In [1]:
#| default_exp models.BBB0XX

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import BertPreTrainedModel, BertModel

from transformers.activations import get_activation
from transformers.utils.generic import ModelOutput

from fastcore.meta import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from xcai.block import *

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

## Setup

In [7]:
from xcai.main import *

In [8]:
data_dir = '/Users/suchith720/Projects/data/'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'google-bert/bert-base-uncased'

pkl_dir = f'{data_dir}/processed/'
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data_distilbert-base-uncased_sxc.joblib'

In [13]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_slbl_samples=2, do_build=False, 
                    main_oversample=True, meta_oversample=True, return_scores=True)

In [41]:
block.train.dset.meta['neg_meta'] = block.train.dset.meta['cat_meta']
block.train.dset.meta['neg_meta'].prefix = 'neg'

In [42]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 3: break

In [43]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_scores', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pneg2data_idx', 'pneg2data_data2ptr', 'neg2data_idx', 'neg2data_scores', 'neg2data_data2ptr', 'neg2data_identifier', 'neg2data_input_text', 'neg2data_input_ids', 'neg2data_attention_mask', 'pneg2lbl_idx', 'pneg2lbl_lbl2ptr', 'neg2lbl_idx', 'neg2lbl_scores', 'neg2lbl_lbl2ptr', 'neg2lbl_identifier', 'neg2lbl_input_text', 'neg2lbl_input_ids', 'neg2lbl_attention_mask', 'neg2lbl_data2ptr', 'pneg2lbl_data2ptr'])

## `BRT009`

In [27]:
#| export
class BRT009Encoder(BertPreTrainedModel):
    
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.bert = BertModel(config)
        self.activation = get_activation(config.hidden_act)
        
        self.dr_transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.dr_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dr_projector = nn.Linear(config.hidden_size, config.hidden_size)

    @torch.no_grad()
    def init_dr_head(self):
        torch.nn.init.eye_(self.dr_transform.weight)
        torch.nn.init.eye_(self.dr_projector.weight)
        
    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0])
        rep = self.activation(rep)
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
    

In [28]:
#| export
class BRT009(BertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.bert"]
    
    def __init__(
        self,
        config,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        use_encoder_parallel:Optional[bool]=True,
        *args, 
        **kwargs
    ):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = BRT009Encoder(config)
        self.loss_fn = MultiTriplet(margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.bert = self.encoder.bert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)

        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [23]:
m = BRT009.from_pretrained('bert-large-uncased', margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                           use_encoder_parallel=False)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of BRT009 were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
m.init_dr_head()

In [25]:
o = m(**batch)

In [26]:
o.loss

tensor(0.0286, grad_fn=<DivBackward0>)

## `BRT023`

In [35]:
#| export
class BRT023Encoder(BertPreTrainedModel):
    
    def __init__(
        self, 
        config,
        normalize:Optional[bool]=False,
        use_ln:Optional[bool]=False,
        *args, 
        **kwargs
    ):
        super().__init__(config, *args, **kwargs)
        store_attr('normalize,use_ln')
        self.bert = BertModel(config)
        
        self.transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.projector = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) if use_ln else None
        self.activation = get_activation(config.hidden_act)

    @torch.no_grad()
    def init_dr_head(self):
        torch.nn.init.eye_(self.transform.weight)
        torch.nn.init.eye_(self.projector.weight)
        
    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.transform(o[0])
        if self.use_ln: 
            rep = self.layer_norm(rep)
        rep = self.projector(rep)
        rep = Pooling.mean_pooling(rep, attention_mask)
        
        return o, F.normalize(rep, dim=1) if self.normalize else rep
        

In [36]:
#| export
class BRT023(BertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.bert"]
    
    def __init__(
        self,
        config,
        normalize:Optional[bool]=False,
        use_layer_norm:Optional[bool]=False,
        use_encoder_parallel:Optional[bool]=True,
        *args, **kwargs
    ):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = BRT023Encoder(config, normalize=normalize, use_ln=use_layer_norm)
        self.loss_fn = MarginMSEWithNegatives()
        self.post_init(); self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.bert = self.encoder.bert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_scores:Optional[torch.Tensor]=None,

        neg2data_input_ids:Optional[torch.Tensor]=None,
        neg2data_attention_mask:Optional[torch.Tensor]=None,
        neg2data_scores:Optional[torch.Tensor]=None,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr, neg2data_repr = None, None, None
        if (
            lbl2data_input_ids is not None and 
            neg2data_input_ids is not None
        ):
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)

            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_scores, neg2data_repr, neg2data_scores, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            neg2data_repr=neg2data_repr,
        )
        

### Example

In [47]:
m = BRT023.from_pretrained('bert-base-uncased', normalize=True, use_encoder_parallel=False)

Some weights of BRT023 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['encoder.projector.bias', 'encoder.projector.weight', 'encoder.transform.bias', 'encoder.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:
m.init_dr_head()

In [49]:
o = m(**batch)

In [50]:
o.loss

tensor(0.1506, grad_fn=<MseLossBackward0>)