# Data sampler

In [1]:
#| default_exp data_sampler

In [2]:
%load_ext autoreload
%autoreload 2

In [14]:
#| export
import os,pickle,torch,re, numpy as np
from typing import Optional,List,Dict
from itertools import chain

from transformers import BatchEncoding, AutoTokenizer

from fastcore.utils import *

from xcai.transform import PadFeatTfm,CollapseTfm
from xcai.core import store_attr

## Setup

In [4]:
from xcai.block import *

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [11]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_distilbert-base-uncased_xcs.pkl'

In [7]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [5]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

block = XCBlock.from_cfg(data_dir, 'data_meta', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                         sampling_features=[('lbl2data',4)], oversample=False, padding=True, return_tensors='pt')



In [8]:
block.collator.tfms.tfms[0].sampling_features = [('lbl2data,cat2lbl2data', (4,1)), ('cat2data', 1)]

In [10]:
batch = block.train.dset.one_batch(bsz=5, seed=10)

In [13]:
batch[0].keys()

dict_keys(['data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2lbl2data_idx', 'cat2lbl2data_identifier', 'cat2lbl2data_input_text', 'cat2lbl2data_input_ids', 'cat2lbl2data_attention_mask'])

## `XCSamplerFeatTfm`

In [9]:
#| export
class XCSamplerFeatTfm:

    def __init__(
        self,
        pad_token:Optional[int]=0,
        oversample:Optional[bool]=False,
        sampling_features:Optional[List]=None,
        **kwargs
    ):
        store_attr('sampling_features,oversample')
        self.pad_proc = PadFeatTfm(pad_tok=pad_token, in_place=False, drop=False)
        self.col_proc = CollapseTfm()

    def sample_feature(self, batch, names, n_samples, oversample):
        feature_names = names.split(',')
        
        if isinstance(n_samples, int): 
            n_samples = (n_samples,)*len(feature_names)

        if len(feature_names) != len(n_samples):
            raise ValueError(f'`feature_names` and `n_samples` should have same length.')
        
        base_name, dep_names = feature_names[0], feature_names[1:]
        base_n_sample, dep_n_samples = n_samples[0], n_samples[1:]

        for p in dep_names:
            if not p.endswith(base_name): 
                raise ValueError(f'{p} does not end with the base prefix `{base_name}`.')

        sampled_batch, sbatch = self.sample_base_feature(batch, names, base_name, base_n_sample, oversample)
        return self.sample_dep_features(sampled_batch, sbatch, dep_names, dep_n_samples, oversample)


In [10]:
#| export
@patch
def rename_idx_ptr(self:XCSamplerFeatTfm, x, prefix, sampling_prefix=None):
    prefixes = prefix.split('2')
    for i,n in enumerate(range(len(prefixes)-1,0,-1)):
        s = '2'.join(prefixes[n:])
        p = prefix if sampling_prefix is None else sampling_prefix
        x[f'{p}_{s}2ptr'] = x[f'{prefix}_idx_ptr-{i+1}']
        del x[f'{prefix}_idx_ptr-{i+1}']
    return x
    

In [11]:
#| export
@patch
def collate_feature_idx(self:XCSamplerFeatTfm, x, name, sampling_name=None):
    level = name.count('2')
    o = self.pad_proc(x, prefix=f'{name}_idx', lev=level)

    if f'{name}_idx' in o:
        if sampling_name is not None and f'{sampling_name}_idx' not in o:
            o[f'{sampling_name}_idx'] = o[f'{name}_idx']
            del o[f'{name}_idx']
        o = self.rename_idx_ptr(o, name, sampling_name)
        o = {f'p{k}':v for k,v in o.items()}
        
    return o 
    

In [12]:
#| export
@patch
def get_rnd_idx_from_ptr(self:XCSamplerFeatTfm, x, n_samples, oversample=True):
    if oversample: return [torch.randint(i, size=(n_samples,)) if i>0 else torch.tensor([-1]) for i in x]
    else: return [torch.randperm(i)[:n_samples] if i>0 else torch.tensor([-1]) for i in x]


In [13]:
#| export
@patch
def get_features(self:XCSamplerFeatTfm, x, prefix:str):
    pat = f'^({prefix.replace(",","|")})_.*'
    return [o for o in x if re.match(pat, o)]
    

In [14]:
#| export
@patch
def sample_batch(self:XCSamplerFeatTfm, batch, features, idxs, level):
    sbatch = []
    for b,idx in zip(batch, idxs):
        sfeatures = {}
        for feature in features:
            cfeature = self.col_proc(b[feature], level)[0]
            sfeatures[feature] = [] if idx[0] == -1 else [cfeature[i] for i in idx]
        sbatch.append(sfeatures)
    return sbatch
    

In [15]:
#| export
@patch
def remove_unwanted_ptr(self:XCSamplerFeatTfm, x):
    return {k:v for k,v in x.items() if not re.match('.*_ptr-[0-9]+$', k)}

@patch
def rename_keys(self:XCSamplerFeatTfm, x, prefix):
    keys = list(x.keys())
    for k in keys:
        nk = k.split('_', maxsplit=1)[1]
        nk = f'{prefix}_{nk}'
        if nk not in x:
            x[nk] = x[k]
            del x[k]
    return x

@patch
def collate_features(self:XCSamplerFeatTfm, x, name, sampling_name=None):
    level = name.count('2')
    o = self.pad_proc(x, prefix=name, lev=level)
    o = self.rename_idx_ptr(o, name, sampling_name)
    o = self.remove_unwanted_ptr(o)
    if sampling_name is not None: o = self.rename_keys(o, sampling_name)
    return o
    

In [16]:
#| export
@patch
def sample_base_feature(self:XCSamplerFeatTfm, batch:List, prefix_names:str, name:str, n_sample:int, oversample:Optional[bool]=True):
    sampled_batch, sbatch = {}, {}
    
    feat_prefix = name.split('2')
    sampling_name,ptr_name = f'{feat_prefix[0]}2{feat_prefix[-1]}',feat_prefix[-1]
    
    o = self.collate_feature_idx(batch, name=name, sampling_name=sampling_name)

    if len(o):
        sampling_idx = self.get_rnd_idx_from_ptr(o[f'p{sampling_name}_{ptr_name}2ptr'], n_sample, oversample=oversample)
        
        sampled_batch.update(o)
        
        feats,level = self.get_features(batch[0], prefix_names), name.count('2')-1
        sbatch = self.sample_batch(batch, feats, sampling_idx, level)
    
        o = self.collate_features(sbatch, name=name, sampling_name=sampling_name)
        sampled_batch.update(o)
    
    return sampled_batch, sbatch
    

In [17]:
#| export
@patch
def sample_sbatch(self:XCSamplerFeatTfm, batch, features, n_samples, oversample=True):
    sbatch = []
    for b in batch:
        
        idxs = []
        for val in b[features[0]]:
            if oversample: idx = np.random.randint(len(val), size=n_samples) if len(val) > 0 else []
            else: idx = np.random.permutation(len(val))[:n_samples]
            idxs.append(idx)
        
        sfeatures = {}
        for feature in features:
            
            svalues = []
            for val,idx in zip(b[feature],idxs):
                svalues.append([val[i] for i in idx])
                
            sfeatures[feature] = svalues
            
        sbatch.append(sfeatures)
    return sbatch
    

In [18]:
#| export
@patch
def sample_dep_features(
    self:XCSamplerFeatTfm, 
    sampled_batch:List, 
    sbatch:List, 
    names:List, 
    n_samples:List, 
    oversample:Optional[bool]=True
):
    for name,n_sample in zip(names,n_samples):
        sampling_name = '2'.join(name.split('2')[:2])
        o = self.collate_feature_idx(sbatch, name=name, sampling_name=sampling_name)

        if len(o):
            sampled_batch.update(o)
            
            feats = self.get_features(sbatch[0], name)
            o = self.sample_sbatch(sbatch, feats, n_sample, oversample=oversample)
            o = self.collate_features(o, name=name, sampling_name=sampling_name)
            sampled_batch.update(o)

    return sampled_batch
    

In [19]:
#| export
@patch
def process_features(self:XCSamplerFeatTfm, sampled_batch:BatchEncoding, batch:BatchEncoding, names:List):
    for name in names:
        o = self.collate_features(batch, name=name)
        sampled_batch.update(o)
    return sampled_batch
    

In [20]:
#| export
@patch
def __call__(
    self:XCSamplerFeatTfm, 
    batch:List, 
    sampling_features:Optional[List]=None,
    oversample:Optional[bool]=None,
):  
    store_attr('sampling_features,oversample', is_none=False)

    sampled_features = set()
    out = BatchEncoding({})
    for name, n_sample in self.sampling_features:
        o = self.sample_feature(batch, name, n_sample, self.oversample)
        out.update(o)

        sampled_features.update(name.split(','))

    all_features = set([k.split('_')[0] for k in batch[0].keys()])
    remaining_features = all_features.difference(sampled_features)
    out = self.process_features(out, batch, remaining_features)
    
    return out
    

### Example

In [21]:
batch = block.train.dset.one_batch(bsz=1600, seed=100)

In [22]:
sampler = XCSamplerFeatTfm(pad_token=0, sampling_features=[('lbl2data,cat2lbl2data', (4,1)), ('cat2data', 1)])

In [23]:
o = sampler(batch, oversample=False)

In [24]:
o['cat2lbl_input_ids'].shape, o['cat2lbl_attention_mask'].shape

(torch.Size([2373, 17]), torch.Size([2373, 17]))

In [25]:
o.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_data2ptr', 'pcat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'cat2lbl_lbl2data2ptr', 'cat2lbl_idx', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [30]:
from tqdm.auto import tqdm

In [46]:
for batch in tqdm(block.train.dl):
    a = torch.where(batch['cat2lbl_input_ids'] == 102)[1] + 1 
    b = batch['cat2lbl_attention_mask'].sum(dim=1)
    assert torch.all(a == b).item()


  0%|          | 0/434 [00:00<?, ?it/s]

### XCSamplerFeatTfm code

In [None]:
pad_proc = PadFeatTfm(pad_tok=0, in_place=False, drop=False)
col_proc = CollapseTfm()

In [None]:
prefix = 'lbl2data,cat2lbl2data'
n_samples = [2,3]

names = prefix.split(',')
base_name, smp_names = names[0], names[1:]

for p in smp_names:
    if not p.endswith(base_name): 
        raise ValueError(f'{p} does not end with the base prefix `{base_name}`.')
        

In [None]:
if isinstance(n_samples, int): 
    n_samples = (n_samples,)*len(smp_names)

if len(names) != len(n_samples):
    raise ValueError(f'`prefixes` and `n_samples` should have same length.')
    

In [None]:
sampled_batch = {}

In [None]:
name, n_sample = base_name, n_samples[0]

feat_prefix = name.split('2')
sampling_name,ptr_name = f'{feat_prefix[0]}2{feat_prefix[-1]}',feat_prefix[-1]


o = collate_feature_idx(batch, name=name, sampling_name=sampling_name)
sampling_idx = get_rnd_idx_from_ptr(o[f'p{sampling_name}_{ptr_name}2ptr'], n_sample, oversample=True)

sampled_batch.update(o)

In [None]:
feats,level = get_features(batch[0], prefix), name.count('2')-1
sbatch = sample_batch(batch, feats, sampling_idx, level)

In [None]:
o = collate_features(sbatch, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

In [None]:
name, n_sample = smp_names[0], n_samples[0]
sampling_name = '2'.join(name.split('2')[:2])

o = collate_feature_idx(batch, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

feats = get_features(sbatch[0], name)

In [None]:
o = sample_sbatch(sbatch, feats, n_samples[1], oversample=True)
o = collate_features(o, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

In [None]:
sampled_batch

{'plbl2data_data2ptr': tensor([ 1, 14,  1,  1,  1]),
 'lbl2data_idx': tensor([ 97475,  97475, 134705,  14241, 196033, 196033,  26569,  26569, 195049,
         195049]),
 'lbl2data_identifier': ['List_of_Test_cricket_umpires',
  'List_of_Test_cricket_umpires',
  'List_of_drugs_used_by_militaries',
  'Military_medicine',
  'List_of_rivers_of_Mexico',
  'List_of_rivers_of_Mexico',
  'List_of_New_South_Wales_representative_cricketers',
  'List_of_New_South_Wales_representative_cricketers',
  'List_of_antarctic_and_sub-antarctic_islands',
  'List_of_antarctic_and_sub-antarctic_islands'],
 'lbl2data_input_text': ['List of Test cricket umpires',
  'List of Test cricket umpires',
  'List of drugs used by militaries',
  'Military medicine',
  'List of rivers of Mexico',
  'List of rivers of Mexico',
  'List of New South Wales representative cricketers',
  'List of New South Wales representative cricketers',
  'List of antarctic and sub-antarctic islands',
  'List of antarctic and sub-antarctic 

## `OAKSamplerFeatTfm`

In [6]:
#| export
class OAKSamplerFeatTfm:

    def __init__(self, metadata_name:str, num_labels:Optional[int]=1, num_metadata:Optional[int]=1, **kwargs):
        self.meta_name,self.n_labels,self.n_meta = metadata_name,num_labels,num_metadata

    @staticmethod
    def collate_data(batch:Dict, features:List):
        batch['data_idx'] = torch.tensor([o['data_idx'] for o in features], dtype=torch.int64)
        batch['data_input_ids'] = torch.vstack([o['data_input_ids'] for o in features])
        batch['data_attention_mask'] = torch.vstack([o['data_attention_mask'] for o in features])

    @staticmethod
    def collate_labels(batch:Dict, features:List, n_labels:Optional[int]=1):
        batch['plbl2data_data2ptr'] = torch.tensor([len(o['lbl2data_idx']) for o in features], dtype=torch.int64)
        batch['plbl2data_idx'] = torch.tensor(list(chain(*[o['lbl2data_idx'] for o in features])), dtype=torch.int64)
    
        input_ids = torch.vstack(list(chain(*[o['lbl2data_input_ids'] for o in features])))
        attention_mask = torch.vstack(list(chain(*[o['lbl2data_attention_mask'] for o in features])))
    
        indptr = torch.cat([ torch.zeros((1,), dtype=torch.int64), batch['plbl2data_data2ptr'].cumsum(dim=0)])
        idx = torch.hstack([torch.randperm(n)[:n_labels]+offset for n,offset in zip(batch['plbl2data_data2ptr'], indptr)])
    
        batch['lbl2data_data2ptr'] = torch.clamp(batch['plbl2data_data2ptr'], max=n_labels)
        batch['lbl2data_idx'] = batch['plbl2data_idx'][idx]
        batch['lbl2data_attention_mask'] = attention_mask[idx]
        batch['lbl2data_input_ids'] = input_ids[idx]

    @staticmethod
    def collate_metadata(batch:Dict, features:List, meta_name:str, n_meta:Optional[int]=1):
        batch[f'p{meta_name}2data_data2ptr'] = torch.tensor([len(o[f'{meta_name}2data_idx']) for o in features], dtype=torch.int64)
        batch[f'p{meta_name}2data_idx'] = torch.tensor(list(chain(*[o[f'{meta_name}2data_idx'] for o in features])), dtype=torch.int64)
        
        indptr = torch.cat([ torch.zeros((1,), dtype=torch.int64), batch[f'p{meta_name}2data_data2ptr'].cumsum(dim=0)])
        idx = torch.hstack([torch.randperm(n)[:n_meta]+offset for n,offset in zip(batch[f'p{meta_name}2data_data2ptr'], indptr)])
    
        batch[f'{meta_name}2data_data2ptr'] = torch.clamp(batch[f'p{meta_name}2data_data2ptr'], max=n_meta)
        batch[f'{meta_name}2data_idx'] = batch[f'p{meta_name}2data_idx'][idx]

    def __call__(self, features:List):
        batch = {}
        self.collate_data(batch, features)
        self.collate_labels(batch, features, n_labels=self.n_labels)
        if self.meta_name:
            self.collate_metadata(batch, features, self.meta_name, n_meta=self.n_meta)
        return batch
        

### Example

In [7]:
batch = block.train.dset.one_batch(bsz=1600, seed=1000)

In [9]:
sampler = OAKSamplerFeatTfm(num_labels=2, num_metadata=3, metadata_name='cat')

In [10]:
o = sampler(batch)

In [11]:
o.keys()

dict_keys(['data_idx', 'data_input_ids', 'data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_attention_mask', 'lbl2data_input_ids', 'pcat2data_data2ptr', 'pcat2data_idx', 'cat2data_data2ptr', 'cat2data_idx'])

In [12]:
for k,v in o.items():
    print(k, ':', v.shape)

data_idx : torch.Size([1600])
data_input_ids : torch.Size([1600, 32])
data_attention_mask : torch.Size([1600, 32])
plbl2data_data2ptr : torch.Size([1600])
plbl2data_idx : torch.Size([3361])
lbl2data_data2ptr : torch.Size([1600])
lbl2data_idx : torch.Size([2387])
lbl2data_attention_mask : torch.Size([2387, 32])
lbl2data_input_ids : torch.Size([2387, 32])
pcat2data_data2ptr : torch.Size([1600])
pcat2data_idx : torch.Size([7952])
cat2data_data2ptr : torch.Size([1600])
cat2data_idx : torch.Size([3961])


In [15]:
tokz = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [77]:
n = 700

In [78]:
tokz.decode(o['data_input_ids'][n], skip_special_tokens=True)

'republic of maryland'

In [79]:
indptr = o['lbl2data_data2ptr'].cumsum(dim=0)

In [80]:
o['lbl2data_idx'][indptr[n-1]:indptr[n]]

tensor([90687, 89859])

In [81]:
tokz.batch_decode(o['lbl2data_input_ids'][indptr[n-1]:indptr[n]], skip_special_tokens=True)

['history of liberia', 'history of slavery in maryland']

In [82]:
indptr = o['plbl2data_data2ptr'].cumsum(dim=0)

In [83]:
o['plbl2data_idx'][indptr[n-1]:indptr[n]]

tensor([89859, 90687, 90688])