In [None]:
#| default_exp models.BT000X

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, re, inspect
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any
from transformers import BertLMHeadModel, BatchEncoding, BertPreTrainedModel, BertModel
from transformers.utils.generic import ModelOutput
from fastcore.meta import *

from xcai.test_utils import *
from xcai.losses import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [None]:
block = Test.from_cfg('train')

  self._set_arrayXarray(i, j, x)


In [None]:
bsz = 20
batch = block.train.one_batch(bsz)

In [None]:
batch.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

## Output

In [None]:
#| export
@dataclass
class XCModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    lm_loss: Optional[torch.FloatTensor] = None
    dr_loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    

## XCModel

In [None]:
class XCModel(BertLMHeadModel):

    def __init__(self, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        

## BT0001

In [None]:
#| export
class BT0001(BertLMHeadModel):

    def __init__(self, cfg):
        super().__init__(cfg)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o = self.bert(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        data_logits = self.cls(data_o[0])
        data_repr = data_o[0].mean(dim=1)
        
        if lbl2data_input_ids is not None and lbl2data_data2ptr is not None:
            lbl2data_o = self.bert(
                lbl2data_input_ids,
                lbl2data_attention_mask,
                lbl2data_token_type_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict
            )
            lbl2data_repr = lbl2data_o[0].mean(dim=1)
            return data_logits, lbl2data_input_ids, lbl2data_data2ptr, data_repr, lbl2data_repr

        return data_logits, lbl2data_input_ids, lbl2data_data2ptr
        

In [None]:
m = BT0001.from_pretrained('bert-base-uncased')

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [None]:
out = m(**batch)

In [None]:
for o in out: print(o.shape)

torch.Size([20, 18, 30522])
torch.Size([37, 16])
torch.Size([20])
torch.Size([20, 768])
torch.Size([37, 768])


## BT0002

In [None]:
#| export
class BT0002(BertLMHeadModel):
    use_generation,use_representation = True,False 

    def __init__(self,
                 config,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o = self.bert(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        data_logits = self.cls(data_o[0])
        
        loss = None
        if lbl2data_input_ids is not None:
            loss = self.loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)

        if not return_dict:
            o = (data_logits,) + data_o[2:]
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            logits=data_logits,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
            data_cross_attentions=data_o.cross_attentions,
        )
        

### Example

In [None]:
m = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(13.2184, device='cuda:0', grad_fn=<SumBackward0>)

## BT0003

In [None]:
#| export
class BT0003(BertPreTrainedModel):
    use_generation,use_representation = False,True
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.8,
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.bert = BertModel(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, ig_tok=ig_tok, reduce='mean')
        self.post_init()

    @delegates(BertModel.__call__)
    def get_repr(self, 
                 input_ids:Optional[torch.Tensor]=None, 
                 attention_mask:Optional[torch.Tensor]=None,
                 token_type_ids:Optional[torch.Tensor]=None,
                 **kwargs):
        o = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            **kwargs
        )
        return o, F.normalize(o[0].mean(dim=1), dim=1)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_repr = self.get_repr(data_input_ids, 
                                          data_attention_mask, 
                                          data_token_type_ids, 
                                          output_attentions=output_attentions, 
                                          output_hidden_states=output_hidden_states,
                                          return_dict=return_dict)
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_repr(lbl2data_input_ids, 
                                                      lbl2data_attention_mask, 
                                                      lbl2data_token_type_ids, 
                                                      output_attentions=output_attentions, 
                                                      output_hidden_states=output_hidden_states,
                                                      return_dict=return_dict)
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [None]:
m = BT0003.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

Some weights of BT0003 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.v']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
b.keys()

dict_keys(['lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.data_repr.shape

torch.Size([20, 768])

## BT0004

In [None]:
#| export
class BT0004(BertLMHeadModel):
    use_generation,use_representation = True,True 

    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 lw:Optional[int]=0.5,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.lw, self.dr_loss_fn = lw, SoupCon(bsz=bsz, reduce='mean')
        self.lm_loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')
        
    @delegates(BertModel.__call__)
    def get_repr(self, 
                 input_ids:Optional[torch.Tensor]=None, 
                 attention_mask:Optional[torch.Tensor]=None,
                 token_type_ids:Optional[torch.Tensor]=None,
                 **kwargs):
        o = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            **kwargs
        )
        return o, F.normalize(o[0].mean(dim=1), dim=1)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_repr = self.get_repr(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        data_logits = self.cls(data_o[0])
        
        loss, lm_loss, dr_loss, lbl2data_repr = None, None, None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_repr(
                lbl2data_input_ids,
                lbl2data_attention_mask,
                lbl2data_token_type_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            
            lm_loss = self.lm_loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
            dr_loss = self.dr_loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr)
            loss = lm_loss + self.lw*dr_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
            data_cross_attentions=data_o.cross_attentions,
        )
        

### Example

In [None]:
b = prepare_batch(m, batch)

In [None]:
bsz = b['data_input_ids'].shape[0]

In [None]:
m = BT0004.from_pretrained('bert-base-uncased', lw=0.5, bsz=bsz, tn_targ=10_000, ig_tok=0)
m, b = m.to('cuda'), b.to('cuda')
o = m(**b)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0004 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['dr_loss_fn.t', 'lm_loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
