# Data

> Datasets and collators for Extreme Classification

In [None]:
#| default_exp data

In [None]:
#| hide
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from scipy import sparse
import torch, inspect, numpy as np, pandas as pd
from IPython.display import display
from typing import Dict, Optional, Callable
from torch.utils.data import Dataset,DataLoader
from xclib.data import data_utils as du
from transformers import PreTrainedTokenizerBase, AutoTokenizer

from fastcore.utils import *
from fastcore.meta import *
from fastcore.dispatch import *

from xcai.core import *

In [None]:
#| hide
from xcai.transform import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Data

In [None]:
dset_dir = f'/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K'
data_cfg = {
    'cols': ['identifier', 'input_text'],
    'use_tokz': True,
    'tokz': 'bert-base-uncased',
    'fld': 'input_text',
    'max_len': 32,
}

In [None]:
#| export
class MainXCData:
    
    @classmethod
    @delegates(Info.from_txt)
    def from_file(cls, data_lbl:str, data_info:str, lbl_info:str, data_lbl_filterer:Optional[str]=None, **kwargs):
        return {
            'data_lbl': du.read_sparse_file(data_lbl),
            'data_info': Info.from_txt(data_info, **kwargs),
            'lbl_info': Info.from_txt(lbl_info, **kwargs),
            'data_lbl_filterer': Filterer.load_filter(data_lbl_filterer),
        }
    

In [None]:
train_cfg = {
    'data_lbl': f'{dset_dir}/trn_X_Y.txt',
    'data_info': f'{dset_dir}/raw_data/train.raw.txt',
    'lbl_info': f'{dset_dir}/raw_data/label.raw.txt',
    'data_lbl_filterer': f'{dset_dir}/filter_labels_train.txt'
}
train_data = MainXCData.from_file(**train_cfg, **data_cfg)

In [None]:
#| export
class MetaXCData:
    
    @classmethod
    @delegates(Info.from_txt)
    def from_file(cls, data_meta:str, lbl_meta:str, meta_info:str, prefix:str, **kwargs):
        return {
            'prefix': prefix,
            'data_meta': du.read_sparse_file(data_meta),
            'lbl_meta': du.read_sparse_file(lbl_meta),
            'meta_info': Info.from_txt(meta_info, **kwargs),
        }
    

In [None]:
meta_cfg = {
    'prefix': 'hlk',
    'data_meta': f'{dset_dir}/hyper_link_trn_X_Y.txt',
    'lbl_meta': f'{dset_dir}/hyper_link_lbl_X_Y.txt',
    'meta_info': f'{dset_dir}/raw_data/hyper_link.raw.txt',
}
meta_data = MetaXCData.from_file(**meta_cfg, **data_cfg)



## Dataset

### `BaseXCDataset`

In [None]:
#| export
class BaseXCDataset(Dataset):
    def __init__(self):
        self.n_data, self.n_lbl, self.n_meta, self.n_samples = None, None, None, None

    def __len__(self):
        return self.n_data if self.n_data is not None else 0

    def splitter(self, valid_pct:Optional[float]=0.2, seed=None):
        if seed is not None: torch.manual_seed(seed)
        rnd_idx = list(torch.randperm(self.n_data).numpy())
        cut = int(valid_pct * self.n_data)
        train, valid = self._getitems(rnd_idx[cut:]), self._getitems(rnd_idx[:cut])
        return train, valid
    
    def _verify_info(self, info:Dict):
        if info is None: raise ValueError('`info` cannot be empty.')
        n_info = [len(v) for k,v in info.items()]
        if len(n_info) == 0 or n_info[0] == 0: raise ValueError('`info` cannot be empty.')
        if np.all([n_info[0] == o for o in n_info]) == False: 
            raise ValueError('All `data_info` fields should have equal number of elements.')
        return n_info[0]
        
    def show_data(self, n:Optional[int]=10, seed:Optional[int]=None):
        if n < 1: return
        if seed: np.random.seed(seed)
        idx = np.random.permutation(self.n_data)[:n]
        d = [self[i] for i in idx]
        df = pd.DataFrame({k:[o[k] for o in d] for k in d[0]})
        with pd.option_context('display.max_colwidth', None, 'display.max_columns', None):
            display(df)
            

### `MainXCDataset`

In [None]:
#| export
class MainXCDataset(BaseXCDataset):
    def __init__(self,
                 data_info:Dict,
                 data_lbl:Optional[sparse.csr_matrix]=None,
                 lbl_info:Optional[Dict]=None,
                 data_lbl_filterer:Optional[Union[sparse.csr_matrix,np.array]]=None,
                 n_samples:Optional[int]=None,
                 **kwargs):
        super().__init__()
        store_attr('data_info,data_lbl,lbl_info,data_lbl_filterer,n_samples')
        self._verify_inputs()
        
    @classmethod
    @delegates(MainXCData.from_file)
    def from_file(cls, n_samples:Optional[int]=None, **kwargs):
        return cls(**MainXCData.from_file(**kwargs), n_samples=n_samples)
        

In [None]:
#| export
@patch
def _verify_inputs(cls:MainXCDataset):
    cls.n_data = cls._verify_info(cls.data_info)
    if cls.data_lbl is not None:
        if cls.n_data != cls.data_lbl.shape[0]:
            raise ValueError(f'`data_info`({cls.n_data}) should have same number of datapoints as `data_lbl`({cls.data_lbl.shape[0]})')
        cls.n_lbl = cls.data_lbl.shape[1]
        if cls.lbl_info is not None:
            n_lbl = cls._verify_info(cls.lbl_info)
            if n_lbl != cls.data_lbl.shape[1]:
                raise ValueError(f'`lbl_info`({n_lbl}) should have same number of labels as `data_lbl`({cls.data_lbl.shape[1]})')

In [None]:
#| export
@patch
def __getitem__(cls:MainXCDataset, idx:int):
    x = {f'data_{k}': v[idx] for k,v in cls.data_info.items()}
    if cls.n_lbl is not None:
        prefix = 'lbl2data'
        x[f'{prefix}_idx'] = cls.data_lbl[idx].indices.tolist()
        if cls.n_samples: x[f'{prefix}_idx'] = [x[f'{prefix}_idx'][i] for i in np.random.permutation(len(x[f'{prefix}_idx']))[:cls.n_samples]]
        if cls.lbl_info is not None:
            x.update({f'{prefix}_{k}':[v[i] for i in x[f'{prefix}_idx']] for k,v in cls.lbl_info.items()})
    return x
    

In [None]:
#| export
@patch
def __getitems__(cls:MainXCDataset, idxs:List):
    x = {f'data_{k}':[v[idx] for idx in idxs] for k,v in cls.data_info.items()}
    if cls.n_lbl is not None:
        prefix = 'lbl2data'
        x[f'{prefix}_idx'] = [cls.data_lbl[idx].indices.tolist() for idx in idxs]
        if cls.n_samples: x[f'{prefix}_idx'] = [[o[i] for i in np.random.permutation(len(o))[:cls.n_samples]] for o in x[f'{prefix}_idx']]
        x[f'{prefix}_data2ptr'] = [len(o) for o in x[f'{prefix}_idx']]
        x[f'{prefix}_idx'] = list(chain(*x[f'{prefix}_idx']))
        if cls.lbl_info is not None:
            x.update({f'{prefix}_{k}':[v[i] for i in x[f'{prefix}_idx']] for k,v in cls.lbl_info.items()})
        return x
        

In [None]:
#| export
@patch
def _getitems(cls:MainXCDataset, idxs:List):
    return MainXCDataset(
        {k:[v[idx] for idx in idxs] for k,v in cls.data_info.items()}, 
        cls.data_lbl[idxs] if cls.data_lbl is not None else None, 
        cls.lbl_info, n_samples=cls.n_samples
    )

#### Example

In [None]:
train_main = MainXCDataset(**train_data, n_samples=2)

In [None]:
train_main.show_data(n=3, seed=10)

Unnamed: 0,data_identifier,data_input_text,data_input_ids,data_token_type_ids,data_attention_mask,lbl2data_idx,lbl2data_identifier,lbl2data_input_text,lbl2data_input_ids,lbl2data_token_type_ids,lbl2data_attention_mask
0,Certificate_of_Formula_Compliance,Certificate of Formula Compliance,"[101, 8196, 1997, 5675, 12646, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]",[244514],[Certificate_of_Origin],[Certificate of Origin],"[[101, 8196, 1997, 4761, 102]]","[[0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1]]"
1,2017_Sunwolves_season,2017 Sunwolves season,"[101, 2418, 3103, 12155, 20899, 2161, 102]","[0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1]","[163257, 308041]","[Sunwolves, 2017_Super_Rugby_season]","[Sunwolves, 2017 Super Rugby season]","[[101, 3103, 12155, 20899, 102], [101, 2418, 3565, 4043, 2161, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]"
2,Hadza_people,Hadza people,"[101, 2018, 4143, 2111, 102]","[0, 0, 0, 0, 0]","[1, 1, 1, 1, 1]","[8411, 164636]","[Twa_peoples, Aka_people]","[Twa peoples, Aka people]","[[101, 1056, 4213, 7243, 102], [101, 9875, 2111, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1]]"


In [None]:
split1, split2 = train_main.splitter()
split1.n_data, split2.n_data, split1.n_data+split2.n_data, train_main.n_data

(554466, 138616, 693082, 693082)

In [None]:
train_main = MainXCDataset.from_file(**train_cfg, **data_cfg)

In [None]:
train_main.show_data(n=3, seed=10)

Unnamed: 0,data_identifier,data_input_text,data_input_ids,data_token_type_ids,data_attention_mask,lbl2data_idx,lbl2data_identifier,lbl2data_input_text,lbl2data_input_ids,lbl2data_token_type_ids,lbl2data_attention_mask
0,Certificate_of_Formula_Compliance,Certificate of Formula Compliance,"[101, 8196, 1997, 5675, 12646, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]",[244514],[Certificate_of_Origin],[Certificate of Origin],"[[101, 8196, 1997, 4761, 102]]","[[0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1]]"
1,2017_Sunwolves_season,2017 Sunwolves season,"[101, 2418, 3103, 12155, 20899, 2161, 102]","[0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1]","[163257, 308041]","[Sunwolves, 2017_Super_Rugby_season]","[Sunwolves, 2017 Super Rugby season]","[[101, 3103, 12155, 20899, 102], [101, 2418, 3565, 4043, 2161, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]"
2,Hadza_people,Hadza people,"[101, 2018, 4143, 2111, 102]","[0, 0, 0, 0, 0]","[1, 1, 1, 1, 1]","[8411, 164636, 230647]","[Twa_peoples, Aka_people, Bushmen]","[Twa peoples, Aka people, Bushmen]","[[101, 1056, 4213, 7243, 102], [101, 9875, 2111, 102], [101, 5747, 3549, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]"


### `MetaXCDataset`

In [None]:
#| export
class MetaXCDataset(BaseXCDataset):

    def __init__(self,
                 prefix:str,
                 data_meta:sparse.csr_matrix, 
                 lbl_meta:sparse.csr_matrix, 
                 meta_info:Optional[Dict]=None, 
                 n_samples:Optional[int]=None, 
                 **kwargs):
        store_attr('prefix,data_meta,lbl_meta,meta_info,n_samples')
        self._verify_inputs()

    def _getitems(self, idxs:List):
        return MetaXCDataset(self.prefix, self.data_meta[idxs], self.lbl_meta, self.meta_info, self.n_samples)
        
    @classmethod
    @delegates(MetaXCData.from_file)
    def from_file(cls, n_samples:Optional[int]=None, **kwargs):
        return cls(**MetaXCData.from_file(**kwargs), n_samples=n_samples)

    @typedispatch
    def get_lbl_meta(self, idx:int):
        prefix = f'{self.prefix}2lbl2data'
        x = {f'{prefix}_idx': self.lbl_meta[idx].indices.tolist()}
        if self.n_samples: x[f'{prefix}_idx'] = [x[f'{prefix}_idx'][i] for i in np.random.permutation(len(x[f'{prefix}_idx']))[:self.n_samples]]
        if self.meta_info is not None:
            x.update({f'{prefix}_{k}':[v[i] for i in x[f'{prefix}_idx']] for k,v in self.meta_info.items()})
        return x
    
    @typedispatch
    def get_lbl_meta(self, idxs:List):
        prefix = f'{self.prefix}2lbl2data'
        x = {f'{prefix}_idx': [self.lbl_meta[idx].indices.tolist() for idx in idxs]}
        if self.n_samples: x[f'{prefix}_idx'] = [[o[i] for i in np.random.permutation(len(o))[:self.n_samples]] for o in x[f'{prefix}_idx']]
        if self.meta_info is not None:
            x.update({f'{prefix}_{k}':[[v[i] for i in o] for o in x[f'{prefix}_idx']] for k,v in self.meta_info.items()})
        return x
        
    def get_data_meta(self, idx:int):
        prefix = f'{self.prefix}2data'
        x = {f'{prefix}_idx': self.data_meta[idx].indices.tolist()}
        if self.n_samples: x[f'{prefix}_idx'] = [x[f'{prefix}_idx'][i] for i in np.random.permutation(len(x[f'{prefix}_idx']))[:self.n_samples]]
        if self.meta_info is not None:
            x.update({f'{prefix}_{k}':[v[i] for i in x[f'{prefix}_idx']] for k,v in self.meta_info.items()})
        return x

    def shape(self):
        return (self.n_data, self.n_lbl, self.n_meta)
        
    def show_data(self, is_lbl:Optional[bool]=False, n:Optional[int]=10, seed:Optional[int]=None):
        if n < 1: return
        if seed: np.random.seed(seed)
        idx = np.random.permutation(self.n_lbl if is_lbl else self.n_data)[:n]
        d = [self.get_lbl_meta(int(i)) for i in idx] if is_lbl else [self.get_data_meta(i) for i in idx]
        df = pd.DataFrame({k:[o[k] for o in d] for k in d[0]})
        with pd.option_context('display.max_colwidth', None):
            display(df)
    

If metadata is available for the main task, then there will be `data_meta` and `lbl_meta`.

In [None]:
#| export
@patch
def _verify_inputs(cls:MetaXCDataset):
    cls.n_data, cls.n_lbl, cls.n_meta = cls.data_meta.shape[0], cls.lbl_meta.shape[0], cls.data_meta.shape[1]
    if cls.lbl_meta.shape[1] != cls.n_meta:
        raise ValueError(f'`lbl_meta`({cls.lbl_meta.shape[1]}) should have same number of columns as `data_meta`({cls.n_meta}).')
    if cls.meta_info is not None:
        n_meta = cls._verify_info(cls.meta_info)
        if n_meta != cls.n_meta:
            raise ValueError(f'`meta_info`({n_meta}) should have same number of entries as number of columns of `data_meta`({cls.n_meta})')
            

#### Example

In [None]:
train_meta = MetaXCDataset(**meta_data, n_samples=3)

In [None]:
train_meta.show_data(n=3)

Unnamed: 0,hlk2data_idx,hlk2data_identifier,hlk2data_input_text,hlk2data_input_ids,hlk2data_token_type_ids,hlk2data_attention_mask
0,"[1188675, 186761, 1188713]","[Thomas_Howard_Fellows, Eric_Verdonk, Martin_Studach]","[Thomas Howard Fellows, Eric Verdonk, Martin Studach]","[[101, 2726, 4922, 13572, 102], [101, 4388, 2310, 28176, 2243, 102], [101, 3235, 16054, 6776, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]"
1,"[484231, 39074, 1362879]","[Geographic_Names_Information_System, United_States_Census_Bureau, List_of_cities_and_towns_in_Arizona]","[Geographic Names Information System, United States Census Bureau, List of cities and towns in Arizona]","[[101, 9183, 3415, 2592, 2291, 102], [101, 2142, 2163, 2883, 4879, 102], [101, 2862, 1997, 3655, 1998, 4865, 1999, 5334, 102]]","[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]]"
2,"[733563, 145339, 1442571]","[Gene_mapping, Proceedings_of_the_National_Academy_of_Sciences_of_the_United_States_of_America, Marc_Rieffel]","[Gene mapping, Proceedings of the National Academy of Sciences of the United States of America, Marc Rieffel]","[[101, 4962, 12375, 102], [101, 8931, 1997, 1996, 2120, 2914, 1997, 4163, 1997, 1996, 2142, 2163, 1997, 2637, 102], [101, 7871, 15544, 12879, 7959, 2140, 102]]","[[0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]"


In [None]:
split1, split2 = train_meta.splitter()
split1.shape(), split2.shape()

((554466, 312330, 2458399), (138616, 312330, 2458399))

In [None]:
train_meta.show_data(n=3, is_lbl=True)

Unnamed: 0,hlk2lbl2data_idx,hlk2lbl2data_identifier,hlk2lbl2data_input_text,hlk2lbl2data_input_ids,hlk2lbl2data_token_type_ids,hlk2lbl2data_attention_mask
0,"[25125, 11543, 49413]","[Newspaper, Catholic_Church, Romandy]","[Newspaper, Catholic Church, Romandy]","[[101, 3780, 102], [101, 3234, 2277, 102], [101, 3142, 5149, 102]]","[[0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]"
1,"[657604, 507661, 163502]","[Niklas_Luhmann, JÃ¼rgen_Habermas, Talcott_Parsons]","[Niklas Luhmann, JÃ¼rgen Habermas, Talcott Parsons]","[[101, 23205, 8523, 11320, 13890, 2078, 102], [101, 14855, 29664, 25892, 5292, 5677, 9335, 102], [101, 21368, 13124, 13505, 102]]","[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]"
2,"[112522, 592260, 22599]","[1960_Winter_Olympics, Frensham_School, Order_of_the_British_Empire]","[1960 Winter Olympics, Frensham School, Order of the British Empire]","[[101, 3624, 3467, 3783, 102], [101, 10424, 6132, 3511, 2082, 102], [101, 2344, 1997, 1996, 2329, 3400, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]"


In [None]:
train_meta.shape()

(693082, 312330, 2458399)

In [None]:
train_meta = MetaXCDataset.from_file(**meta_cfg, **data_cfg)



In [None]:
train_meta.show_data(n=3)

Unnamed: 0,hlk2data_idx,hlk2data_identifier,hlk2data_input_text,hlk2data_input_ids,hlk2data_token_type_ids,hlk2data_attention_mask
0,"[38, 122, 254, 281, 292, 697, 698, 701, 767, 2504, 2754, 2764, 2808, 3223, 3600, 4083, 4322, 4502, 4503, 5203, 19555, 26950, 35767, 45880, 57613, 189987, 247697, 260730, 262600, 280189, 297163, 297313, 361107, 416836, 471498, 619765, 666342, 719698, 818704, 1677742, 2037687, 2414194, 2414195, 2414196]","[Philippines, The_New_York_Times, United_States, France, Forbes, Netherlands, Colombia, Australia, Canada, South_Korea, Japan, Taiwan, New_York_Post, Thailand, Vice_president, Singapore, Food_and_Drug_Administration, Russell_2000_Index, Public_company, Los_Angeles_Times, Internal_Revenue_Service, Certified_Public_Accountant, Time_(magazine), Deseret_News, The_Salt_Lake_Tribune, Barry_Minkow, Chief_financial_officer, UniversitÃ©_de_MontrÃ©al, Therapeutic_Goods_Administration, Chief_operating_officer, U.S._Securities_and_Exchange_Commission, National_Business_Review, Naval_Postgraduate_School, San_Diego_Reader, Commerce_Commission, NSF_International, Public_relations_officer, United_States_Pharmacopeia, USANA_Amphitheatre, Herb_Greenberg, MonaVie, Fraud_Discovery_Institute, Denis_Waitley, Myron_W._Wentz]","[Philippines, The New York Times, United States, France, Forbes, Netherlands, Colombia, Australia, Canada, South Korea, Japan, Taiwan, New York Post, Thailand, Vice president, Singapore, Food and Drug Administration, Russell 2000 Index, Public company, Los Angeles Times, Internal Revenue Service, Certified Public Accountant, Time (magazine), Deseret News, The Salt Lake Tribune, Barry Minkow, Chief financial officer, UniversitÃ© de MontrÃ©al, Therapeutic Goods Administration, Chief operating officer, U.S. Securities and Exchange Commission, National Business Review, Naval Postgraduate School, San Diego Reader, Commerce Commission, NSF International, Public relations officer, United States Pharmacopeia, USANA Amphitheatre, Herb Greenberg, MonaVie, Fraud Discovery Institute, Denis Waitley, Myron W. Wentz]","[[101, 5137, 102], [101, 1996, 2047, 2259, 2335, 102], [101, 2142, 2163, 102], [101, 2605, 102], [101, 10822, 102], [101, 4549, 102], [101, 7379, 102], [101, 2660, 102], [101, 2710, 102], [101, 2148, 4420, 102], [101, 2900, 102], [101, 6629, 102], [101, 2047, 2259, 2695, 102], [101, 6504, 102], [101, 3580, 2343, 102], [101, 5264, 102], [101, 2833, 1998, 4319, 3447, 102], [101, 5735, 2456, 5950, 102], [101, 2270, 2194, 102], [101, 3050, 3349, 2335, 102], [101, 4722, 6599, 2326, 102], [101, 7378, 2270, 17907, 102], [101, 2051, 1006, 2932, 1007, 102], [101, 4078, 7869, 2102, 2739, 102], [101, 1996, 5474, 2697, 10969, 102], [101, 6287, 8117, 24144, 102], [101, 2708, 3361, 2961, 102], [101, 4895, 16402, 28032, 2050, 29652, 2139, 18318, 2527, 29652, 2389, 102], [101, 17261, 5350, 3447, 102], [101, 2708, 4082, 2961, 102], [101, 1057, 1012, 1055, 1012, 12012, 1998, 3863, 3222, 102], [101, 2120, 2449, 3319, 102], [101, 3987, 15438, 2082, 102], [101, 2624, 5277, 8068, 102], [101, 6236, 3222, 102], [101, 24978, 2546, 2248, 102], [101, 2270, 4262, 2961, 102], [101, 2142, 2163, 6887, 27292, 22684, 5051, 2401, 102], [101, 3915, 2532, 23713, 16584, 20192, 7913, 102], [101, 12810, 24190, 102], [101, 13813, 13469, 102], [101, 9861, 5456, 2820, 102], [101, 11064, 3524, 3051, 102], [101, 2026, 4948, 1059, 1012, 2253, 2480, 102]]","[[0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]"
1,"[697, 1484, 1871, 2159, 2223, 2227, 3223, 3260, 3303, 3588, 3800, 3831, 3835, 5286, 8191, 13788, 14573, 16317, 18724, 18738, 19404, 21017, 21635, 25410, 25536, 25564, 28677, 29037, 30019, 36602, 37224, 57165, 57695, 59199, 72706, 88559, 89269, 122763, 135904, 174586, 250630, 266282, 310377, 316289, 332140, 419321, 468604, 468653, 471838, 487747, 490758, 513294, 520315, 538961, 548112, 549034, 656170, 731832, 768225, 994269, 1052104, 1327913, 1364034, 2243097]","[Netherlands, India, Antarctica, Tajikistani_somoni, Erbil, Russia, Thailand, Heard_Island_and_McDonald_Islands, Bonaire, Pyongyang, United_States_Minor_Outlying_Islands, Slovenia, Socialist_Federal_Republic_of_Yugoslavia, Ho_Chi_Minh_City, Azerbaijan, Manipur, Hanoi, Australian_Antarctic_Territory, Lakshadweep, Andaman_and_Nicobar_Islands, Ashmore_and_Cartier_Islands, Sikkim, Da_Nang, Wake_Island, Eastern_Orthodox_Church, Tibet_Autonomous_Region, Gaza_Strip, Uttarakhand, Jammu_and_Kashmir, Korean_Demilitarized_Zone, Ross_Dependency, Johnston_Atoll, Sarez_Lake, Asmara, Crimea, Kingman_Reef, Midway_Atoll, Palmyra_Atoll, Clipperton_Island, United_Nations_Disengagement_Observer_Force, Ghajar, Kawthaung, Yangon_International_Airport, Jarvis_Island, Norwegian_Police_Service, Mount_Athos, Myawaddy, Tachileik, Nay_Pyi_Taw_International_Airport, Addis_Ababa_Bole_International_Airport, Republic_of_Artsakh, LÃ©on-Mba_International_Airport, Minsk_National_Airport, Howland_Island, Mandalay_International_Airport, Banjul_International_Airport, Peter_I_Island, Inaccessible_Island, Electronic_System_for_Travel_Authorization, Eritrean_nakfa, Visa_requirements_for_European_Union_citizens, List_of_diplomatic_missions_of_Slovenia, Nightingale_Islands, Slovenian_identity_card]","[Netherlands, India, Antarctica, Tajikistani somoni, Erbil, Russia, Thailand, Heard Island and McDonald Islands, Bonaire, Pyongyang, United States Minor Outlying Islands, Slovenia, Socialist Federal Republic of Yugoslavia, Ho Chi Minh City, Azerbaijan, Manipur, Hanoi, Australian Antarctic Territory, Lakshadweep, Andaman and Nicobar Islands, Ashmore and Cartier Islands, Sikkim, Da Nang, Wake Island, Eastern Orthodox Church, Tibet Autonomous Region, Gaza Strip, Uttarakhand, Jammu and Kashmir, Korean Demilitarized Zone, Ross Dependency, Johnston Atoll, Sarez Lake, Asmara, Crimea, Kingman Reef, Midway Atoll, Palmyra Atoll, Clipperton Island, United Nations Disengagement Observer Force, Ghajar, Kawthaung, Yangon International Airport, Jarvis Island, Norwegian Police Service, Mount Athos, Myawaddy, Tachileik, Nay Pyi Taw International Airport, Addis Ababa Bole International Airport, Republic of Artsakh, LÃ©on-Mba International Airport, Minsk National Airport, Howland Island, Mandalay International Airport, Banjul International Airport, Peter I Island, Inaccessible Island, Electronic System for Travel Authorization, Eritrean nakfa, Visa requirements for European Union citizens, List of diplomatic missions of Slovenia, Nightingale Islands, Slovenian identity card]","[[101, 4549, 102], [101, 2634, 102], [101, 12615, 102], [101, 23538, 2072, 2061, 8202, 2072, 102], [101, 9413, 14454, 102], [101, 3607, 102], [101, 6504, 102], [101, 2657, 2479, 1998, 9383, 3470, 102], [101, 14753, 14737, 102], [101, 1052, 14001, 6292, 5654, 102], [101, 2142, 2163, 3576, 25376, 3470, 102], [101, 10307, 102], [101, 6102, 2976, 3072, 1997, 8936, 102], [101, 7570, 9610, 19538, 2103, 102], [101, 8365, 102], [101, 23624, 5311, 102], [101, 24809, 102], [101, 2827, 10227, 3700, 102], [101, 2474, 28132, 2094, 28394, 2361, 102], [101, 1998, 23093, 1998, 19332, 8237, 3470, 102], [101, 6683, 5974, 1998, 11122, 3771, 3470, 102], [101, 9033, 24103, 2213, 102], [101, 4830, 16660, 2290, 102], [101, 5256, 2479, 102], [101, 2789, 6244, 2277, 102], [101, 13319, 8392, 2555, 102], [101, 14474, 6167, 102], [101, 14940, 27573, 5685, 102], [101, 21433, 1998, 13329, 102], [101, 4759, 27668, 27606, 18425, 4224, 102], [101, 5811, 24394, 102], [101, 10773, 22292, 102], [101, 18906, 9351, 2697, 102], [101, 2004, 28225, 102], [101, 21516, 102], [101, 2332, 2386, 12664, 102], [101, 12213, 22292, 102], [101, 5340, 19563, 22292, 102], [101, 12528, 4842, 2669, 2479, 102], [101, 2142, 3741, 4487, 5054, 3654, 20511, 9718, 2486, 102], [101, 1043, 3270, 16084, 102], [101, 10556, 26677, 3270, 5575, 102], [101, 8675, 2239, 2248, 3199, 102], [101, 21072, 2479, 102], [101, 5046, 2610, 2326, 102], [101, 4057, 2012, 15006, 102], [101, 2026, 10830, 14968, 102], [101, 11937, 5428, 23057, 2243, 102], [101, 29349, 1052, 10139, 11937, 2860, 2248, 3199, 102], [101, 5587, 2483, 19557, 3676, 8945, 2571, 2248, 3199, 102], [101, 3072, 1997, 2840, 27573, 102], [101, 2474, 29652, 2239, 1011, 15038, 2248, 3199, 102], [101, 20790, 2120, 3199, 102], [101, 22912, 5685, 2479, 102], [101, 24373, 4710, 2248, 3199, 102], [101, 7221, 9103, 2140, 2248, 3199, 102], [101, 2848, 1045, 2479, 102], [101, 29104, 2479, 102], [101, 4816, 2291, 2005, 3604, 20104, 102], [101, 26040, 2078, 17823, 7011, 102], [101, 9425, 5918, 2005, 2647, 2586, 4480, 102], [101, 2862, 1997, 8041, 6416, 1997, 10307, 102], [101, 21771, 3470, 102], [101, 16583, 4767, 4003, 102]]","[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1]]"
2,"[701, 725, 767, 1766, 2755, 10047, 10995, 10996, 13286, 19632, 22835, 27642, 30198, 38457, 44752, 45529, 60631, 61082, 82417, 95577, 116852, 266393, 266425, 266434, 266446, 269665, 526242, 538228, 739978, 816008, 885929, 943971, 1105838, 1445429, 2282672, 2282673, 2282674]","[Australia, Ontario, Canada, Old_Norse, Commonwealth_of_Nations, Yorkshire, Henry_I_of_England, Norman_conquest_of_England, County_Cork, Canadian_Confederation, Local_Government_Act_1972, Scandinavia, Restoration_(England), East_Riding_of_Yorkshire, West_Riding_of_Yorkshire, Lindsey,_Lincolnshire, North_Riding_of_Yorkshire, County_council, Winifred_Holtby, Nenagh, Politics_of_Canada, North_Tipperary, Local_Government_Act_2001, Clonmel, South_Tipperary, David_Peace, Fourth_Age, Farthings_of_Iceland, Police_division, List_of_former_counties_of_Quebec, Local_Government_Reform_Act_2014, Leges_Edwardi_Confessoris, Walking_Purchase, Yorkshire_Ridings_Society, South_Riding_of_Lindsey, North_Riding_of_Lindsey, West_Riding_of_Lindsey]","[Australia, Ontario, Canada, Old Norse, Commonwealth of Nations, Yorkshire, Henry I of England, Norman conquest of England, County Cork, Canadian Confederation, Local Government Act 1972, Scandinavia, Restoration (England), East Riding of Yorkshire, West Riding of Yorkshire, Lindsey, Lincolnshire, North Riding of Yorkshire, County council, Winifred Holtby, Nenagh, Politics of Canada, North Tipperary, Local Government Act 2001, Clonmel, South Tipperary, David Peace, Fourth Age, Farthings of Iceland, Police division, List of former counties of Quebec, Local Government Reform Act 2014, Leges Edwardi Confessoris, Walking Purchase, Yorkshire Ridings Society, South Riding of Lindsey, North Riding of Lindsey, West Riding of Lindsey]","[[101, 2660, 102], [101, 4561, 102], [101, 2710, 102], [101, 2214, 15342, 102], [101, 5663, 1997, 3741, 102], [101, 7018, 102], [101, 2888, 1045, 1997, 2563, 102], [101, 5879, 9187, 1997, 2563, 102], [101, 2221, 8513, 102], [101, 3010, 11078, 102], [101, 2334, 2231, 2552, 3285, 102], [101, 20612, 102], [101, 6418, 1006, 2563, 1007, 102], [101, 2264, 5559, 1997, 7018, 102], [101, 2225, 5559, 1997, 7018, 102], [101, 17518, 1010, 16628, 102], [101, 2167, 5559, 1997, 7018, 102], [101, 2221, 2473, 102], [101, 2663, 10128, 5596, 12621, 3762, 102], [101, 11265, 2532, 5603, 102], [101, 4331, 1997, 2710, 102], [101, 2167, 17333, 102], [101, 2334, 2231, 2552, 2541, 102], [101, 18856, 2239, 10199, 102], [101, 2148, 17333, 102], [101, 2585, 3521, 102], [101, 2959, 2287, 102], [101, 2521, 20744, 2015, 1997, 10399, 102], [101, 2610, 2407, 102], [101, 2862, 1997, 2280, 5721, 1997, 5447, 102], [101, 2334, 2231, 5290, 2552, 2297, 102], [101, 4190, 2229, 3487, 2072, 18766, 21239, 102], [101, 3788, 5309, 102], [101, 7018, 5559, 2015, 2554, 102], [101, 2148, 5559, 1997, 17518, 102], [101, 2167, 5559, 1997, 17518, 102], [101, 2225, 5559, 1997, 17518, 102]]","[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]","[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]"


### `XCDataset`

In [None]:
#| export
class MetaXCDatasets(dict):

    def __init__(self, meta:Dict):
        super().__init__(meta)
        for o in meta: setattr(self, o, meta[o])
        

In [None]:
#| export
class XCDataset(BaseXCDataset):

    def __init__(self, data:MainXCDataset, **kwargs):
        super().__init__()
        self.data, self.meta = data, MetaXCDatasets({k:kwargs[k] for k in self.get_meta_args(**kwargs) if isinstance(kwargs[k], MetaXCDataset)})
        self._verify_inputs()

    def _getitems(self, idxs:List):
        return XCDataset(self.data._getitems(idxs), **{k:meta._getitems(idxs) for k,meta in self.meta.items()})

    @staticmethod
    def get_meta_args(**kwargs):
        return [k for k in kwargs if re.match(r'.*_meta$', k)]
        
    @classmethod
    @delegates(MainXCDataset.from_file)
    def from_file(cls, **kwargs):
        data = MainXCDataset.from_file(**kwargs)
        meta_kwargs = {o:kwargs.pop(o) for o in cls.get_meta_args(**kwargs)}
        meta = {k:MetaXCDataset.from_file(**v, **kwargs) for k,v in meta_kwargs.items()}
        return cls(data, **meta)

    def _verify_inputs(self):
        self.n_data, self.n_lbl = self.data.n_data, self.data.n_lbl
        if len(self.meta):
            self.n_meta = self.meta[list(self.meta.keys())[0]].n_meta
            for meta in self.meta.values():
                if meta.n_data != self.n_data: raise ValueError(f'`meta`({meta.n_data}) and `data`({self.n_data}) should have the same number of datapoints.')
                if self.n_lbl is not None and meta.n_lbl != self.n_lbl: 
                    raise ValueError(f'`meta`({meta.n_lbl}) and `data`({self.n_lbl}) should have the same number of labels.')
                if meta.n_meta != self.n_meta: raise ValueError(f'Every `meta`({meta.n_meta},{self.n_meta}) should have the same number of entries.')

    def __getitem__(self, idx:int):
        x = self.data[idx]
        if self.n_meta:
            for m in self.meta.values():
                x.update(m.get_data_meta(idx))
                if self.n_lbl: x.update(m.get_lbl_meta(x['lbl2data_idx']))
        return x

    @property
    def lbl_info(self): return self.data.lbl_info

    def one_batch(self, bsz:Optional[int]=10, seed:Optional[int]=None):
        if seed is not None: torch.manual_seed(seed)
        idxs = list(torch.randperm(len(self)).numpy())[:bsz]
        return [self[idx] for idx in idxs]
       

#### Example

In [None]:
train_dset = XCDataset(train_main, hlk_meta=train_meta)

In [None]:
b = train_dset.one_batch(10)

In [None]:
b

[{'data_identifier': 'Orikabe_Station',
  'data_input_text': 'Orikabe Station',
  'data_input_ids': [101, 2030, 7556, 4783, 2276, 102],
  'data_token_type_ids': [0, 0, 0, 0, 0, 0],
  'data_attention_mask': [1, 1, 1, 1, 1, 1],
  'lbl2data_idx': [113248],
  'lbl2data_identifier': ['List_of_Railway_Stations_in_Japan'],
  'lbl2data_input_text': ['List of Railway Stations in Japan'],
  'lbl2data_input_ids': [[101, 2862, 1997, 2737, 3703, 1999, 2900, 102]],
  'lbl2data_token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0]],
  'lbl2data_attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1]],
  'hlk2data_idx': [4170, 1596964, 1359126],
  'hlk2data_identifier': ['Japanese_National_Railways',
   'List_of_highways_numbered_284',
   'Ichinoseki_Station'],
  'hlk2data_input_text': ['Japanese National Railways',
   'List of highways numbered 284',
   'Ichinoseki Station'],
  'hlk2data_input_ids': [[101, 2887, 2120, 7111, 102],
   [101, 2862, 1997, 10292, 8597, 26871, 102],
   [101, 22564, 5740, 3366, 3211, 2276, 102]],
 

In [None]:
train_dset.show_data(n=2)

Unnamed: 0,data_identifier,data_input_text,data_input_ids,data_token_type_ids,data_attention_mask,lbl2data_idx,lbl2data_identifier,lbl2data_input_text,lbl2data_input_ids,lbl2data_token_type_ids,lbl2data_attention_mask,hlk2data_idx,hlk2data_identifier,hlk2data_input_text,hlk2data_input_ids,hlk2data_token_type_ids,hlk2data_attention_mask,hlk2lbl2data_idx,hlk2lbl2data_identifier,hlk2lbl2data_input_text,hlk2lbl2data_input_ids,hlk2lbl2data_token_type_ids,hlk2lbl2data_attention_mask
0,Portuguese_Paratroopers,Portuguese Paratroopers,"[101, 5077, 11498, 13181, 27342, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]","[151631, 183910]","[Airborne_forces, Portuguese_Air_Force]","[Airborne forces, Portuguese Air Force]","[[101, 10519, 2749, 102], [101, 5077, 2250, 2486, 102]]","[[0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1, 1], [1, 1, 1, 1, 1]]","[1082329, 701264, 281]","[Special_Operations_Troops_Centre, Tancos, France]","[Special Operations Troops Centre, Tancos, France]","[[101, 2569, 3136, 3629, 2803, 102], [101, 9092, 13186, 102], [101, 2605, 102]]","[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]]","[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]","[[321269, 152744, 7156], [205409, 111094, 1980540]]","[[Le_Muy, Indonesian_invasion_of_East_Timor, Benjamin_Franklin], [Airstrike, Dornier_Do_27, Culatra_Island]]","[[Le Muy, Indonesian invasion of East Timor, Benjamin Franklin], [Airstrike, Dornier Do 27, Culatra Island]]","[[[101, 3393, 14163, 2100, 102], [101, 9003, 5274, 1997, 2264, 19746, 102], [101, 6425, 5951, 102]], [[101, 14369, 18886, 3489, 102], [101, 2079, 28484, 2079, 2676, 102], [101, 12731, 20051, 2527, 2479, 102]]]","[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]]","[[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]]"
1,African_theology,African theology,"[101, 3060, 8006, 102]","[0, 0, 0, 0]","[1, 1, 1, 1]","[228980, 7989]","[African_theology, Christianity_in_Africa]","[African theology, Christianity in Africa]","[[101, 3060, 8006, 102], [101, 7988, 1999, 3088, 102]]","[[0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1, 1], [1, 1, 1, 1, 1]]","[534154, 534150, 91924]","[John_S._Pobee, Musimbi_Kanyoro, Western_Christianity]","[John S. Pobee, Musimbi Kanyoro, Western Christianity]","[[101, 2198, 1055, 1012, 13433, 11306, 102], [101, 14163, 5332, 14905, 2072, 22827, 7677, 3217, 102], [101, 2530, 7988, 102]]","[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]]","[[534152, 534154, 531359], [442884, 531315, 805]]","[[Kwesi_Dickson, John_S._Pobee, Christianity_in_Africa], [Athanasius_of_Alexandria, Southeast_Africa, Bible]]","[[Kwesi Dickson, John S. Pobee, Christianity in Africa], [Athanasius of Alexandria, Southeast Africa, Bible]]","[[[101, 6448, 2229, 2072, 22076, 102], [101, 2198, 1055, 1012, 13433, 11306, 102], [101, 7988, 1999, 3088, 102]], [[101, 2012, 15788, 24721, 1997, 10297, 102], [101, 4643, 3088, 102], [101, 6331, 102]]]","[[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]]]","[[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]]"


In [None]:
split1, split2 = train_dset.splitter()
split1.data.n_data, {k:m.shape() for k,m in split1.meta.items()}, split2.data.n_data, {k:m.shape() for k,m in split2.meta.items()}

(554466,
 {'hlk_meta': (554466, 312330, 2458399)},
 138616,
 {'hlk_meta': (138616, 312330, 2458399)})

In [None]:
train_dset = XCDataset.from_file(**cfg['path']['train'], **cfg['parameters'])



In [None]:
train_dset.show_data(n=2)

Unnamed: 0,data_identifier,data_input_text,data_input_ids,data_token_type_ids,data_attention_mask,lbl2data_idx,lbl2data_identifier,lbl2data_input_text,lbl2data_input_ids,lbl2data_token_type_ids,lbl2data_attention_mask,hlk2data_idx,hlk2data_identifier,hlk2data_input_text,hlk2data_input_ids,hlk2data_token_type_ids,hlk2data_attention_mask,hlk2lbl2data_idx,hlk2lbl2data_identifier,hlk2lbl2data_input_text,hlk2lbl2data_input_ids,hlk2lbl2data_token_type_ids,hlk2lbl2data_attention_mask
0,Microsail,Microsail,"[101, 12702, 15816, 2140, 102]","[0, 0, 0, 0, 0]","[1, 1, 1, 1, 1]",[194848],[List_of_sailing_boat_types],[List of sailing boat types],"[[101, 2862, 1997, 8354, 4049, 4127, 102]]","[[0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1]]","[254, 281, 527389, 527392, 883760]","[United_States, France, Gary_Mull, Fractional_rig, Jeanneau]","[United States, France, Gary Mull, Fractional rig, Jeanneau]","[[101, 2142, 2163, 102], [101, 2605, 102], [101, 5639, 14163, 3363, 102], [101, 12884, 2389, 19838, 102], [101, 14537, 4887, 102]]","[[0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1]]","[[43854, 261695, 297082]]","[[National_Collegiate_Athletic_Association, Sacred_Heart_University, Sacred_Heart_Pioneers]]","[[National Collegiate Athletic Association, Sacred Heart University, Sacred Heart Pioneers]]","[[[101, 2120, 9234, 5188, 2523, 102], [101, 6730, 2540, 2118, 102], [101, 6730, 2540, 13200, 102]]]","[[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]","[[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]]"
1,Captain_Benjamin_Williams_House,Captain Benjamin Williams House,"[101, 2952, 6425, 3766, 2160, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]",[235682],"[National_Register_of_Historic_Places_listings_in_Middletown,_Connecticut]","[National Register of Historic Places listings in Middletown, Connecticut]","[[101, 2120, 4236, 1997, 3181, 3182, 26213, 1999, 28747, 1010, 6117, 102]]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]","[7140, 38915]","[National_Register_of_Historic_Places, Connecticut_River]","[National Register of Historic Places, Connecticut River]","[[101, 2120, 4236, 1997, 3181, 3182, 102], [101, 6117, 2314, 102]]","[[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]]","[[8398, 19705, 26775, 44140, 55989, 64716, 75195, 81592, 129263, 195143, 267626, 281086, 281149, 281203, 281221, 281352, 304511, 316421, 578952, 723622, 742621, 744544, 799315, 844632, 1019297, 1357876, 1357877, 1357878, 1357879, 1357880]]","[[2004_Summer_Olympics, Baseball_at_the_Summer_Olympics, Australians, Run_batted_in, Second_baseman, Maccabi_World_Union, Jewish_Virtual_Library, New_South_Wales_Patriots, Southeastern_Louisiana_University, Sri_Lanka_national_cricket_team, Baseball_at_the_2004_Summer_Olympics, World_Baseball_Classic, Rodney_van_Buizen, Brett_Roneberg, Glenn_Williams, Brendan_Kingman, 2007_Baseball_World_Cup, Greg_Jelks, 2005_Baseball_World_Cup, Judaism_in_Australia, New_Haven_County_Cutters, International_Baseball_League_of_Australia, 2001_Baseball_World_Cup, Northeast_League, Taiwan_Major_League, Masada_College, 2006_Claxton_Shield, 2004_Claxton_Shield, 2007_Claxton_Shield, Catskill_Cougars]]","[[2004 Summer Olympics, Baseball at the Summer Olympics, Australians, Run batted in, Second baseman, Maccabi World Union, Jewish Virtual Library, New South Wales Patriots, Southeastern Louisiana University, Sri Lanka national cricket team, Baseball at the 2004 Summer Olympics, World Baseball Classic, Rodney van Buizen, Brett Roneberg, Glenn Williams, Brendan Kingman, 2007 Baseball World Cup, Greg Jelks, 2005 Baseball World Cup, Judaism in Australia, New Haven County Cutters, International Baseball League of Australia, 2001 Baseball World Cup, Northeast League, Taiwan Major League, Masada College, 2006 Claxton Shield, 2004 Claxton Shield, 2007 Claxton Shield, Catskill Cougars]]","[[[101, 2432, 2621, 3783, 102], [101, 3598, 2012, 1996, 2621, 3783, 102], [101, 15739, 102], [101, 2448, 12822, 1999, 102], [101, 2117, 18038, 102], [101, 24055, 2088, 2586, 102], [101, 3644, 7484, 3075, 102], [101, 2047, 2148, 3575, 11579, 102], [101, 8252, 5773, 2118, 102], [101, 5185, 7252, 2120, 4533, 2136, 102], [101, 3598, 2012, 1996, 2432, 2621, 3783, 102], [101, 2088, 3598, 4438, 102], [101, 13898, 3158, 20934, 4697, 2078, 102], [101, 12049, 6902, 22669, 2290, 102], [101, 9465, 3766, 102], [101, 15039, 2332, 2386, 102], [101, 2289, 3598, 2088, 2452, 102], [101, 6754, 15333, 13687, 2015, 102], [101, 2384, 3598, 2088, 2452, 102], [101, 13725, 1999, 2660, 102], [101, 2047, 4033, 2221, 16343, 2015, 102], [101, 2248, 3598, 2223, 1997, 2660, 102], [101, 2541, 3598, 2088, 2452, 102], [101, 4794, 2223, 102], [101, 6629, 2350, 2223, 102], [101, 16137, 8447, 2267, 102], [101, 2294, 18856, 8528, 2669, 6099, 102], [101, 2432, 18856, 8528, 2669, 6099, 102], [101, 2289, 18856, 8528, 2669, 6099, 102], [101, 8870, 15872, 26317, 102]]]","[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]","[[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]]"


## Collator

In [None]:
#| export
class XCCollator:

    def __init__(self, tfms):
        self.tfms = tfms

    def __call__(self, x):
        return self.tfms(x)
        

### Example

In [None]:
tokz = AutoTokenizer.from_pretrained('bert-base-uncased')
tfm_cfg = {
    'sep_tok': tokz.sep_token_id,
    'pad_tok': tokz.pad_token_id,
    'prefix': 'lbl',
    'pad_side': 'right',
    'inp': 'data',
    'targ': 'lbl2data',
    'ptr': 'lbl2data_data2ptr',
    'drop': False,
    'ret_t': True,
    'in_place': False,
    'lev': 0,
    'device': 'cpu',
}

In [None]:
tfms = TfmPipeline([XCPadFeatTfm(**tfm_cfg), AlignInputIdsTfm(**tfm_cfg)])
collator = XCCollator(tfms)

In [None]:
b = train_dset.one_batch(10)
o = collator(b)

In [None]:
o

{'lbl2data_idx': tensor([152270, 152271,  21303,  19438, 259968, 134235, 181495,  22805, 266055,
        138124,  63265, 130051, 195205,  57429,   2567,   2568]), 'lbl2data_identifier': ['Communes_of_the_Alpes-de-Haute-Provence_department', '3=Liste_des_anciennes_communes_des_Alpes-de-Haute-Provence', 'List_of_mountain_peaks_of_the_United_States', 'List_of_mountain_peaks_of_North_America', '2002-03_Heineken_Cup', 'Football_in_Poland', 'PogoÃ\x85Â\x84_Szczecin', 'List_of_national_anthems', 'National_Anthem_of_South_Ossetia', 'VicuÃ\x83Â±a_family', 'Louisiana_African_American_Heritage_Trail', 'Creoles_of_color', 'UK_insolvency_law', 'Communes_of_the_Seine-et-Marne_department', 'List_of_foreign_ministers_in_2017', 'List_of_current_foreign_ministers'], 'lbl2data_input_text': ['Communes of the Alpes-de-Haute-Provence department', '3=Liste des anciennes communes des Alpes-de-Haute-Provence', 'List of mountain peaks of the United States', 'List of mountain peaks of North America', '2002-03 He

In [None]:
dl = DataLoader(train_dset, batch_size=10, collate_fn=collator)

In [None]:
b = next(iter(dl))

In [None]:
b

{'lbl2data_idx': tensor([    0,     2,     3,     9, 26766,    14,    13,    16,    18,    23,
           29,    45,    60,    51,    67,    65,   101,   102]), 'lbl2data_identifier': ['Antinomianism', 'Anarchism', 'Autism', 'Conimbricenses', 'Aristotle', 'List_of_actors_with_Academy_Award_nominations', 'List_of_Academy_Award_records', 'Clock_synchronization', 'Precision_Time_Protocol', 'Egotism', 'Inclusive_fitness', 'Canadian_pioneers_in_early_Hollywood', 'Qualitative_research', 'Ethnobiology', 'History_of_agricultural_science', 'American_Society_of_Agronomy', '12_basic_principles_of_animation', 'Anime'], 'lbl2data_input_text': ['Antinomianism', 'Anarchism', 'Autism', 'Conimbricenses', 'Aristotle', 'List of actors with Academy Award nominations', 'List of Academy Award records', 'Clock synchronization', 'Precision Time Protocol', 'Egotism', 'Inclusive fitness', 'Canadian pioneers in early Hollywood', 'Qualitative research', 'Ethnobiology', 'History of agricultural science', 'American

In [None]:
b.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_token_type_ids', 'hlk2lbl2data_attention_mask', 'hlk2lbl2data_data2ptr', 'hlk2lbl2data_lbl2ptr', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_token_type_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr'])

## Configuration

In [None]:
import json

In [None]:
cfg = {
    'path':{
        'train': {
            'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/trn_X_Y.txt',
            'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt',
            'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
            'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_train.txt',
            'hlk_meta': {
                'prefix': 'hlk',
                'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_trn_X_Y.txt',
                'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_lbl_X_Y.txt',
                'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/hyper_link.raw.txt',
            },
        },
        'test': {
            'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/tst_X_Y.txt',
            'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/test.raw.txt',
            'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
            'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_test.txt',
            'hlk_meta': {
                'prefix': 'hlk',
                'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_tst_X_Y.txt',
                'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_lbl_X_Y.txt',
                'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/hyper_link.raw.txt',
            },
        },
    },
    'parameters': {
        'cols': ['identifier', 'input_text'],
        'use_tokz': True,
        'tokz': 'bert-base-uncased',
        'fld': 'input_text',
        'max_len': 32,
        'pad_side': 'right',
        'inp': 'data',
        'targ': 'lbl2data',
        'ptr': 'lbl2data_data-ptr',
        'drop': False,
        'ret_t': True,
        'in_place': False,
        'lev': 0,
        'device': 'cpu',
    },
}

In [None]:
cfg_dir = '/home/scai/phd/aiz218323/scratch/Projects/xcai/cfg/'

In [None]:
with open(f'{cfg_dir}/data.json', 'w') as f:
    json.dump(cfg, f, indent=2)

## Data block

### BaseXCDataBlock

In [None]:
#| export
class BaseXCDataBlock:

    @delegates(DataLoader.__init__)
    def __init__(self, 
                 dset:XCDataset, 
                 collate_fn:Callable=None,
                 **kwargs):
        self.dset, self.dl_kwargs, self.collate_fn = dset, self._get_dl_kwargs(**kwargs), collate_fn
        self.dl = DataLoader(dset, collate_fn=collate_fn, **self.dl_kwargs) if collate_fn is not None else None

    @classmethod
    @delegates(XCDataset.from_file)
    def from_file(cls, collate_fn:Callable=None, **kwargs):
        return BaseXCDataBlock(XCDataset.from_file(**kwargs), collate_fn, **kwargs)

    def __len__(self):
        return len(self.dset)

    def _get_dl_kwargs(self, **kwargs):
        dl_params = inspect.signature(DataLoader.__init__).parameters
        return {k:v for k,v in kwargs.items() if k in dl_params}

    
    def _getitems(self, idxs:List):
        return BaseXCDataBlock(self.dset._getitems(idxs), collate_fn=self.collate_fn, **self.dl_kwargs)

    @property
    def bsz(self): return self.dl.batch_size

    @bsz.setter
    def bsz(self, v):
        self.dl_kwargs['batch_size'] = v
        self.dl = DataLoader(self.dset, collate_fn=self.collate_fn, **self.dl_kwargs) if self.collate_fn is not None else None

    @property
    def data_lbl_filterer(self): return self.dset.data.data_lbl_filterer

    @data_lbl_filterer.setter
    def data_lbl_filterer(self, val): self.dset.data.data_lbl_filterer = val

    @typedispatch
    def one_batch(self):
        return next(iter(self.dl))

    @typedispatch
    def one_batch(self, bsz:int):
        self.dl_kwargs['batch_size'] = bsz
        self.dl = DataLoader(self.dset, collate_fn=self.collate_fn, **self.dl_kwargs) if self.collate_fn is not None else None
        return next(iter(self.dl))
        
        

In [None]:
#| export
@patch
def filterer(cls:BaseXCDataBlock, train:'BaseXCDataBlock', valid:'BaseXCDataBlock', fld:Optional[str]='identifier'):
    train_info, valid_info, lbl_info = train.dset.data.data_info, valid.dset.data.data_info, train.dset.data.lbl_info
    if fld not in train_info: raise ValueError(f'`{fld}` not in `data_info`')
        
    train.data_lbl_filterer, valid_filterer = Filterer.generate(train_info[fld], valid_info[fld], lbl_info[fld], 
                                                                train.dset.data.data_lbl, valid.dset.data.data_lbl)
    _, valid_filterer, idx = Filterer.prune(valid.dset.data.data_lbl, valid_filterer)
    
    valid = valid._getitems(idx)
    valid.data_lbl_filterer = valid_filterer
    
    return train, valid

@patch
def splitter(cls:BaseXCDataBlock, valid_pct:Optional[float]=0.2, seed=None):
    if seed is not None: torch.manual_seed(seed)
    rnd_idx = list(torch.randperm(len(cls)).numpy())
    cut = int(valid_pct * len(cls))
    train, valid = cls._getitems(rnd_idx[cut:]), cls._getitems(rnd_idx[:cut])
    return cls.filterer(train, valid)
        

#### Example

In [None]:
train_block = BaseXCDataBlock(train_dset, batch_size=2, collate_fn=collator)

In [None]:
split1, split2 = train_block.splitter()
print('split 1:',split1.dset.n_data,split1.dset.n_lbl,split1.dset.n_meta)
print('split 2:',split2.dset.n_data,split2.dset.n_lbl,split2.dset.n_meta) 

  self._set_arrayXarray(i, j, x)


split 1: 554466 312330 2458399
split 2: 137588 312330 2458399


In [None]:
b = train_block.one_batch()

In [None]:
b = train_block.one_batch(20)

In [None]:
train_block = BaseXCDataBlock.from_file(**cfg['path']['train'], **cfg['parameters'], collate_fn=collator)



In [None]:
next(iter(train_block.dl))

{'lbl2data_idx': tensor([0, 1, 2]), 'lbl2data_identifier': ['Antinomianism', 'Libertarian_socialism', 'Anarchism'], 'lbl2data_input_text': ['Antinomianism', 'Libertarian socialism', 'Anarchism'], 'lbl2data_input_ids': tensor([[  101,  3424,  3630, 20924,   102,     0],
        [  101, 19297, 14649,   102,     0,     0],
        [  101,  9617, 11140,  2964,   102,     0]]), 'lbl2data_token_type_ids': tensor([[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]]), 'lbl2data_attention_mask': tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 0]]), 'lbl2data_data2ptr': tensor([3]), 'data_identifier': ['Anarchism'], 'data_input_text': ['Anarchism'], 'data_input_ids': tensor([[  101,  9617, 11140,  2964,   102]]), 'data_token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'data_attention_mask': tensor([[1, 1, 1, 1, 1]]), 'hlk2lbl2data_idx': tensor([   1992,    4612,    4697,    4706,    4709,    4892,    5256,    7587,
           9813,   10948,   110

### XCDataBlock

In [None]:
#| export
class XCDataBlock:

    def __init__(self, train:BaseXCDataBlock=None, valid:BaseXCDataBlock=None, test:BaseXCDataBlock=None):
        self.train, self.valid, self.test = train, valid, test

    @staticmethod
    def load_cfg(fname):
        with open(fname, 'r') as f: return json.load(f)

    @property
    def lbl_info(self): return self.train.dset.data.lbl_info

    @property
    def n_lbl(self): return self.train.dset.n_lbl

    @property
    def collator(self): return self.train.collate_fn
        
    @classmethod
    def from_cfg(cls, 
                 cfg:Union[str,Dict],
                 collate_fn:Optional[Callable]=None,
                 valid_pct:Optional[float]=0.2,
                 seed=None):
        if isinstance(cfg, str): cfg = cls.load_cfg(cfg)
        blks = {o:BaseXCDataBlock.from_file(**cfg['path'][o], **cfg['parameters'], collate_fn=collate_fn) for o in ['train', 'valid', 'test'] if o in cfg['path']}
        if 'valid' not in blks: blks['train'], blks['valid'] = blks['train'].splitter(valid_pct, seed=seed)
        return cls(**blks)
        

#### Example

In [None]:
block = XCDataBlock(train=train_block)

In [None]:
lbl_info = block.lbl_info

In [None]:
lbl_info['input_ids'][:10]

[[101, 3424, 3630, 20924, 2964, 102],
 [101, 19297, 14649, 102],
 [101, 9617, 11140, 2964, 102],
 [101, 19465, 102],
 [101, 13798, 1997, 8181, 5367, 102],
 [101, 7734, 2162, 1997, 6889, 102],
 [101, 4519, 2793, 5349, 102],
 [101, 2862, 1997, 7008, 1997, 8181, 5367, 102],
 [101, 2862, 1997, 2942, 2916, 4177, 102],
 [101, 9530, 5714, 23736, 19023, 2229, 102]]

In [None]:
%time block = XCDataBlock.from_cfg(cfg, collate_fn=collator)

CPU times: user 1min 18s, sys: 930 ms, total: 1min 19s
Wall time: 44.2 s


In [None]:
print('n_data :', f'{block.train.dset.n_data},{block.valid.dset.n_data},{block.test.dset.n_data}')
print('n_lbl  :', f'{block.train.dset.n_lbl},{block.valid.dset.n_lbl},{block.test.dset.n_lbl}')
print('n_meta :', f'{block.train.dset.n_meta},{block.valid.dset.n_meta},{block.test.dset.n_meta}')

n_data : 554466,138616,177515
n_lbl  : 312330,312330,312330
n_meta : None,None,None


In [None]:
block.valid.dset.show_data(n=2)

Unnamed: 0,data_identifier,data_input_text,data_input_ids,data_token_type_ids,data_attention_mask,lbl2data_idx,lbl2data_identifier,lbl2data_input_text,lbl2data_input_ids,lbl2data_token_type_ids,lbl2data_attention_mask,hlk2data_idx,hlk2data_identifier,hlk2data_input_text,hlk2data_input_ids,hlk2data_token_type_ids,hlk2data_attention_mask,hlk2lbl2data_idx,hlk2lbl2data_identifier,hlk2lbl2data_input_text,hlk2lbl2data_input_ids,hlk2lbl2data_token_type_ids,hlk2lbl2data_attention_mask
0,2019_Caribbean_Club_Championship,2019 Caribbean Club Championship,"[101, 10476, 7139, 2252, 2528, 102]","[0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1]","[310617, 310618]","[2019_CONCACAF_League, 2020_CONCACAF_Champions_League]","[2019 CONCACAF League, 2020 CONCACAF Champions League]","[[101, 10476, 22169, 2223, 102], [101, 12609, 22169, 3966, 2223, 102]]","[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]","[2932, 4843, 8132, 115570, 122131, 147896, 158253, 158260, 246100, 286923, 289314, 304800, 350441, 403932, 416728, 416741, 416746, 416750, 416921, 416926, 416929, 416930, 416945, 416949, 416954, 416961, 427446, 427447, 427448, 427449, 427450, 427451, 427452, 427453, 427454, 427455, 427456, 427457, 427458, 427459, 427460, 427461]","[California, San_Francisco, Eastern_Time_Zone, Willemstad, CONCACAF, Jamaica_Football_Federation, Football_Association_of_Cuba, Bermuda_Football_Association, Caribbean_Football_Union, Surinamese_Football_Association, Guyana_Football_Federation, Barbados_Football_Association, Caribbean_Club_Championship, Bahamas_Football_Association, Cibao_FC, Grenada_Football_Association, Club_Franciscain, Ligue_de_football_de_la_Martinique, Weymouth_Wales_FC, Anguilla_Football_Association, Aruba_Football_Federation, SV_Real_Rincon, Bonaire_Football_Federation, Dominica_Football_Association, Guadeloupean_League_of_Football, Montserrat_Football_Association, 2017_TT_Pro_League, AS_Capoise, 2018_Caribbean_Club_Championship, SV_Dakota, Fruta_Conquerors_FC, FC_Santiago_de_Cuba, Platinum_FC, 2018_SLFA_First_Division, Village_Superstars_FC, 2020_Caribbean_Club_Championship, Ergilio_Hato_Stadium, 2018_Campeonato_Nacional_de_FÃºtbol_de_Cuba, 2018_Barbados_Premier_League, Scholars_International_SC, CRKSV_Jong_Holland, CS_Moulien]","[California, San Francisco, Eastern Time Zone, Willemstad, CONCACAF, Jamaica Football Federation, Football Association of Cuba, Bermuda Football Association, Caribbean Football Union, Surinamese Football Association, Guyana Football Federation, Barbados Football Association, Caribbean Club Championship, Bahamas Football Association, Cibao FC, Grenada Football Association, Club Franciscain, Ligue de football de la Martinique, Weymouth Wales FC, Anguilla Football Association, Aruba Football Federation, SV Real Rincon, Bonaire Football Federation, Dominica Football Association, Guadeloupean League of Football, Montserrat Football Association, 2017 TT Pro League, AS Capoise, 2018 Caribbean Club Championship, SV Dakota, Fruta Conquerors FC, FC Santiago de Cuba, Platinum FC, 2018 SLFA First Division, Village Superstars FC, 2020 Caribbean Club Championship, Ergilio Hato Stadium, 2018 Campeonato Nacional de FÃºtbol de Cuba, 2018 Barbados Premier League, Scholars International SC, CRKSV Jong Holland, CS Moulien]","[[101, 2662, 102], [101, 2624, 3799, 102], [101, 2789, 2051, 4224, 102], [101, 18811, 16917, 102], [101, 22169, 102], [101, 9156, 2374, 4657, 102], [101, 2374, 2523, 1997, 7394, 102], [101, 13525, 2374, 2523, 102], [101, 7139, 2374, 2586, 102], [101, 25050, 3366, 2374, 2523, 102], [101, 18786, 2374, 4657, 102], [101, 16893, 2374, 2523, 102], [101, 7139, 2252, 2528, 102], [101, 17094, 2374, 2523, 102], [101, 25022, 3676, 2080, 4429, 102], [101, 29153, 2374, 2523, 102], [101, 2252, 4557, 3540, 2378, 102], [101, 18374, 2139, 2374, 2139, 2474, 29365, 102], [101, 2057, 24335, 17167, 3575, 4429, 102], [101, 17076, 19231, 2721, 2374, 2523, 102], [101, 12098, 19761, 2374, 4657, 102], [101, 17917, 2613, 15544, 15305, 2078, 102], [101, 14753, 14737, 2374, 4657, 102], [101, 11282, 2050, 2374, 2523, 102], [101, 19739, 9648, 23743, 5051, 2319, 2223, 1997, 2374, 102], [101, 18318, 8043, 8609, 2374, 2523, 102], [101, 2418, 23746, 4013, 2223, 102], [101, 2004, 6178, 23565, 102], [101, 2760, 7139, 2252, 2528, 102], [101, 17917, 7734, 102], [101, 10424, 13210, 25466, 2015, 4429, 102], [101, 4429, 8728, 2139, 7394, 102], [101, 8899, 4429, 102], [101, 2760, 22889, 7011, 2034, 2407, 102], [101, 2352, 18795, 2015, 4429, 102], [101, 12609, 7139, 2252, 2528, 102], [101, 9413, 20142, 3695, 6045, 2080, 3346, 102], [101, 2760, 17675, 10718, 2139, 6904, 29662, 2102, 14956, 2139, 7394, 102], [101, 2760, 16893, 4239, 2223, 102], [101, 5784, 2248, 8040, 102], [101, 13675, 5705, 2615, 18528, 7935, 102], [101, 20116, 9587, 15859, 2368, 102]]","[[0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]","[[1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]","[[122131, 133440, 133443, 246100, 262512, 284189, 304790, 304795, 350432, 350434, 350435, 350436, 350437, 350439, 350441, 393339, 403927, 403940, 406583, 427443], [122131, 133440, 133443, 150577, 171079, 246100, 260995, 262512, 284189, 289987, 304790, 342673, 350434, 350435, 350437, 350438, 350441, 393339, 406583, 416759, 419069, 427776]]","[[CONCACAF, National_Autonomous_Federation_of_Football_of_Honduras, National_Football_Federation_of_Guatemala, Caribbean_Football_Union, Panamanian_Football_Federation, Apertura_and_Clausura, Central_American_Football_Union, Football_Federation_of_Belize, Nicaraguan_Football_Federation, Salvadoran_Primera_DivisiÃ³n, Liga_FPD, Premier_League_of_Belize, Liga_Nacional_de_FÃºtbol_de_Guatemala, Nicaraguan_Primera_DivisiÃ³n, Caribbean_Club_Championship, 2018_CONCACAF_Champions_League, 2018_CONCACAF_League, 2017_CONCACAF_League, 2019_CONCACAF_Champions_League, 2020_CONCACAF_Champions_League], [CONCACAF, National_Autonomous_Federation_of_Football_of_Honduras, National_Football_Federation_of_Guatemala, Cruz_Azul, Mexican_Football_Federation, Caribbean_Football_Union, MLS_Cup_Playoffs, Panamanian_Football_Federation, Apertura_and_Clausura, Canadian_Championship, Central_American_Football_Union, Voyageurs_Cup, Salvadoran_Primera_DivisiÃ³n, Liga_FPD, Liga_Nacional_de_FÃºtbol_de_Guatemala, Liga_MX, Caribbean_Club_Championship, 2018_CONCACAF_Champions_League, 2019_CONCACAF_Champions_League, 2019_CONCACAF_League, 2019_Major_League_Soccer_season, 2019_Canadian_Championship]]","[[CONCACAF, National Autonomous Federation of Football of Honduras, National Football Federation of Guatemala, Caribbean Football Union, Panamanian Football Federation, Apertura and Clausura, Central American Football Union, Football Federation of Belize, Nicaraguan Football Federation, Salvadoran Primera DivisiÃ³n, Liga FPD, Premier League of Belize, Liga Nacional de FÃºtbol de Guatemala, Nicaraguan Primera DivisiÃ³n, Caribbean Club Championship, 2018 CONCACAF Champions League, 2018 CONCACAF League, 2017 CONCACAF League, 2019 CONCACAF Champions League, 2020 CONCACAF Champions League], [CONCACAF, National Autonomous Federation of Football of Honduras, National Football Federation of Guatemala, Cruz Azul, Mexican Football Federation, Caribbean Football Union, MLS Cup Playoffs, Panamanian Football Federation, Apertura and Clausura, Canadian Championship, Central American Football Union, Voyageurs Cup, Salvadoran Primera DivisiÃ³n, Liga FPD, Liga Nacional de FÃºtbol de Guatemala, Liga MX, Caribbean Club Championship, 2018 CONCACAF Champions League, 2019 CONCACAF Champions League, 2019 CONCACAF League, 2019 Major League Soccer season, 2019 Canadian Championship]]","[[[101, 22169, 102], [101, 2120, 8392, 4657, 1997, 2374, 1997, 14373, 102], [101, 2120, 2374, 4657, 1997, 11779, 102], [101, 7139, 2374, 2586, 102], [101, 8515, 11148, 2374, 4657, 102], [101, 23957, 5339, 4648, 1998, 19118, 4648, 102], [101, 2430, 2137, 2374, 2586, 102], [101, 2374, 4657, 1997, 18867, 102], [101, 15448, 2078, 2374, 4657, 102], [101, 10582, 2319, 14837, 4487, 11365, 2401, 18107, 2078, 102], [101, 8018, 1042, 17299, 102], [101, 4239, 2223, 1997, 18867, 102], [101, 8018, 10718, 2139, 6904, 29662, 2102, 14956, 2139, 11779, 102], [101, 15448, 2078, 14837, 4487, 11365, 2401, 18107, 2078, 102], [101, 7139, 2252, 2528, 102], [101, 2760, 22169, 3966, 2223, 102], [101, 2760, 22169, 2223, 102], [101, 2418, 22169, 2223, 102], [101, 10476, 22169, 3966, 2223, 102], [101, 12609, 22169, 3966, 2223, 102]], [[101, 22169, 102], [101, 2120, 8392, 4657, 1997, 2374, 1997, 14373, 102], [101, 2120, 2374, 4657, 1997, 11779, 102], [101, 8096, 17207, 5313, 102], [101, 4916, 2374, 4657, 102], [101, 7139, 2374, 2586, 102], [101, 16287, 2452, 7555, 102], [101, 8515, 11148, 2374, 4657, 102], [101, 23957, 5339, 4648, 1998, 19118, 4648, 102], [101, 3010, 2528, 102], [101, 2430, 2137, 2374, 2586, 102], [101, 8774, 9236, 2452, 102], [101, 10582, 2319, 14837, 4487, 11365, 2401, 18107, 2078, 102], [101, 8018, 1042, 17299, 102], [101, 8018, 10718, 2139, 6904, 29662, 2102, 14956, 2139, 11779, 102], [101, 8018, 25630, 102], [101, 7139, 2252, 2528, 102], [101, 2760, 22169, 3966, 2223, 102], [101, 10476, 22169, 3966, 2223, 102], [101, 10476, 22169, 2223, 102], [101, 10476, 2350, 2223, 4715, 2161, 102], [101, 10476, 3010, 2528, 102]]]","[[[0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], [[0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]","[[[1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], [[1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]]"
1,1994_United_States_Senate_election_in_Texas,1994 United States Senate election in Texas,"[101, 2807, 2142, 2163, 4001, 2602, 1999, 3146, 102]","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 1, 1, 1, 1, 1, 1, 1, 1]",[116620],"[United_States_Senate_elections,_1994]","[United States Senate elections, 1994]","[[101, 2142, 2163, 4001, 3864, 1010, 2807, 102]]","[[0, 0, 0, 0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1, 1, 1]]","[10363, 156277, 174973]","[Kay_Bailey_Hutchison, Jim_Mattox, Texas_Attorney_General]","[Kay Bailey Hutchison, Jim Mattox, Texas Attorney General]","[[101, 10905, 8925, 12570, 27795, 102], [101, 3958, 4717, 11636, 102], [101, 3146, 4905, 2236, 102]]","[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]","[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]",[[]],[[]],[[]],[[]],[[]],[[]]


In [None]:
train_lbl = block.train.dset.data.data_lbl
test_lbl = block.valid.dset.data.data_lbl

In [None]:
train_identifier = block.train.dset.data.data_info['identifier']
test_identifier = block.valid.dset.data.data_info['identifier']
lbl_identifier = block.train.dset.data.lbl_info['identifier']

In [None]:
x, y = gen_filterer(train_identifier, test_identifier, lbl_identifier, train_lbl, test_lbl)