# RADGA Model Architecture

In [1]:
#| default_exp 37-radga-model-architecture

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.MMM0XX import DBT013

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
os.environ['WANDB_PROJECT']='xc-nlg_37-radga-model-architecture'

## Data

In [12]:
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data_metas', valid_pct=0.001, 
                         tfm='xcnlg', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data',1,2), ('cat2data',1,1), ('hlk2data',1,3)], 
                         n_data_meta_samples=50)

  self._set_arrayXarray(i, j, x)


In [15]:
pkl_dir = '/home/aiscuser/scratch/datasets/processed/'
with open(f'{pkl_dir}/wikiseealso-radga.pkl', 'wb') as file: pickle.dump(block, file)

In [12]:
pkl_dir = '/home/aiscuser/scratch/datasets/processed/'
with open(f'{pkl_dir}/wikiseealso-radga.pkl', 'rb') as file: block = pickle.load(file)

In [11]:
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data_meta', valid_pct=0.001, 
                         tfm='xcnlg', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data',1,2), ('hlk2data',1,3)], n_data_meta_samples=50)

  self._set_arrayXarray(i, j, x)


In [13]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/37-radga-model-architecture',
    logging_first_step=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.1,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    group_by_cluster=True,
    num_clustering_warmup_epochs=1,
    num_cluster_update_epochs=1,
    num_cluster_size_update_epochs=1,
    clustering_type='EXPO',
    minimum_cluster_size=1,
    maximum_cluster_size=300,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1_REPR',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    fp16=True,
    label_names=['cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'cat2data_idx', 
                 'pcat2data_idx', 'pcat2data_data2ptr',
                 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr', 'hlk2data_idx', 
                ],
)



In [19]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/37-radga-model-architecture',
    logging_first_step=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.1,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    group_by_cluster=True,
    num_clustering_warmup_epochs=1,
    num_cluster_update_epochs=1,
    num_cluster_size_update_epochs=1,
    clustering_type='EXPO',
    minimum_cluster_size=1,
    maximum_cluster_size=300,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1_REPR',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    fp16=True,
    label_names=['hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr', 'hlk2data_idx'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## Model

In [20]:
#| export
from xcai.models.MMM0XX import Pooling, XCModelOutput

In [34]:
#| export
class DBT018(DBT013):
    
    @delegates(DBT013.__init__)
    def __init__(self, config, aug_meta_prefix:Optional[List]=None, tn_meta:Optional[int]=None, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('aug_meta_prefix')
        self.fuser = Fuser(config)
        self.o = torch.ones(tn_meta, dtype=torch.long, device=self.device) if tn_meta is not None else None
        
    def _get_meta_inputs(self, meta_prefix:Optional[str], **kwargs):
        inputs = {}
        for t in [o for o in kwargs if meta_prefix is not None and re.match(f'^[p]?{meta_prefix}.*', o)]:
            p,q = t.split('_', maxsplit=1)
            if t[0] == 'p': inputs.setdefault(p[1:], {})[f'p{q}'] = kwargs[t]
            else: inputs.setdefault(p, {})[q] = kwargs[t]
        return inputs
    
    def get_output(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None, 
        **kwargs
    ):
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return o[0]
    
    def resize_meta(self, meta:torch.Tensor, mask:torch.Tensor, n_data2meta:torch.Tensor):
        bsz, dim, tn_data2meta = n_data2meta.shape[0], meta.shape[-1], meta.shape[0]
        self.o = self.o.to(meta.device)
        o = (
            torch.ones(tn_data2meta, dtype=torch.long, device=meta.device) 
            if self.o is None or len(self.o) < tn_data2meta else self.o[:tn_data2meta]
        )

        max_n_data2meta = n_data2meta.max()
        xn_data2meta = max_n_data2meta-n_data2meta+1

        data2meta_ptr = n_data2meta.cumsum(dim=0)-1
        r_data2meta = o.scatter(0, data2meta_ptr, xn_data2meta)

        xmeta,xmask = meta.repeat_interleave(r_data2meta, dim=0),mask.repeat_interleave(r_data2meta, dim=0)
        m = o.scatter(0, data2meta_ptr, 0).repeat_interleave(r_data2meta, dim=0).view(bsz, -1)
        m[:, -1] = 1; m = m.view(-1, 1)
        xmask *= m

        return xmeta.view(bsz, -1, dim),xmask.view(bsz, -1)
    
    def get_meta_fused_output(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None, 
        **kwargs
    ):
        data_h,data_mh = self.get_output(input_ids, attention_mask), None
        
        meta_inputs = self._get_meta_inputs(self.aug_meta_prefix, **kwargs)
        for m in meta_inputs.values():
            valid_idx = torch.where(m['data2ptr'] > 0)[0]
            dmh = data_h.clone()
            if len(valid_idx):
                meta_h = self.get_output(m['input_ids'], m['attention_mask'])
                meta_h, meta_attention_mask = self.resize_meta(meta_h, m['attention_mask'], m['data2ptr'][valid_idx])
                dmh[valid_idx] = self.fuser(data_h[valid_idx], meta_h, attention_mask[valid_idx], meta_attention_mask)[0]
            data_mh = dmh if data_mh is None else data_mh+dmh
            
        data_mh = data_h if data_mh is None else data_mh/len(meta_inputs)
        return (data_mh,)
        
    def get_meta_fused_genrep(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None, 
        **kwargs
    ):    
        o = self.get_meta_fused_output(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        torch.cuda.empty_cache()
        
        rep = self.dr_transform(o[0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        rep = F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
        
        logits = self.vocab_transform(o[0])
        logits = self.activation(logits)
        logits = self.vocab_layer_norm(logits)
        logits = self.vocab_projector(logits)
        return o,logits,rep
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_logits, data_repr = self.get_meta_fused_genrep(data_input_ids, data_attention_mask, **kwargs)
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_representation(lbl2data_input_ids, lbl2data_attention_mask)
            lm_loss = self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss = dr_loss + self.lw*lm_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o[0],
        )


In [35]:
#| export
class DBT019(DBT018):
    
    @delegates(DBT018.__init__)
    def __init__(self, config, m_lw:Optional[float]=0.2, pred_meta_prefix:Optional[List]=None, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('m_lw,pred_meta_prefix')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_logits, data_repr = self.get_meta_fused_genrep(data_input_ids, data_attention_mask, **kwargs)
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_representation(lbl2data_input_ids, lbl2data_attention_mask)
            lm_loss = self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss = dr_loss + self.lw*lm_loss
            
            meta_inputs = self._get_meta_inputs(self.pred_meta_prefix, **kwargs)
            m_lw = self.m_lw/len(meta_inputs)
            for m in meta_inputs.values():
                valid_idx = torch.where(m['data2ptr'])[0]
                if len(valid_idx) > 0:
                    o, rep = self.get_representation(m['input_ids'], m['attention_mask'])
                    m_lml = self.gen_lfn(data_logits[valid_idx], m['input_ids'], m['data2ptr'][valid_idx], **kwargs)
                    m_drl = self.rep_lfn(data_repr[valid_idx], rep, m['data2ptr'][valid_idx], m['idx'], 
                                         m['pdata2ptr'][valid_idx], m['pidx'], **kwargs)
                    loss += m_lw * (m_drl + self.lw*m_lml)
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o[0],
        )
    

## Training

In [23]:
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [1]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT018.from_pretrained('distilbert-base-uncased', ig_tok=0, bsz=bsz, tn_targ=10_000, tn_meta=10_000, 
                               margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, lw=0.01,
                               aug_meta_prefix='hlk', init_drh=True)

NameError: name 'args' is not defined

In [24]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT019.from_pretrained('distilbert-base-uncased', ig_tok=0, bsz=bsz, tn_targ=1000, tn_meta=100, 
                               margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, lw=0.01, m_lw=0.1, 
                               pred_meta_prefix='cat', aug_meta_prefix='hlk', init_drh=True)

Some weights of DBT019 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['dr_layer_norm.bias', 'dr_layer_norm.weight', 'dr_projector.bias', 'dr_projector.weight', 'dr_transform.bias', 'dr_transform.weight', 'fuser.k.bias', 'fuser.k.weight', 'fuser.layer_norm.bias', 'fuser.layer_norm.weight', 'fuser.o.bias', 'fuser.o.weight', 'fuser.q.bias', 'fuser.q.weight', 'fuser.v.bias', 'fuser.v.weight', 'gen_lfn.o', 'rep_lfn.u', 'rep_lfn.v']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [26]:
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [37]:
o = model(**b.to(model.device))

In [38]:
o.loss

tensor(0.1738, device='cuda:0', grad_fn=<AddBackward0>)

In [40]:
learn.train()