In [1]:
#| default_exp models.BBB0XX

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [26]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import BertPreTrainedModel, BertModel

from transformers.activations import get_activation
from transformers.utils.generic import ModelOutput

from fastcore.meta import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [5]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [6]:
from xcai.block import *

In [7]:
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

## Setup

In [12]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets'

pkl_dir = f'{data_dir}/processed'
fname = f'{pkl_dir}/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

In [13]:
with open(fname, 'rb') as file: block = pickle.load(file)

In [14]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 3: break

In [15]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

## `BRT009`

In [34]:
#| export
class BRT009Encoder(BertPreTrainedModel):
    
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.bert = BertModel(config)
        self.activation = get_activation(config.hidden_act)
        
        self.dr_transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.dr_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dr_projector = nn.Linear(config.hidden_size, config.hidden_size)
        
    def init_dr_head(self):
        self.dr_transform.weight.data = torch.eye(self.dr_transform.out_features, self.dr_transform.in_features, 
                                                  dtype=self.dr_transform.weight.dtype)
        self.dr_projector.weight.data = torch.eye(self.dr_projector.out_features, self.dr_projector.in_features, 
                                                  dtype=self.dr_projector.weight.dtype)
        
    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0])
        rep = self.activation(rep)
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
    

In [36]:
#| export
class BRT009(BertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.bert"]
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = BRT009Encoder(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.bert = self.encoder.bert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)

        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )

### Example

In [37]:
m = BRT009.from_pretrained('bert-base-uncased', bsz=1024, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                           use_encoder_parallel=False)

Some weights of BRT009 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [38]:
m.init_dr_head()

In [42]:
o = m(**batch)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [43]:
o.loss

tensor(0.0148, grad_fn=<DivBackward0>)