In [1]:
#| default_exp block

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import numpy as np, re, inspect
from typing import Optional, Dict, Callable, Union
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase

from fastcore.meta import *

from xcai.data import *
from xcai.sdata import *
from xcai.ndata import *
from xcai.transform import *
from xcai.data_sampler import *

from xcai.config import PARAM, WIKISEEALSOTITLES, WIKITITLES, WIKISEEALSO, WIKIPEDIA, ORCAS, AMAZONTITLES131, AMAZON131, AMAZONTITLES

In [4]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Config

In [5]:
#| export
CFGS = {
    'wikiseealsotitles':WIKISEEALSOTITLES, 
    'wikititles':WIKITITLES, 
    'wikiseealso':WIKISEEALSO, 
    'wikipedia':WIKIPEDIA,
    'orcas': ORCAS,
    'amazontitles131': AMAZONTITLES131,
    'amazon131': AMAZON131,
    'amazontitles': AMAZONTITLES,
}

TFMS = {
    'xc': [XCPadFeatTfm, AlignInputIdsTfm], 
    'ng': [NGPadFeatTfm], 
    'xcnlg': [XCSamplePadFeatTfm], 
    'rm':[RamenPadFeatTfm],
    'xcs': [XCSamplerFeatTfm],
    'oak': [OAKSamplerFeatTfm],
}

## Block

In [6]:
#| export
class XCBlock:

    @delegates(XCDataBlock.from_cfg)
    @classmethod
    def from_cfg(
        cls, 
        cfg:Union[str,Dict],
        cfg_key:Optional[str]=None,
        data_dir:Optional[str]=None,  
        bsz:Optional[int]=10, 
        **kwargs
    ):
        if isinstance(cfg, str):
            """ Selecting the configuration """
            if cfg not in CFGS: raise ValueError(f'Invalid configuration ({cfg})')
            cfgs = CFGS[cfg](data_dir)
    
            """ Selecting the dataset type """
            if cfg_key not in cfgs: raise ValueError(f'Invalid configuration key ({cfg_key})')
            cfg = cfgs[cfg_key] 

        """ Setting the parameters """
        for k in cfg['parameters']: 
            if k in kwargs: cfg['parameters'][k]=kwargs.pop(k)

        tokenizer = cfg['parameters']['tokenizer']
        tokz = tokenizer if isinstance(tokenizer, PreTrainedTokenizerBase) else AutoTokenizer.from_pretrained(tokenizer)  
        cfg['parameters']['pad_token'] = tokz.pad_token_id
        cfg['parameters']['batch_size'] = bsz
        
        collator = XCCollator(TfmPipeline([o(**cfg['parameters']) for o in TFMS[cfg['parameters']['transform_type']]]))
        
        return XCDataBlock.from_cfg(cfg, collate_fn=collator, **kwargs)


In [7]:
#| export
class SXCBlock:

    @delegates(SXCDataBlock.from_cfg)
    @classmethod
    def from_cfg(
        cls,
        cfg:Union[str,Dict],
        cfg_key:Optional[str]=None,
        data_dir:Optional[str]=None, 
        collate_fn:Optional[Callable]=identity_collate_fn, 
        **kwargs
    ):
        if isinstance(cfg, str):
            """ Selecting the configuration """
            if cfg not in CFGS: raise ValueError(f'Invalid configuration ({cfg})')
            cfgs = CFGS[cfg](data_dir)
    
            """ Selecting the dataset type """
            if cfg_key not in cfgs: raise ValueError(f'Invalid configuration key ({cfg_key})')
            cfg = cfgs[cfg_key] 

        """ Setting the parameters """
        for k in cfg['parameters']: 
            if k in kwargs: cfg['parameters'][k]=kwargs.pop(k)
            
        return SXCDataBlock.from_cfg(cfg, collate_fn=collate_fn, **kwargs)
        

In [8]:
#| export
class NXCBlock:

    @delegates(NXCDataBlock.from_cfg)
    @classmethod
    def from_cfg(
        cls,
        cfg:Union[str,Dict],
        cfg_key:Optional[str]=None,
        data_dir:Optional[str]=None, 
        collate_fn:Optional[Callable]=identity_collate_fn, 
        **kwargs
    ):
        if isinstance(cfg, str):
            """ Selecting the configuration """
            if cfg not in CFGS: raise ValueError(f'Invalid configuration ({cfg})')
            cfgs = CFGS[cfg](data_dir)
    
            """ Selecting the dataset type """
            if cfg_key not in cfgs: raise ValueError(f'Invalid configuration key ({cfg_key})')
            cfg = cfgs[cfg_key] 

        """ Setting the parameters """
        for k in cfg['parameters']: 
            if k in kwargs: cfg['parameters'][k]=kwargs.pop(k)
            
        return NXCDataBlock.from_cfg(cfg, collate_fn=collate_fn, **kwargs)
        

#### Example

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC/data/'

### `WIKISEEALSOTITLES`

In [None]:
block = XCBlock.from_cfg(data_dir, 'train_meta', dset='wikiseealsotitles', transform_type='xcnlg', tokenizer='bert-base-uncased')



In [None]:
b = block.train.one_batch()

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'phlk2data_idx', 'phlk2data_data2ptr', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_token_type_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr', 'phlk2lbl2data_idx', 'phlk2lbl2data_data2ptr', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_token_type_ids', 'hlk2lbl2data_attention_mask', 'hlk2lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask', 'data_idx'])

In [None]:
block = SXCBlock.from_cfg(data_dir, 'data_lnk', dset='wikiseealsotitles', tokenizer='distilbert-base-uncased')



In [None]:
batch = block.train.one_batch(10)

In [None]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'plnk2lbl_idx', 'plnk2lbl_lbl2ptr', 'lnk2lbl_idx', 'lnk2lbl_lbl2ptr', 'lnk2lbl_identifier', 'lnk2lbl_input_text', 'lnk2lbl_input_ids', 'lnk2lbl_attention_mask', 'lnk2lbl_data2ptr', 'plnk2lbl_data2ptr'])

In [None]:
type(batch)

transformers.tokenization_utils_base.BatchEncoding

### `WIKITITLES`

In [None]:
block = XCBlock.from_cfg(data_dir, 'train', dset='wikititles', tfm='ng', tokenizer='bert-base-uncased')

In [None]:
b = block.train.one_batch(); b.keys()

dict_keys(['lbl2data_idx', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
import torch
for k,v in b.items():
    if isinstance(v, torch.Tensor): print(k,':', v.shape)
    else: print(k,':',len(v))

lbl2data_idx : torch.Size([10])
plbl2data_idx : torch.Size([32])
plbl2data_data2ptr : torch.Size([10])
lbl2data_identifier : 10
lbl2data_input_text : 10
lbl2data_input_ids : torch.Size([10, 9])
lbl2data_token_type_ids : torch.Size([10, 9])
lbl2data_attention_mask : torch.Size([10, 9])
lbl2data_data2ptr : torch.Size([10])
data_identifier : 10
data_input_text : 10
data_input_ids : torch.Size([10, 10])
data_token_type_ids : torch.Size([10, 10])
data_attention_mask : torch.Size([10, 10])


In [None]:
b = block.train.dset.one_batch(); b

[{'data_identifier': 'MiRA_Resource_Centre_for_Black,_Immigrant_and_Refugee_Women',
  'data_input_text': 'MiRA Resource Centre for Black, Immigrant and Refugee Women',
  'data_input_ids': [101,
   18062,
   7692,
   2803,
   2005,
   2304,
   1010,
   11560,
   1998,
   13141,
   2308,
   102],
  'data_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  'data_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  'lbl2data_idx': [194516, 242040, 333379, 334112, 494210],
  'lbl2data_identifier': ['Category:Feminist_organizations',
   'Category:Human_rights_organizations',
   'Category:Organisations_based_in_Norway',
   'Category:Organizations_established_in_1989',
   "Category:Women\\'s_organizations"],
  'lbl2data_input_text': ['Feminist organizations',
   'Human rights organizations',
   'Organisations based in Norway',
   'Organizations established in 1989',
   'women organizations'],
  'lbl2data_input_ids': [[101, 10469, 4411, 102],
   [101, 2529, 2916, 4411, 102],
   [101, 8