In [None]:
#| default_exp block

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import numpy as np, re, inspect
from typing import Optional, Dict
from transformers import AutoTokenizer, BatchEncoding

from fastcore.meta import *

from xcai.data import *
from xcai.transform import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Config

### `PARAM`

In [None]:
#| export
PARAM = {
    'info_column_names': ['identifier', 'input_text'],
    'use_tokenizer': True,
    'tokenizer': 'bert-base-cased',
    'tokenization_column': 'input_text',
    'max_sequence_length': 32,
    'pad_side': 'right',
    'inp': 'data',
    'targ': 'lbl2data',
    'ptr': 'lbl2data_data2ptr',
    'drop': True,
    'ret_t': True,
    'in_place': True,
    'collapse': True,
    'device': 'cpu',
    'tfm': 'xc',
}

### `CONFIGS`

In [None]:
#| export
WIKISEEALSO = {
    'train' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_train.txt',
            },
        },
        'parameters': PARAM,
    },
    'data' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_train.txt',
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_test.txt',
            },
        },
        'parameters': PARAM,
    },
    'train_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_train.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/hyper_link.raw.txt'
                },
            },
        },
        'parameters': PARAM,
    },
    'data_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_train.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/hyper_link.raw.txt'
                },
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/filter_labels_test.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_tst_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/hyper_link.raw.txt',
                },
            },
        },
        'parameters': PARAM,
    },
}

In [None]:
#| export
WIKICATEGORY = {
    'train' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
            },
        },
        'parameters': PARAM,
    },
    'data' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
            },
        },
        'parameters': PARAM,
    },
    'train_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/hyper_link.raw.txt'
                },
            },
        },
        'parameters': PARAM,
    },
    'data_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/hyper_link.raw.txt'
                },
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/label.raw.txt',
                'hlk_meta': {
                    'prefix': 'hlk',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_tst_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/hyper_link_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-WikiTitles-500K/raw_data/hyper_link.raw.txt',
                },
            },
        },
        'parameters': PARAM,
    },
}

In [None]:
#| export
AMAZON = {
    'train' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_train.txt',
            },
        },
        'parameters': PARAM,
    },
    'data' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_train.txt',
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_test.txt',
            },
        },
        'parameters': PARAM,
    },
    'train_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_train.txt',
                'cat_meta': {
                    'prefix': 'cat',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/category.raw.txt'
                },
            },
        },
        'parameters': PARAM,
    },
    'data_meta' : {
        'path': {
            'train': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/trn_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/train.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_train.txt',
                'cat_meta': {
                    'prefix': 'cat',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_trn_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/category.raw.txt'
                },
            },
            'test': {
                'data_lbl': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/tst_X_Y.txt',
                'data_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/test.raw.txt',
                'lbl_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/label.raw.txt',
                'data_lbl_filterer': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/filter_labels_test.txt',
                'cat_meta': {
                    'prefix': 'cat',
                    'data_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_tst_X_Y.txt',
                    'lbl_meta': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/category_lbl_X_Y.txt',
                    'meta_info': '/home/scai/phd/aiz218323/Projects/XC_NLG/data/(mapped)LF-AmazonTitles-1.3M/raw_data/category.raw.txt',
                },
            },
        },
        'parameters': PARAM,
    },
}

## Block

In [None]:
#| export
CFGS = {'wiki_seealso':WIKISEEALSO, 'wiki_category':WIKICATEGORY, 'amazon':AMAZON}
TFMS = {'xc': [XCPadFeatTfm, AlignInputIdsTfm], 'ng': [NGPadFeatTfm],}

In [None]:
#| export
class XCBlock:

    @delegates(XCDataBlock.from_cfg)
    @classmethod
    def from_cfg(cls, cfg:str, dset:Optional[str]='wiki_seealso', bsz:Optional[int]=10, **kwargs):
        if dset not in CFGS: raise ValueError(f'Invalid `dset`({cfg})')
        cfgs = CFGS[dset]

        if cfg not in cfgs: raise ValueError(f'Invalid `cfg`({cfg})')
        cfg = cfgs[cfg] 
            
        for k in cfg['parameters']: 
            if k in kwargs and kwargs[k] is not None: cfg['parameters'][k]=kwargs.pop(k)
                
        tokz = AutoTokenizer.from_pretrained(cfg['parameters']['tokenizer'])
        cfg['parameters']['sep_tok'] = tokz.sep_token_id 
        cfg['parameters']['pad_tok'] = tokz.pad_token_id
        cfg['parameters']['batch_size'] = bsz
        
        collator = XCCollator(TfmPipeline([o(**cfg['parameters']) for o in TFMS[cfg['parameters']['tfm']]]))
        
        return XCDataBlock.from_cfg(cfg, collate_fn=collator, **kwargs)


#### Example

##### `WikiSeeAlso`

In [None]:
block = XCBlock.from_cfg('train', dset='wiki_seealso', tfm='ng', tokenizer='bert-base-uncased')

  self._set_arrayXarray(i, j, x)


In [None]:
b = block.train.one_batch()

In [None]:
b.keys()

dict_keys(['lbl2data_idx', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
import torch
for k,v in b.items():
    if isinstance(v, torch.Tensor): print(k,':', v.shape)
    else: print(k,':',len(v))

lbl2data_idx : torch.Size([10])
plbl2data_idx : torch.Size([21])
plbl2data_data2ptr : torch.Size([10])
lbl2data_identifier : 10
lbl2data_input_text : 10
lbl2data_input_ids : torch.Size([10, 15])
lbl2data_token_type_ids : torch.Size([10, 15])
lbl2data_attention_mask : torch.Size([10, 15])
lbl2data_data2ptr : torch.Size([10])
data_identifier : 10
data_input_text : 10
data_input_ids : torch.Size([10, 11])
data_token_type_ids : torch.Size([10, 11])
data_attention_mask : torch.Size([10, 11])


In [None]:
b = block.train.dset.one_batch()

In [None]:
b

[{'data_identifier': 'Kaskaskia_River',
  'data_input_text': 'Kaskaskia River',
  'data_input_ids': [101, 10556, 8337, 5488, 2050, 2314, 102],
  'data_token_type_ids': [0, 0, 0, 0, 0, 0, 0],
  'data_attention_mask': [1, 1, 1, 1, 1, 1, 1],
  'lbl2data_idx': [109204],
  'lbl2data_identifier': ['List_of_Illinois_rivers'],
  'lbl2data_input_text': ['List of Illinois rivers'],
  'lbl2data_input_ids': [[101, 2862, 1997, 4307, 5485, 102]],
  'lbl2data_token_type_ids': [[0, 0, 0, 0, 0, 0]],
  'lbl2data_attention_mask': [[1, 1, 1, 1, 1, 1]]},
 {'data_identifier': 'HaÊ»apai',
  'data_input_text': 'HaÊ»apai',
  'data_input_ids': [101, 5292, 2063, 1090, 9706, 4886, 102],
  'data_token_type_ids': [0, 0, 0, 0, 0, 0, 0],
  'data_attention_mask': [1, 1, 1, 1, 1, 1, 1],
  'lbl2data_idx': [30134, 203987],
  'lbl2data_identifier': ['List_of_islands_and_towns_in_Tonga',
   '2006_Tonga_earthquake'],
  'lbl2data_input_text': ['List of islands and towns in Tonga',
   '2006 Tonga earthquake'],
  'lbl2data_inp

##### `WikiCategory`

In [None]:
block = XCBlock.from_cfg('train', dset='wiki_category', tfm='ng', tokenizer='bert-base-uncased')

In [None]:
b = block.train.one_batch(); b.keys()

dict_keys(['lbl2data_idx', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
import torch
for k,v in b.items():
    if isinstance(v, torch.Tensor): print(k,':', v.shape)
    else: print(k,':',len(v))

lbl2data_idx : torch.Size([10])
plbl2data_idx : torch.Size([47])
plbl2data_data2ptr : torch.Size([10])
lbl2data_identifier : 10
lbl2data_input_text : 10
lbl2data_input_ids : torch.Size([10, 11])
lbl2data_token_type_ids : torch.Size([10, 11])
lbl2data_attention_mask : torch.Size([10, 11])
lbl2data_data2ptr : torch.Size([10])
data_identifier : 10
data_input_text : 10
data_input_ids : torch.Size([10, 32])
data_token_type_ids : torch.Size([10, 32])
data_attention_mask : torch.Size([10, 32])


In [None]:
b = block.train.dset.one_batch(); b

[{'data_identifier': 'Lincs_Wind_Farm',
  'data_input_text': 'Lincs Wind Farm',
  'data_input_ids': [101, 11409, 6169, 3612, 3888, 102],
  'data_token_type_ids': [0, 0, 0, 0, 0, 0],
  'data_attention_mask': [1, 1, 1, 1, 1, 1],
  'lbl2data_idx': [161683, 176709, 328961, 403954, 492958],
  'lbl2data_identifier': ['Category:DONG_Energy_wind_farms',
   'Category:East_Lindsey',
   'Category:Offshore_wind_farms_in_the_North_Sea',
   'Category:Round_2_offshore_wind_farms',
   'Category:Wind_farms_in_England'],
  'lbl2data_input_text': ['DONG Energy wind farms',
   'East Lindsey',
   'Offshore wind farms in the North Sea',
   'Round 2 offshore wind farms',
   'Wind farms in England'],
  'lbl2data_input_ids': [[101, 11947, 2943, 3612, 8623, 102],
   [101, 2264, 17518, 102],
   [101, 12195, 3612, 8623, 1999, 1996, 2167, 2712, 102],
   [101, 2461, 1016, 12195, 3612, 8623, 102],
   [101, 3612, 8623, 1999, 2563, 102]],
  'lbl2data_token_type_ids': [[0, 0, 0, 0, 0, 0],
   [0, 0, 0, 0],
   [0, 0, 0, 

##### `AmazonProduct`

In [None]:
block = XCBlock.from_cfg('train', dset='amazon', tfm='ng', tokenizer='bert-base-uncased')

In [None]:
b = block.train.one_batch(); b.keys()

dict_keys(['lbl2data_idx', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
import torch
for k,v in b.items():
    if isinstance(v, torch.Tensor): print(k,':', v.shape)
    else: print(k,':',len(v))

lbl2data_idx : torch.Size([10])
plbl2data_idx : torch.Size([279])
plbl2data_data2ptr : torch.Size([10])
lbl2data_identifier : 10
lbl2data_input_text : 10
lbl2data_input_ids : torch.Size([10, 32])
lbl2data_token_type_ids : torch.Size([10, 32])
lbl2data_attention_mask : torch.Size([10, 32])
lbl2data_data2ptr : torch.Size([10])
data_identifier : 10
data_input_text : 10
data_input_ids : torch.Size([10, 25])
data_token_type_ids : torch.Size([10, 25])
data_attention_mask : torch.Size([10, 25])


In [None]:
b = block.train.dset.one_batch(); b

[{'data_identifier': 'B001C8B0H6',
  'data_input_text': 'Rapid 02892 Heavy Duty Cartridge Stapler, 80 Sheet Capacity, Silver',
  'data_input_ids': [101,
   5915,
   6185,
   2620,
   2683,
   2475,
   3082,
   4611,
   15110,
   18785,
   2099,
   1010,
   3770,
   7123,
   3977,
   1010,
   3165,
   102],
  'data_token_type_ids': [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  'data_attention_mask': [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  'lbl2data_idx': [391348,
   542913,
   591984,
   606011,
   679776,
   741684,
   883696,
   883697,
   883698],
  'lbl2data_identifier': ['B00006IF79',
   'B008ALWW5W',
   'B0006HV93Y',
   'B00006IFMC',
   'B00168CPYO',
   'B00006IA5E',
   'B001C89FFA',
   'B001C8B0HQ',
   'B001C8B0HG'],
  'lbl2data_input_text': ['Scotch Desk Tape Dispenser, 1in. Core, Black',
   'Rubbermaid Commercial FG295600BLA Plastic Deskside Wasteb

## Batch

In [None]:
#| export
def prepare_batch(m, b, m_args=None):
    m_kwargs = inspect.signature(m.forward).parameters
    return BatchEncoding({k:v for k,v in b.items() if k in m_kwargs or (m_args is not None and k in m_args)})