# data sampler

In [None]:
#| default_exp data_sampler

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import os,pickle,torch,re, numpy as np
from typing import Optional,List

from transformers import BatchEncoding

from fastcore.utils import *

from xcai.transform import PadFeatTfm,CollapseTfm
from xcai.core import store_attr

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_ramen-cat.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.dset.one_batch(bsz=5, seed=10)

## `XCSamplerFeatTfm`

In [None]:
#| export
class XCSamplerFeatTfm:

    def __init__(
        self,
        pad_token:Optional[int]=0,
        pad_in_place:Optional[bool]=True,
        pad_drop=True,

        sampling_features:Optional[List]=None,
    ):
        store_attr('sampling_features')
        self.pad_proc = PadFeatTfm(pad_tok=pad_token, in_place=pad_in_place, drop=pad_drop)
        self.col_proc = CollapseTfm()

    def sample_feature(self, batch, names, n_samples, oversample):
        feature_names = names.split(',')
        
        if isinstance(n_samples, int): 
            n_samples = (n_samples,)*len(feature_names)

        if len(feature_names) != len(n_samples):
            raise ValueError(f'`feature_names` and `n_samples` should have same length.')
        
        base_name, dep_names = feature_names[0], feature_names[1:]
        base_n_sample, dep_n_samples = n_samples[0], n_samples[1:]

        for p in dep_names:
            if not p.endswith(base_name): 
                raise ValueError(f'{p} does not end with the base prefix `{base_name}`.')

        sampled_batch, sbatch = self.sample_base_feature(batch, names, base_name, base_n_sample, oversample)
        return self.sample_dep_features(sampled_batch, sbatch, dep_names, dep_n_samples, oversample)


In [None]:
#| export
@patch
def rename_idx_ptr(self:XCSamplerFeatTfm, x, prefix, sampling_prefix=None):
    prefixes = prefix.split('2')
    for i,n in enumerate(range(len(prefixes)-1,0,-1)):
        s = '2'.join(prefixes[n:])
        p = prefix if sampling_prefix is None else sampling_prefix
        x[f'{p}_{s}2ptr'] = x[f'{prefix}_idx_ptr-{i+1}']
        del x[f'{prefix}_idx_ptr-{i+1}']
    return x
    

In [None]:
#| export
@patch
def collate_feature_idx(self:XCSamplerFeatTfm, x, name, sampling_name=None):
    level = name.count('2')
    o = self.pad_proc(x, prefix=f'{name}_idx', lev=level)
    
    if sampling_name is not None:
        o[f'{sampling_name}_idx'] = o[f'{name}_idx']
        del o[f'{name}_idx']
        
    o = self.rename_idx_ptr(o, name, sampling_name)
    return {f'p{k}':v for k,v in o.items()}
    

In [None]:
#| export
@patch
def get_rnd_idx_from_ptr(self:XCSamplerFeatTfm, x, n_samples, oversample=True):
    if oversample: return [torch.randint(i, size=(n_samples,)) if i>0 else torch.tensor([-1]) for i in x]
    else: return [torch.randperm(i)[:n_samples] if i>0 else torch.tensor([-1]) for i in x]


In [None]:
#| export
@patch
def get_features(self:XCSamplerFeatTfm, x, prefix:str):
    pat = f'^({prefix.replace(",","|")})_.*'
    return [o for o in x if re.match(pat, o)]
    

In [None]:
#| export
@patch
def sample_batch(self:XCSamplerFeatTfm, batch, features, idxs, level):
    sbatch = []
    for b,idx in zip(batch, idxs):
        sfeatures = {}
        for feature in features:
            cfeature = self.col_proc(b[feature], level)[0]
            sfeatures[feature] = [] if idx[0] == -1 else [cfeature[i] for i in idx]
        sbatch.append(sfeatures)
    return sbatch
    

In [None]:
#| export
@patch
def remove_unwanted_ptr(self:XCSamplerFeatTfm, x):
    return {k:v for k,v in x.items() if not re.match('.*_ptr-[0-9]+$', k)}

@patch
def rename_keys(self:XCSamplerFeatTfm, x, prefix):
    keys = list(x.keys())
    for k in keys:
        nk = k.split('_', maxsplit=1)[1]
        nk = f'{prefix}_{nk}'
        if nk not in x:
            x[nk] = x[k]
            del x[k]
    return x

@patch
def collate_features(self:XCSamplerFeatTfm, x, name, sampling_name=None):
    level = name.count('2')
    o = self.pad_proc(x, prefix=name, lev=level)
    o = self.rename_idx_ptr(o, name, sampling_name)
    o = self.remove_unwanted_ptr(o)
    if sampling_name is not None: o = self.rename_keys(o, sampling_name)
    return o
    

In [None]:
#| export
@patch
def sample_base_feature(self:XCSamplerFeatTfm, batch:List, prefix_names:str, name:str, n_sample:int, oversample:Optional[bool]=True):
    sampled_batch = {}
    
    feat_prefix = name.split('2')
    sampling_name,ptr_name = f'{feat_prefix[0]}2{feat_prefix[-1]}',feat_prefix[-1]
    
    o = self.collate_feature_idx(batch, name=name, sampling_name=sampling_name)
    sampling_idx = self.get_rnd_idx_from_ptr(o[f'p{sampling_name}_{ptr_name}2ptr'], n_sample, oversample=oversample)
    
    sampled_batch.update(o)
    
    feats,level = self.get_features(batch[0], prefix_names), name.count('2')-1
    sbatch = self.sample_batch(batch, feats, sampling_idx, level)

    o = self.collate_features(sbatch, name=name, sampling_name=sampling_name)
    sampled_batch.update(o)
    
    return sampled_batch, sbatch
    

In [None]:
#| export
@patch
def sample_sbatch(self:XCSamplerFeatTfm, batch, features, n_samples, oversample=True):
    sbatch = []
    for b in batch:
        sfeatures = {}
        for feature in features:
            
            svalues = []
            for val in b[feature]:

                if oversample: idx = np.random.randint(len(val), size=n_samples) if len(val) > 0 else []
                else: idx = np.random.permutation(len(val))[:n_samples]
    
                svalues.append([val[i] for i in idx])
                
            sfeatures[feature] = svalues
            
        sbatch.append(sfeatures)
    return sbatch
    

In [None]:
#| export
@patch
def sample_dep_features(
    self:XCSamplerFeatTfm, 
    sampled_batch:List, 
    sbatch:List, 
    names:List, 
    n_samples:List, 
    oversample:Optional[bool]=True
):
    for name,n_sample in zip(names,n_samples):
        sampling_name = '2'.join(name.split('2')[:2])
        o = self.collate_feature_idx(sbatch, name=name, sampling_name=sampling_name)
        sampled_batch.update(o)
        
        feats = self.get_features(sbatch[0], name)
        o = self.sample_sbatch(sbatch, feats, n_sample, oversample=oversample)
        o = self.collate_features(o, name=name, sampling_name=sampling_name)
        sampled_batch.update(o)

    return sampled_batch
    

In [None]:
#| export
@patch
def process_features(self:XCSamplerFeatTfm, sampled_batch:BatchEncoding, batch:BatchEncoding, names:List):
    for name in names:
        o = self.collate_features(batch, name=name)
        sampled_batch.update(o)
    return sampled_batch
    

In [None]:
#| export
@patch
def __call__(
    self:XCSamplerFeatTfm, 
    batch:List, 
    sampling_features:Optional[List]=None,
    oversample:Optional[bool]=True
):  
    store_attr('sampling_features,oversample', is_none=False)

    sampled_features = set()
    out = BatchEncoding({})
    for name, n_sample in self.sampling_features:
        o = self.sample_feature(batch, name, n_sample, oversample)
        out.update(o)

        sampled_features.update(name.split(','))

    all_features = set([k.split('_')[0] for k in batch[0].keys()])
    remaining_features = all_features.difference(sampled_features)
    out = self.process_features(out, batch, remaining_features)
    
    return out
    

### Example

In [None]:
sampler = XCSamplerFeatTfm(pad_token=0, pad_in_place=False, pad_drop=False,
                           sampling_features=[('lbl2data,cat2lbl2data', 2), ('cat2data', 2)])

In [None]:
o = sampler(batch, oversample=True)

In [None]:
o.keys()

dict_keys(['plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_data2ptr', 'pcat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'cat2lbl_lbl2data2ptr', 'cat2lbl_idx', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

### XCSamplerFeatTfm code

In [None]:
pad_proc = PadFeatTfm(pad_tok=0, in_place=False, drop=False)
col_proc = CollapseTfm()

In [None]:
prefix = 'lbl2data,cat2lbl2data'
n_samples = [2,3]

names = prefix.split(',')
base_name, smp_names = names[0], names[1:]

for p in smp_names:
    if not p.endswith(base_name): 
        raise ValueError(f'{p} does not end with the base prefix `{base_name}`.')
        

In [None]:
if isinstance(n_samples, int): 
    n_samples = (n_samples,)*len(smp_names)

if len(names) != len(n_samples):
    raise ValueError(f'`prefixes` and `n_samples` should have same length.')
    

In [None]:
sampled_batch = {}

In [None]:
name, n_sample = base_name, n_samples[0]

feat_prefix = name.split('2')
sampling_name,ptr_name = f'{feat_prefix[0]}2{feat_prefix[-1]}',feat_prefix[-1]


o = collate_feature_idx(batch, name=name, sampling_name=sampling_name)
sampling_idx = get_rnd_idx_from_ptr(o[f'p{sampling_name}_{ptr_name}2ptr'], n_sample, oversample=True)

sampled_batch.update(o)

In [None]:
feats,level = get_features(batch[0], prefix), name.count('2')-1
sbatch = sample_batch(batch, feats, sampling_idx, level)

In [None]:
o = collate_features(sbatch, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

In [None]:
name, n_sample = smp_names[0], n_samples[0]
sampling_name = '2'.join(name.split('2')[:2])

o = collate_feature_idx(batch, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

feats = get_features(sbatch[0], name)

In [None]:
o = sample_sbatch(sbatch, feats, n_samples[1], oversample=True)
o = collate_features(o, name=name, sampling_name=sampling_name)
sampled_batch.update(o)

In [None]:
sampled_batch

{'plbl2data_data2ptr': tensor([ 1, 14,  1,  1,  1]),
 'lbl2data_idx': tensor([ 97475,  97475, 134705,  14241, 196033, 196033,  26569,  26569, 195049,
         195049]),
 'lbl2data_identifier': ['List_of_Test_cricket_umpires',
  'List_of_Test_cricket_umpires',
  'List_of_drugs_used_by_militaries',
  'Military_medicine',
  'List_of_rivers_of_Mexico',
  'List_of_rivers_of_Mexico',
  'List_of_New_South_Wales_representative_cricketers',
  'List_of_New_South_Wales_representative_cricketers',
  'List_of_antarctic_and_sub-antarctic_islands',
  'List_of_antarctic_and_sub-antarctic_islands'],
 'lbl2data_input_text': ['List of Test cricket umpires',
  'List of Test cricket umpires',
  'List of drugs used by militaries',
  'Military medicine',
  'List of rivers of Mexico',
  'List of rivers of Mexico',
  'List of New South Wales representative cricketers',
  'List of New South Wales representative cricketers',
  'List of antarctic and sub-antarctic islands',
  'List of antarctic and sub-antarctic 