# SData

In [1]:
#| default_exp sdata

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [83]:
#| export
import torch, inspect, numpy as np, scipy.sparse as sp, inspect
from typing import Callable, Optional, Union, Dict
from torch.utils.data import DataLoader
from transformers import BatchEncoding
from itertools import chain

from xcai.core import Filterer, Info
from xcai.data import MainXCData, MetaXCData
from xcai.data import BaseXCDataset, MainXCDataset, MetaXCDataset, XCDataset
from xcai.data import BaseXCDataBlock, XCDataBlock
from xcai.data import _read_sparse_file
from xcai.graph.operations import *

from fastcore.utils import *
from fastcore.meta import *
from plum import dispatch

In [4]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Data

In [5]:
dset_dir = '/Users/suchith720/Projects/data/(mapped)LF-WikiSeeAlsoTitles-320K'
data_cfg = {
    'info_column_names': ['identifier', 'input_text'],
    'use_tokenizer': True,
    'tokenizer': 'sentence-transformers/msmarco-distilbert-base-v4',
    'tokenization_column': 'input_text',
    'main_max_data_sequence_length': 32,
    'main_max_lbl_sequence_length': 32,
    'meta_max_sequence_length': 32,
    'padding': True,
    'return_tensors': 'pt',
}

In [6]:
train_cfg = {
    'data_lbl': f'{dset_dir}/trn_X_Y.txt',
    'data_info': f'{dset_dir}/raw_data/train.raw.txt',
    'lbl_info': f'{dset_dir}/raw_data/label.raw.txt',
    'data_lbl_filterer': f'{dset_dir}/filter_labels_train.txt',
}
train_data = MainXCData.from_file(**train_cfg, **data_cfg)

tokenizer_config.json:   0%|          | 0.00/319 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [47]:
meta_cfg = {
    'prefix': 'hlk',
    'data_meta': f'{dset_dir}/category_trn_X_Y.txt',
    'lbl_meta': f'{dset_dir}/category_lbl_X_Y.txt',
    'meta_info': f'{dset_dir}/raw_data/category.raw.txt',
}
meta_data = MetaXCData.from_file(**meta_cfg, **data_cfg)



In [48]:
#| export
def identity_collate_fn(batch): return BatchEncoding(batch)

## Sampler

In [64]:
class Sampler:
    
    @staticmethod
    def dropout(idxs:List, remove:Optional[float]=None, replace:Optional[float]=None):
        remove_mask, replace_mask = list(), list()
        for idx in idxs:
            if remove is not None:
                if np.random.rand() < remove:
                    remove_mask.append([1]*len(idx))
                    if replace is not None:
                        replace_mask.append([0]*len(idx))
                else:
                    remove_mask.append([0]*len(idx))
                    if replace is not None: 
                        replace_mask.append([1]*len(idx) if np.random.rand() < replace else [0]*len(idx))
            elif replace is not None:
                replace_mask.append([1]*len(idx) if np.random.rand() < replace else [0]*len(idx))
        return remove_mask, replace_mask

    @staticmethod
    def prune_indices_and_scores(output:Dict, prefix:str, data_lbl_indices:List, data_lbl_scores:List, 
                                 indices:List, num_samples:Optional[int]=None, use_distribution:Optional[bool]=False,
                                 return_scores:Optional[bool]=False, dtype=torch.int64):
        entity = prefix.split('2')[-1]
        output[f'p{prefix}_idx'] = [data_lbl_indices[idx] for idx in indices]
        scores = [data_lbl_scores[idx] for idx in indices] if use_distribution or return_scores else None
        
        if num_samples:
            if scores is None:
                output[f'p{prefix}_idx'] = [[o[i] for i in np.random.permutation(len(o))[:num_samples]] for o in output[f'p{prefix}_idx']]
            else:
                idxs, sc = list(), list()
                for p,q in zip(output[f'p{prefix}_idx'], scores):
                    assert len(p) == len(q)
                    rnd_idx = np.random.permutation(len(p))[:num_samples]
                    idxs.append([p[i] for i in rnd_idx])
                    sc.append([q[i] for i in rnd_idx])
                output[f'p{prefix}_idx'], scores = idxs, sc
                
        output[f'p{prefix}_{entity}2ptr'] = torch.tensor([len(o) for o in output[f'p{prefix}_idx']], dtype=dtype)
        return scores

    @staticmethod
    def sample_indices_and_scores(indices:List, scores:Optional[List]=None, num_samples:Optional[int]=1, 
                                  oversample:Optional[bool]=False, use_distribution:Optional[bool]=False, 
                                  return_scores:Optional[bool]=False):
        if use_distribution and scores is None:
            raise ValueError(f'`scores` cannot be empty when `use_distribution` is set.')
        
        s_indices, s_scores = [], []
        for k in range(len(indices)):
            probs = scores[k] if use_distribution else None
            size = num_samples if oversample else min(num_samples, len(indices[k]))
            
            rnd_idx = np.random.choice(len(indices[k]), size=size, p=probs, replace=oversample) if len(indices[k]) else []

            s_indices.append([indices[k][i] for i in rnd_idx])
            if return_scores:
                assert len(indices[k]) == len(scores[k]), f'Length of indices({len(indices[k])}) and scores({(len(scores[k]))}) should be equal.'
                s_scores.append([scores[k][i] for i in rnd_idx])

        return s_indices, s_scores

    @staticmethod
    def get_info(prefix:str, idxs:List, info:Dict, info_keys:List):
        output = dict()
        for k,v in info.items():
            if k in info_keys:
                if isinstance(v, np.ndarray) or isinstance(v, torch.Tensor):
                    o = v[idxs]
                    if isinstance(o, np.ndarray): o = torch.from_numpy(o)
                    output[f'{prefix}_{k}'] = o
                else:
                    output[f'{prefix}_{k}'] = [v[idx] for idx in idxs]
        return output

    @staticmethod
    def extract_items(
        prefix:str,
        data_lbl_indices:List,
        
        indices:List, 
        num_samples:int, 
        num_sampler_samples:int, 
        oversample:bool, 
                      
        info:Dict, 
        info_keys:List,
        
        use_distribution:Optional[bool]=False, 
        data_lbl_scores:Optional[List]=None, 
                      
        dropout_remove:Optional[float]=None, 
        dropout_replace:Optional[float]=None, 
        return_scores:Optional[bool]=False,
        dtype=torch.int64,
    ):
        output, entity = dict(), prefix.split('2')[-1]
            
        scores = Sampler.prune_indices_and_scores(output, prefix, data_lbl_indices, data_lbl_scores, indices, 
                                                  num_samples, use_distribution, return_scores, dtype=dtype)
        
        output[f'{prefix}_idx'], scores = Sampler.sample_indices_and_scores(output[f'p{prefix}_idx'], scores, 
                                                                            num_sampler_samples, oversample, 
                                                                            use_distribution, return_scores)
        if return_scores:
            output[f'{prefix}_scores'] = torch.tensor(list(chain(*scores)), dtype=torch.float32)

        output[f'{prefix}_{entity}2ptr'] = torch.tensor([len(o) for o in output[f'{prefix}_idx']], dtype=dtype)
        output[f'{prefix}_idx'] = torch.tensor(list(chain(*output[f'{prefix}_idx'])), dtype=dtype)
        output[f'p{prefix}_idx'] = torch.tensor(list(chain(*output[f'p{prefix}_idx'])), dtype=dtype)
        
        if info is not None:
            output.update(Sampler.get_info(prefix, output[f'{prefix}_idx'], info, info_keys))
            
        return output
        

## SDataset

### `SMainXCDataset`

In [74]:
#| export
class SMainXCDataset(MainXCDataset):

    def __init__(
        self,
        n_slbl_samples:Optional[int]=1,
        main_oversample:Optional[bool]=False,
        use_main_distribution:Optional[bool]=False,
        return_scores:Optional[bool]=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('n_slbl_samples,main_oversample,use_main_distribution,return_scores')
        
        self.data_lbl_scores = None
        if use_main_distribution or return_scores: self._store_scores()
        
    def _store_scores(self):
        if self.data_lbl is not None:
            if self.use_main_distribution:
                data_lbl = self.data_lbl / (self.data_lbl.sum(axis=1) + 1e-9)
                data_lbl = data_lbl.tocsr()
            else:
                data_lbl = self.data_lbl
            self.data_lbl_scores = [o.data.tolist() for o in data_lbl]
        
    def __getitems__(self, idxs:List):
        x = {'data_idx': torch.tensor(idxs, dtype=torch.int64)}
        x.update(self.get_info('data', idxs, self.data_info, self.data_info_keys))
        if self.data_lbl is not None:
            prefix = 'lbl2data'
            o = Sampler.extract_items(prefix, self.curr_data_lbl, idxs, self.n_lbl_samples, self.n_slbl_samples, 
                                      self.main_oversample, self.lbl_info, self.lbl_info_keys, self.use_main_distribution, 
                                      self.data_lbl_scores, return_scores=self.return_scores)
            x.update(o)
        return x
    

#### Example

In [66]:
train_main = SMainXCDataset(**train_data, n_slbl_samples=2)

In [67]:
train_main.__getitems__([100, 200])

{'data_idx': tensor([100, 200]),
 'data_identifier': ['Applet', 'Geography_of_Africa'],
 'data_input_text': ['Applet', 'Geography of Africa'],
 'data_input_ids': tensor([[  101,  6207,  2102,   102,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101, 10505,  1997,  3088,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0]]),
 'data_attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0]]),
 'plbl2data_idx': tensor([  927,   928,   929,  

In [None]:
train_main.oversample = True
train_main.n_slbl_samples = 3

In [None]:
def func():
    import pdb; pdb.set_trace()
    train_main.__getitems__([1,2,3,4])
    

In [None]:
train_dl = DataLoader(train_main, batch_size=10, collate_fn=identity_collate_fn)
batch = next(iter(train_dl))

In [None]:
batch

{'data_idx': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 'data_identifier': ['Anarchism',
  'Autism',
  'Aristotle',
  'Academy_Awards',
  'International_Atomic_Time',
  'Altruism',
  'Allan_Dwan',
  'Anthropology',
  'Agricultural_science',
  'Animation'],
 'data_input_text': ['Anarchism',
  'Autism',
  'Aristotle',
  'Academy Awards',
  'International Atomic Time',
  'Altruism',
  'Allan Dwan',
  'Anthropology',
  'Agricultural science',
  'Animation'],
 'data_input_ids': tensor([[  101,  9617, 11140,  2964,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101, 19465,   102,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,    

In [None]:
tokz = AutoTokenizer.from_pretrained(data_cfg['tokenizer'])

In [None]:
batch['plbl2data_data2ptr']

In [None]:
tokz.batch_decode(batch['data_input_ids'])

### `SMetaXCDataset`

In [76]:
#| export
class SMetaXCDataset(MetaXCDataset):

    def __init__(
        self,
        n_sdata_meta_samples:Optional[int]=1,
        n_slbl_meta_samples:Optional[int]=1,
        meta_oversample:Optional[bool]=False,
        use_meta_distribution:Optional[bool]=False,
        meta_dropout_remove:Optional[float]=None,
        meta_dropout_replace:Optional[float]=None,
        return_scores:Optional[bool]=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('n_sdata_meta_samples,n_slbl_meta_samples,meta_oversample,use_meta_distribution')
        store_attr('meta_dropout_remove,meta_dropout_replace,return_scores')

        self.data_meta_scores, self.lbl_meta_scores = None, None
        if use_meta_distribution or return_scores: self._store_scores()

    def _store_scores(self):
        def get_scores(matrix:sp.csr_matrix, use_meta_distribution:bool):
            if matrix is not None:
                if use_meta_distribution:
                    matrix = matrix / (matrix.sum(axis=1) + 1e-9)
                    matrix = matrix.tocsr()
                return [o.data.tolist() for o in matrix]
                
        self.data_meta_scores = get_scores(self.data_meta, self.use_meta_distribution)
        self.lbl_meta_scores = get_scores(self.lbl_meta, self.use_meta_distribution)
        
    def get_data_meta(self, idxs:List):
        x, prefix = dict(), f'{self.prefix}2data'
        o = Sampler.extract_items(prefix, self.curr_data_meta, idxs, self.n_data_meta_samples, self.n_sdata_meta_samples, 
                                  self.meta_oversample, self.meta_info, self.meta_info_keys, self.use_meta_distribution, 
                                  self.data_meta_scores, dropout_remove=self.meta_dropout_remove, 
                                  dropout_replace=self.meta_dropout_replace, return_scores=self.return_scores)
        x.update(o)
        return x
        
    def get_lbl_meta(self, idxs:List):
        if self.curr_lbl_meta is None: return {}
        x, prefix = dict(), f'{self.prefix}2lbl'
        o = Sampler.extract_items(prefix, self.curr_lbl_meta, idxs, self.n_lbl_meta_samples, self.n_slbl_meta_samples, 
                                  self.meta_oversample, self.meta_info, self.meta_info_keys, self.use_meta_distribution, 
                                  self.lbl_meta_scores, dropout_remove=self.meta_dropout_remove, 
                                  dropout_replace=self.meta_dropout_replace, return_scores=self.return_scores)
        x.update(o)
        return x
        

#### Example

In [69]:
train_meta = SMetaXCDataset(**meta_data, n_sdata_meta_samples=2, n_slbl_meta_samples=2)

In [70]:
train_meta.meta_oversample = True
train_meta.n_sdata_meta_samples = 3
train_meta.n_slbl_meta_samples = 3

In [71]:
train_meta.get_data_meta([100, 200])

{'phlk2data_idx': tensor([  1058, 147261, 149012,  85726]),
 'phlk2data_data2ptr': tensor([3, 1]),
 'hlk2data_idx': tensor([  1058, 149012,   1058,  85726,  85726,  85726]),
 'hlk2data_data2ptr': tensor([3, 3]),
 'hlk2data_identifier': ['Category:Technology_neologisms',
  'Category:Component-based_software_engineering',
  'Category:Technology_neologisms',
  'Category:Geography_of_Africa',
  'Category:Geography_of_Africa',
  'Category:Geography_of_Africa'],
 'hlk2data_input_text': ['Technology neologisms',
  'Component-based software engineering',
  'Technology neologisms',
  'Geography of Africa',
  'Geography of Africa',
  'Geography of Africa'],
 'hlk2data_input_ids': tensor([[  101,  2974,  9253, 21197, 22556,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0],
         [  101,  6922,  1011,  2241,  4007,  3330,   102,     0,     0,     0,
        

In [72]:
train_meta.get_lbl_meta([101, 201])

{'phlk2lbl_idx': tensor([3384]),
 'phlk2lbl_lbl2ptr': tensor([1, 0]),
 'hlk2lbl_idx': tensor([3384, 3384, 3384]),
 'hlk2lbl_lbl2ptr': tensor([3, 0]),
 'hlk2lbl_identifier': ['Category:Animation_techniques',
  'Category:Animation_techniques',
  'Category:Animation_techniques'],
 'hlk2lbl_input_text': ['Animation techniques',
  'Animation techniques',
  'Animation techniques'],
 'hlk2lbl_input_ids': tensor([[ 101, 7284, 5461,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0],
         [ 101, 7284, 5461,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0],
         [ 101, 7284, 5461,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0

### `SXCDataset`

In [80]:
#| export
class SXCDataset(XCDataset):

    def __init__(self, data:SMainXCDataset, **kwargs):
        super().__init__()
        self.data, self.meta = data, MetaXCDatasets({k:kwargs[k] for k in self.get_meta_args(**kwargs) if isinstance(kwargs[k], SMetaXCDataset)})
        self._verify_inputs()
        
    @classmethod
    @delegates(SMainXCDataset.from_file)
    def from_file(cls, **kwargs):
        data = SMainXCDataset.from_file(**kwargs)
        meta_kwargs = {o:kwargs.pop(o) for o in cls.get_meta_args(**kwargs)}

        meta = dict()
        for k,v in meta_kwargs.items():
            input_kwargs = {p:q.get(k,None) if isinstance(q, dict) else q for p,q in kwargs.items()}
            for o in v: input_kwargs.pop(o, None)
            meta[k] = SMetaXCDataset.from_file(**v, **input_kwargs)  
        # meta = {k:SMetaXCDataset.from_file(**v, **kwargs) for k,v in meta_kwargs.items()}
        
        return cls(data, **meta)
        
    def __getitems__(self, idxs:List):
        x = self.data.__getitems__(idxs)
        if self.n_meta:
            for meta in self.meta.values():
                x.update(meta.get_data_meta(idxs))
                if self.n_lbl:
                    z = meta.get_lbl_meta(x['lbl2data_idx'])
                    if len(z):
                        z[f'{meta.prefix}2lbl_data2ptr'] = torch.tensor([o.sum() for o in z[f'{meta.prefix}2lbl_lbl2ptr'].split_with_sizes(x[f'lbl2data_data2ptr'].tolist())])
                        z[f'p{meta.prefix}2lbl_data2ptr'] = torch.tensor([o.sum() for o in z[f'p{meta.prefix}2lbl_lbl2ptr'].split_with_sizes(x[f'lbl2data_data2ptr'].tolist())])
                    x.update(z)
        return x

    # =========== Operations ===========
    
    def get_one_hop_metadata(self, batch_size:Optional[int]=1024, thresh:Optional[int]=10, topk:Optional[int]=10, **kwargs):
        data_lbl = Graph.threshold_on_degree(self.data.data_lbl, thresh=thresh)
        data_meta, lbl_meta = Graph.one_hop_matrix(data_lbl, batch_size=batch_size, topk=topk, do_normalize=True)
        self.meta['ohm_meta'] = SMetaXCDataset(prefix='ohm', data_meta=data_meta, lbl_meta=lbl_meta, 
                                               meta_info=self.data.lbl_info, **kwargs)

    def get_random_walk_metadata(self, batch_size:Optional[int]=1024, walk_to:Optional[int]=100, prob_reset:Optional[float]=0.8, 
                                 topk_thresh:Optional[int]=10, degree_thresh=20, **kwargs):
        data_meta = perform_random_walk(data_lbl, batch_size=batch_size, walk_to=walk_to, prob_reset=prob_reset, 
                                        n_hops=1, thresh=degree_thresh, topk=topk_thresh, do_normalize=True)
        lbl_meta = perform_random_walk(data_lbl.transpose().tocsr(), batch_size=batch_size, walk_to=walk_to, 
                                       prob_reset=prob_reset, n_hops=2, thresh=degree_thresh, topk=topk_thresh, do_normalize=True)
        self.meta['rnw_meta'] = SMetaXCDataset(prefix='rnw', data_meta=data_meta, lbl_meta=lbl_meta,
                                               meta_info=self.data.lbl_info, **kwargs)

    def get_random_walk_with_matrices_metadata(self, meta_name:str, batch_size:Optional[int]=1024, walk_to:Optional[int]=100, 
                                               prob_reset:Optional[float]=0.8, topk_thresh:Optional[int]=10, data_degree_thresh=20, 
                                               lbl_degree_thresh=20, **kwargs):
        if f'{meta_name}_meta' not in self.meta: raise ValueError(f'Invalid metadata: {meta_name}')
        data_meta, lbl_meta = self.meta[f'{meta_name}_meta'].data_meta, self.meta[f'{meta_name}_meta'].lbl_meta
        data_rnw = perform_random_walk_with_matrices(data_meta, lbl_meta, batch_size=batch_size, walk_to=walk_to, prob_reset=prob_reset, 
                                                     n_hops=2, data_thresh=data_degree_thresh, lbl_thresh=lbl_degree_thresh, 
                                                     topk=topk_thresh, do_normalize=True)
        lbl_rnw = perform_random_walk_with_matrices(lbl_meta, data_meta, batch_size=batch_size, walk_to=walk_to, prob_reset=prob_reset, 
                                                    n_hops=3, data_thresh=data_degree_thresh, lbl_thresh=lbl_degree_thresh, 
                                                    topk=topk_thresh, do_normalize=True)
        self.meta['rnw_meta'] = SMetaXCDataset(prefix='rnw', data_meta=data_rnw, lbl_meta=lbl_rnw, meta_info=self.data.lbl_info, **kwargs)
        
    @staticmethod
    def combine_info(info_1:Dict, info_2:Dict, pad_token:int=0):
        comb_info = dict()
        for k,v in info_1.items():
            if isinstance(v, tuple) or isinstance(v, list): comb_info[k] = v + info_2[k]
            elif isinstance(v, torch.Tensor):
                n_data = v.shape[0] + info_2[k].shape[0]
                seq_len = max(v.shape[1], info_2[k].shape[1]) 
                
                if k == 'input_ids': 
                    info = torch.full((n_data, seq_len), pad_token, dtype=v.dtype)
                elif k == 'attention_mask': 
                    info = torch.full((n_data, seq_len), 0, dtype=v.dtype)
                    
                info[:v.shape[0], :v.shape[1]] = v
                info[v.shape[0]:, :info_2[k].shape[1]] = info_2[k]
                
                comb_info[k] = info
        return comb_info

    def _get_main_dataset(
        self,
        data_info:Dict, 
        data_lbl:Optional[sp.csr_matrix]=None, 
        lbl_info:Optional[Dict]=None, 
        data_lbl_filterer:Optional[Union[sp.csr_matrix,np.array]]=None, 
        **kwargs
    ):
        dset = self.data._get_dataset(data_info, data_lbl, lbl_info, data_lbl_filterer, **kwargs)
        return SXCDataset(dset)
        
    def combine_lbl_and_meta(self, meta_name:str, pad_token:int=0, p_data=0.5, **kwargs): 
        if f'{meta_name}_meta' not in self.meta: raise ValueError(f'Invalid metadata: {meta_name}')
            
        data_lbl = self.data.data_lbl
        data_lbl = data_lbl.multiply(1/(data_lbl.getnnz(axis=1).reshape(-1, 1) + 1e-9))
        data_lbl = data_lbl.tocsr() * p_data
        
        data_meta = self.meta[f'{meta_name}_meta'].data_meta
        data_meta = data_meta.multiply(1/(data_meta.getnnz(axis=1).reshape(-1, 1) + 1e-9))
        data_meta = data_meta.tocsr() * (1 - p_data)
        
        lbl_info = self.data.lbl_info
        meta_info = self.meta[f'{meta_name}_meta'].meta_info

        comb_info = self.combine_info(lbl_info, meta_info, pad_token)
        
        return self._get_main_dataset(self.data.data_info, sp.hstack([data_lbl, data_meta]), comb_info, 
                                      self.data.data_lbl_filterer, **kwargs)

    def combine_data_and_meta(self, meta_name:str, pad_token:int=0, **kwargs):
        if f'{meta_name}_meta' not in self.meta: raise ValueError(f'Invalid metadata: {meta_name}')
        
        data_lbl, meta_lbl = self.data.data_lbl, self.meta[f'{meta_name}_meta'].lbl_meta.transpose().tocsr()
        assert data_lbl.shape[1] == meta_lbl.shape[1], f"Incompatible metadata shape: {meta_lbl.shape}"

        data_info = self.data.data_info
        meta_info = self.meta[f'{meta_name}_meta'].meta_info
        comb_info = self.combine_info(data_info, meta_info, pad_token)

        dset = self._get_main_dataset(comb_info, sp.vstack([data_lbl, meta_lbl]), self.data.lbl_info, 
                                      self.data.data_lbl_filterer, **kwargs)
        valid_idx = np.where(dset.data.data_lbl.getnnz(axis=1) > 0)[0]
        return dset._getitems(valid_idx)

    @staticmethod
    def get_combined_data_and_meta(dset, meta_lbl:sp.csr_matrix, meta_info:Dict, pad_token:int=0, **kwargs):    
        data_lbl = dset.data.data_lbl
        assert data_lbl.shape[1] == meta_lbl.shape[1], f"Incompatible metadata shape: {meta_lbl.shape}"
        
        data_info = dset.data.data_info
        comb_info = dset.combine_info(data_info, meta_info, pad_token)
        
        dset = dset._get_main_dataset(comb_info, sp.vstack([data_lbl, meta_lbl]), dset.data.lbl_info, 
                                      dset.data.data_lbl_filterer, **kwargs)
        valid_idx = np.where(dset.data.data_lbl.getnnz(axis=1) > 0)[0]
        return dset._getitems(valid_idx)
        
        

#### Example

In [None]:
train_dset = SXCDataset(train_main, hlk_meta=train_meta)

In [None]:
bb = train_dset.__getitems__([100, 200, 500])

In [None]:
train_dl = DataLoader(train_dset, batch_size=10, collate_fn=identity_collate_fn)
batch = next(iter(train_dl))

In [None]:
list(batch)

['data_idx',
 'data_identifier',
 'data_input_text',
 'data_input_ids',
 'data_token_type_ids',
 'data_attention_mask',
 'plbl2data_idx',
 'plbl2data_data2ptr',
 'lbl2data_idx',
 'lbl2data_data2ptr',
 'lbl2data_identifier',
 'lbl2data_input_text',
 'lbl2data_input_ids',
 'lbl2data_token_type_ids',
 'lbl2data_attention_mask',
 'phlk2data_idx',
 'phlk2data_data2ptr',
 'hlk2data_idx',
 'hlk2data_data2ptr',
 'hlk2data_identifier',
 'hlk2data_input_text',
 'hlk2data_input_ids',
 'hlk2data_token_type_ids',
 'hlk2data_attention_mask',
 'phlk2lbl_idx',
 'phlk2lbl_lbl2ptr',
 'hlk2lbl_idx',
 'hlk2lbl_lbl2ptr',
 'hlk2lbl_identifier',
 'hlk2lbl_input_text',
 'hlk2lbl_input_ids',
 'hlk2lbl_token_type_ids',
 'hlk2lbl_attention_mask',
 'hlk2lbl_data2ptr',
 'phlk2lbl_data2ptr']

### `SBaseXCDataBlock`

In [81]:
#| export
class SBaseXCDataBlock(BaseXCDataBlock):
    
    @classmethod
    @delegates(SXCDataset.from_file)
    def from_file(cls, collate_fn:Callable=identity_collate_fn, **kwargs):
        return cls(SXCDataset.from_file(**kwargs), collate_fn, **kwargs)
        

#### Example

In [None]:
train_block = SBaseXCDataBlock(train_dset, batch_size=2)

### `SXCDataBlock`

In [85]:
#| export
class SXCDataBlock(XCDataBlock):

    @staticmethod
    def inference_dset(data_info:Dict, data_lbl:sp.csr_matrix, lbl_info:Dict, data_lbl_filterer, 
                       **kwargs):
        x_idx = np.where(data_lbl.getnnz(axis=1) == 0)[0].reshape(-1,1)
        y_idx = np.zeros((len(x_idx),1), dtype=np.int64)
        data_lbl[x_idx, y_idx] = 1
        data_lbl_filterer = np.hstack([x_idx, y_idx]) if data_lbl_filterer is None else np.vstack([np.hstack([x_idx, y_idx]), data_lbl_filterer])
    
        pred_dset = SXCDataset(SMainXCDataset(data_info=data_info, data_lbl=data_lbl, lbl_info=lbl_info,
                                              data_lbl_filterer=data_lbl_filterer, **kwargs))
        return pred_dset
    
    @classmethod
    def from_cfg(
        cls, 
        cfg:Union[str,Dict],
        collate_fn:Optional[Callable]=identity_collate_fn,
        valid_pct:Optional[float]=0.2,
        seed=None,
        **kwargs,
    ):
        if isinstance(cfg, str): cfg = cls.load_cfg(cfg)

        blocks = dict()
        for o in ['train', 'valid', 'test']:
            if o in cfg['path']:
                params = cfg['parameters'].copy()
                params.update(kwargs)
                if o != 'train': 
                    params['meta_dropout_remove'], params['meta_dropout_replace'] = None, None
                blocks[o] = SBaseXCDataBlock.from_file(**cfg['path'][o], **params, collate_fn=collate_fn)
                
        return cls(**blocks)
        

#### Example

In [None]:
from xcai.config import WIKISEEALSOTITLES

In [None]:
block = SXCDataBlock(train=train_block)

In [None]:
linker_dset = block.linker_dset('hlk_meta')

In [None]:
batch = next(iter(block.train.dl))

In [None]:
batch.keys()

In [None]:
config = WIKISEEALSOTITLES('/home/scai/phd/aiz218323/Projects/XC/data')['data_lnk']

In [None]:
print(config['parameters'])

{'transform_type': 'xc', 'smp_features': [('lbl2data', 1, 2), ('hlk2data', 1, 1), ('hlk2lbl2data', 2, 1)], 'pad_token': 0, 'oversample': False, 'sampling_features': [('lbl2data', 2), ('hlk2data', 1), ('hlk2lbl2data', 1)], 'num_labels': 1, 'num_metadata': 1, 'metadata_name': None, 'info_column_names': ['identifier', 'input_text'], 'use_tokenizer': True, 'tokenizer': 'bert-base-cased', 'tokenization_column': 'input_text', 'max_sequence_length': 32, 'padding': False, 'return_tensors': None, 'sep': '->', 'prompt_func': None, 'pad_side': 'right', 'drop': True, 'ret_t': True, 'in_place': True, 'collapse': True, 'device': 'cpu', 'inp': 'data', 'targ': 'lbl2data', 'ptr': 'lbl2data_data2ptr', 'n_lbl_samples': None, 'data_info_keys': None, 'lbl_info_keys': None, 'n_slbl_samples': 1, 'main_oversample': False, 'n_data_meta_samples': 1, 'n_lbl_meta_samples': 1, 'meta_info_keys': None, 'meta_oversample': False}


In [None]:
params = {'return_tensors':'pt', 'padding':True}

for k,v in params.items():
    config['parameters'][k] = v

In [None]:
block = SXCDataBlock.from_cfg(config)

In [None]:
batch = block.train.dset.__getitems__([100, 200])

In [None]:
batch

{'data_idx': tensor([100, 200]),
 'data_identifier': ['Applet', 'Geography_of_Africa'],
 'data_input_text': ['Applet', 'Geography of Africa'],
 'data_input_ids': tensor([[  101,  7302,  1204,   102,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101, 20678,  1104,  2201,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0]]),
 'data_token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0]]),
 'data_attention_mask': tensor([[1, 1, 1, 1, 0, 