# data sampler

In [None]:
#| default_exp data_sampler

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import os,pickle,torch,re

from xcai.transform import PadFeatTfm,CollapseTfm

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_ramen-cat.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.dset.one_batch(bsz=5, seed=10)

## Sampler

In [None]:
pad_proc = PadFeatTfm(pad_tok=0, in_place=False, drop=False)
col_proc = CollapseTfm()

In [None]:
prefix = 'lbl2data,cat2lbl2data'
n_samples = 1

prefixes = prefix.split(',')
base_prefix, prefixes = prefixes[0], prefixes[1:]

for p in prefixes:
    if not p.endswith(base_prefix): 
        raise ValueError(f'{p} does not end with the base prefix `{base_prefix}`.')
        

In [None]:
if isinstance(n_samples, int): 
    n_samples = (n_samples,)*len(prefixes)

if len(prefixes) != len(n_samples):
    raise ValueError(f'`prefixes` and `n_samples` should have same length.')
    

In [None]:
def rename_idx_ptr(x, prefix, smp_prefix=None):
    prefixes = prefix.split('2')
    for i,n in enumerate(range(len(prefixes)-1,0,-1)):
        p,s = '2'.join(prefixes[:n+1]), '2'.join(prefixes[n:])
        p = p if smp_prefix is None else smp_prefix
        x[f'{p}_{s}2ptr'] = x[f'{prefix}_idx_ptr-{i+1}']
        del x[f'{prefix}_idx_ptr-{i+1}']
    return x
    

In [None]:
def collate_feature_idx(x, prefix, smp_prefix=None):
    level = prefix.count('2')
    o = pad_proc(x, prefix=f'{prefix}_idx', lev=level)
    
    if smp_prefix is not None:
        o[f'{smp_prefix}_idx'] = o[f'{prefix}_idx']
        del o[f'{prefix}_idx']
        
    o = rename_idx_ptr(o, prefix, smp_prefix)
    return {f'p{k}':v for k,v in o.items()}
    

In [None]:
def get_rnd_idx(x):
    return torch.cat([torch.randint(i, size=(1,)) if i>0 else torch.tensor([-1]) for i in x])
    

In [None]:
smp_feat, base_feat = base_prefix.split('2', maxsplit=1)
o = collate_feature_idx(batch, base_prefix)

In [None]:
sampled_batch = {}

In [None]:
sampled_batch.update(o)

In [None]:
o

{'plbl2data_idx': tensor([ 97475,   8095,  14241,  53207,  85334,  87177,  87553, 134705, 150807,
         153681, 168326, 186062, 188361, 221141, 252304, 196033,  26569, 195049]),
 'plbl2data_data2ptr': tensor([ 1, 14,  1,  1,  1])}

In [None]:
smp_idx = get_rnd_idx(o[f'p{base_prefix}_{base_feat}2ptr']); smp_idx

tensor([0, 3, 0, 0, 0])

In [None]:
def get_features(x, prefix:str):
    pat = f'^({prefix.replace(",","|")})_.*'
    return [o for o in x if re.match(pat, o)]

def sample_batch(x, feat, idx):
    return [{f: [] if i<0 else [o[f][i]] for f in feat} for i,o in zip(idx, x)]


In [None]:
feats = get_features(batch[0], prefix)
smp_batch = sample_batch(batch, feats, smp_idx)

In [None]:
def remove_unwanted_ptr(x):
    return {k:v for k,v in x.items() if not re.match('.*_ptr-[0-9]+$', k)}

def rename_keys(x, prefix):
    keys = list(x.keys())
    for k in keys:
        nk = k.split('_', maxsplit=1)[1]
        nk = f'{prefix}_{nk}'
        if nk not in x:
            x[nk] = x[k]
            del x[k]
    return x
    
def collate_feat(x, prefix, smp_prefix=None):
    level = prefix.count('2')
    o = pad_proc(x, prefix=prefix, lev=level)
    o = rename_idx_ptr(o, prefix, smp_prefix)
    o = remove_unwanted_ptr(o)
    if smp_prefix is not None: 
        o = rename_keys(o, smp_prefix)
    return o
    

In [None]:
o = collate_feat(smp_batch, prefix=base_prefix)

In [None]:
sampled_batch.update(o)

In [None]:
sampled_batch

{'plbl2data_idx': tensor([ 97475,   8095,  14241,  53207,  85334,  87177,  87553, 134705, 150807,
         153681, 168326, 186062, 188361, 221141, 252304, 196033,  26569, 195049]),
 'plbl2data_data2ptr': tensor([ 1, 14,  1,  1,  1]),
 'lbl2data_idx': tensor([ 97475,  85334, 196033,  26569, 195049]),
 'lbl2data_identifier': ['List_of_Test_cricket_umpires',
  'Triage',
  'List_of_rivers_of_Mexico',
  'List_of_New_South_Wales_representative_cricketers',
  'List_of_antarctic_and_sub-antarctic_islands'],
 'lbl2data_input_text': ['List of Test cricket umpires',
  'Triage',
  'List of rivers of Mexico',
  'List of New South Wales representative cricketers',
  'List of antarctic and sub-antarctic islands'],
 'lbl2data_input_ids': tensor([[  101,  2862,  1997,  3231,  4533, 20887,  2015,   102,     0,     0],
         [  101, 13012,  4270,   102,     0,     0,     0,     0,     0,     0],
         [  101,  2862,  1997,  5485,  1997,  3290,   102,     0,     0,     0],
         [  101,  2862,  1

In [None]:
def get_sample_prefix(prefix, smp_feat):
    prefixes = prefix.split('2')
    return '2'.join(prefixes[:prefixes.index(smp_feat)+1])

def sample_smp_batch(x, feat, idx, smp_prefix=None):
    smp_batch = []
    for i,o in zip(idx,x):
        point = {}
        for f in feat:
            k = f
            if smp_prefix is not None:
                suffix = f.split('_', maxsplit=1)[1]
                k = f'{smp_prefix}_{suffix}'
            point[k] = [[]] if i<0 else [[o[f][0][i]]]
        smp_batch.append(point)
    return smp_batch
    

In [None]:
prefixes

['cat2lbl2data']

In [None]:
p = prefixes[0]

In [None]:
smp_prefix = get_sample_prefix(p, smp_feat)
o = collate_feature_idx(smp_batch, prefix=p, smp_prefix=smp_prefix); o

{'pcat2lbl_idx': tensor([402688, 495564, 497116, 497117,  55311,  57683,  74600, 381870, 464092,
          72163, 499473, 504533,  62743, 426229, 490629]),
 'pcat2lbl_data2ptr': tensor([4, 5, 3, 3, 0]),
 'pcat2lbl_lbl2data2ptr': tensor([4, 5, 3, 3, 0])}

In [None]:
smp_idx = get_rnd_idx(o[f'p{smp_prefix}_{base_prefix}2ptr']); smp_idx

tensor([ 1,  1,  1,  2, -1])

In [None]:
feats = get_features(smp_batch[0], p)
o = sample_smp_batch(smp_batch, feats, smp_idx)
o = collate_feat(o, prefix=p, smp_prefix=smp_prefix)

In [None]:
o

{'cat2lbl_data2ptr': tensor([1, 1, 1, 1, 0]),
 'cat2lbl_lbl2data2ptr': tensor([1, 1, 1, 1, 0]),
 'cat2lbl_idx': tensor([495564,  57683, 499473, 490629]),
 'cat2lbl_identifier': ['Category:International_cricket_umpires',
  'Category:Intensive_care_medicine',
  'Category:Lists_of_landforms_of_Mexico',
  'Category:Lists_of_Australian_cricketers'],
 'cat2lbl_input_text': ['International cricket umpires',
  'Intensive care medicine',
  'Lists of landforms of Mexico',
  'Lists of Australian cricketers'],
 'cat2lbl_input_ids': tensor([[  101,  2248,  4533, 20887,  2015,   102,     0,     0],
         [  101, 11806,  2729,  4200,   102,     0,     0,     0],
         [  101,  7201,  1997,  2455, 22694,  1997,  3290,   102],
         [  101,  7201,  1997,  2827,  9490,  2015,   102,     0]]),
 'cat2lbl_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 0]])}