In [None]:
#| default_exp models.MMM00X

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, re, inspect
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any
from transformers import (
    BertLMHeadModel, 
    BatchEncoding, 
    BertPreTrainedModel, 
    BertModel, 
    RobertaForCausalLM, 
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput

from fastcore.meta import *

from xcai.losses import *

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()
from xcai.block import *

## Setup

In [None]:
# block = XCBlock.from_cfg('train', tokz='bert-base-uncased')
block = XCBlock.from_cfg('train', tokz='distilbert/distilbert-base-uncased')

  self._set_arrayXarray(i, j, x)


In [None]:
bsz = 20
batch = block.train.one_batch(bsz)

In [None]:
batch.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask'])

## Output

In [None]:
#| export
@dataclass
class XCModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    lm_loss: Optional[torch.FloatTensor] = None
    dr_loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    

## Pooling

In [None]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)


## XCModel

In [None]:
class XCModel(BertLMHeadModel):

    def __init__(self, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        

## BT0001

In [None]:
#| export
class BT0001(BertLMHeadModel):

    def __init__(self, cfg):
        super().__init__(cfg)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o = self.bert(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        data_logits = self.cls(data_o[0])
        data_repr = data_o[0].mean(dim=1)
        
        if lbl2data_input_ids is not None and lbl2data_data2ptr is not None:
            lbl2data_o = self.bert(
                lbl2data_input_ids,
                lbl2data_attention_mask,
                lbl2data_token_type_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict
            )
            lbl2data_repr = lbl2data_o[0].mean(dim=1)
            return data_logits, lbl2data_input_ids, lbl2data_data2ptr, lbl2data_idx, data_repr, lbl2data_repr, kwargs

        return data_logits, lbl2data_input_ids, lbl2data_data2ptr, lbl2data_idx, kwargs
        

In [None]:
m = BT0001.from_pretrained('bert-base-uncased')

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [None]:
out = m(**batch)

In [None]:
for o in out: print(o.shape)

torch.Size([20, 18, 30522])
torch.Size([37, 16])
torch.Size([20])
torch.Size([20, 768])
torch.Size([37, 768])


## BT0002

In [None]:
#| export
class BT0002(BertLMHeadModel):
    use_generation,use_representation = True,False 

    def __init__(self,
                 config,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o = self.bert(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        data_logits = self.cls(data_o[0])
        
        loss = None
        if lbl2data_input_ids is not None:
            loss = self.loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)

        if not return_dict:
            o = (data_logits,) + data_o[2:]
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            logits=data_logits,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
            data_cross_attentions=data_o.cross_attentions,
        )
        

### Example

In [None]:
m = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(13.2184, device='cuda:0', grad_fn=<SumBackward0>)

## BT0003

In [None]:
#| export
class BT0003(BertPreTrainedModel):
    use_generation,use_representation = False,True
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.8,
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.bert = BertModel(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, ig_tok=ig_tok, reduce='mean')
        self.post_init()

    @delegates(BertModel.__call__)
    def get_repr(self, 
                 input_ids:Optional[torch.Tensor]=None, 
                 attention_mask:Optional[torch.Tensor]=None,
                 token_type_ids:Optional[torch.Tensor]=None,
                 **kwargs):
        o = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            **kwargs
        )
        return o, Pooling.mean_pooling(o[0], attention_mask)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_repr = self.get_repr(data_input_ids, 
                                          data_attention_mask, 
                                          data_token_type_ids, 
                                          output_attentions=output_attentions, 
                                          output_hidden_states=output_hidden_states,
                                          return_dict=return_dict)
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_repr(lbl2data_input_ids, 
                                                      lbl2data_attention_mask, 
                                                      lbl2data_token_type_ids, 
                                                      output_attentions=output_attentions, 
                                                      output_hidden_states=output_hidden_states,
                                                      return_dict=return_dict)
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [None]:
m = BT0003.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

Some weights of BT0003 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.v']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
b.keys()

dict_keys(['lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.data_repr.shape

torch.Size([20, 768])

## BT0004

In [None]:
#| export
class BT0004(BertLMHeadModel):
    use_generation,use_representation = True,True 

    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 lw:Optional[int]=0.5,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.lw, self.dr_loss_fn = lw, SoupCon(bsz=bsz, reduce='mean')
        self.lm_loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')
        
    @delegates(BertModel.__call__)
    def get_repr(self, 
                 input_ids:Optional[torch.Tensor]=None, 
                 attention_mask:Optional[torch.Tensor]=None,
                 token_type_ids:Optional[torch.Tensor]=None,
                 **kwargs):
        o = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            **kwargs
        )
        return o, F.normalize(o[0].mean(dim=1), dim=1)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_token_type_ids:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_token_type_ids:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_repr = self.get_repr(
            data_input_ids,
            data_attention_mask,
            data_token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        data_logits = self.cls(data_o[0])
        
        loss, lm_loss, dr_loss, lbl2data_repr = None, None, None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_repr(
                lbl2data_input_ids,
                lbl2data_attention_mask,
                lbl2data_token_type_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            
            lm_loss = self.lm_loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr)
            dr_loss = self.dr_loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, **kwargs)
            loss = lm_loss + self.lw*dr_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
            data_cross_attentions=data_o.cross_attentions,
        )
        

### Example

In [None]:
b = prepare_batch(m, batch)

In [None]:
bsz = b['data_input_ids'].shape[0]

In [None]:
m = BT0004.from_pretrained('bert-base-uncased', lw=0.5, bsz=bsz, tn_targ=10_000, ig_tok=0)
m, b = m.to('cuda'), b.to('cuda')
o = m(**b)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0004 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['dr_loss_fn.t', 'lm_loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## RT0005

In [None]:
#| export
class RT0005(RobertaForCausalLM):
    use_generation,use_representation = True,False 

    def __init__(self,
                 config,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o = self.roberta(
            data_input_ids,
            data_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        data_logits = self.lm_head(data_o[0])
        
        loss = None
        if lbl2data_input_ids is not None:
            loss = self.loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)

        if not return_dict:
            o = (data_logits,) + data_o[2:]
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            logits=data_logits,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
            data_cross_attentions=data_o.cross_attentions,
        )
        

### Example

In [None]:
m = RT0005.from_pretrained('roberta-base', tn_targ=10_000, ig_tok=0)

If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of RT0005 were not initialized from the model checkpoint at roberta-base and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(17.3047, device='cuda:0', grad_fn=<SumBackward0>)

## BT0006

In [None]:
#| export
class BT0006(BT0002):
    use_generation,use_representation = False,True
    
    def __init__(self, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.loss_fn = SoupCon(bsz=bsz, reduce='mean')
        

### Example

In [None]:
m = BT0006.from_pretrained('bert-base-uncased', bsz=bsz)

Some weights of BT0006 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.t']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
m, b = m.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(7.0530, device='cuda:0', grad_fn=<SumBackward0>)

## DBT007

In [None]:
#| export
class DBT007(DistilBertForMaskedLM):
    use_generation,use_representation = True,False 
    
    def __init__(self, 
                 config,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0):
        super().__init__(config)
        self.loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        data_o = self.distilbert(
            input_ids=data_input_ids,
            attention_mask=data_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hs = data_o[0]
        data_logits = self.vocab_transform(hs)
        data_logits = self.activation(data_logits)
        data_logits = self.vocab_layer_norm(data_logits)
        data_logits = self.vocab_projector(data_logits)

        loss = None
        if lbl2data_input_ids is not None:
            loss = self.loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)

        if not return_dict:
            o = (data_logits,) + data_o[2:]
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            logits=data_logits,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )


### Example

In [None]:
m = DBT007.from_pretrained('distilbert-base-uncased', ig_tok=0)

In [None]:
o = m(**batch)

In [None]:
o

XCModelOutput(loss=tensor(15.3677, grad_fn=<SumBackward0>), logits=tensor([[[ -6.1614,  -6.1267,  -6.1175,  ...,  -5.4209,  -5.2192,  -3.5195],
         [-13.9941, -14.0756, -13.8388,  ..., -12.1233, -11.9061,  -9.5880],
         [ -5.8230,  -5.9046,  -5.8874,  ...,  -5.2475,  -4.5833,  -5.0284],
         ...,
         [ -6.8535,  -7.2897,  -6.7291,  ...,  -6.0015,  -8.1558,  -4.2692],
         [ -6.6621,  -7.0612,  -6.3850,  ...,  -5.4423,  -8.1755,  -2.0556],
         [ -6.4021,  -6.5815,  -5.9944,  ...,  -4.7292,  -7.1593,  -2.0416]],

        [[ -6.0825,  -6.0533,  -6.0468,  ...,  -5.3831,  -5.2671,  -3.3423],
         [ -8.2382,  -8.3393,  -8.1442,  ...,  -7.4286,  -7.2308,  -5.8274],
         [-11.7931, -11.6936, -11.6976,  ..., -10.0092, -10.2047,  -7.6660],
         ...,
         [ -6.2854,  -6.4565,  -6.1392,  ...,  -6.3359,  -6.2492,  -5.3608],
         [ -6.8786,  -7.0738,  -6.7573,  ...,  -6.6838,  -7.0853,  -4.0643],
         [ -5.4824,  -5.5097,  -5.3867,  ...,  -5.1087, 

## DBT008

In [None]:
#| export
class DBT008(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.8,
                 ig_tok:Optional[int]=0,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.distilbert = DistilBertModel(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, ig_tok=ig_tok, reduce='mean')
        self.post_init()

    @delegates(BertModel.__call__)
    def get_repr(self, 
                 input_ids:Optional[torch.Tensor]=None, 
                 attention_mask:Optional[torch.Tensor]=None,
                 **kwargs):
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return o, Pooling.mean_pooling(o[0], attention_mask)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_repr = self.get_repr(data_input_ids, 
                                          data_attention_mask, 
                                          output_attentions=output_attentions, 
                                          output_hidden_states=output_hidden_states,
                                          return_dict=return_dict)
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = self.get_repr(lbl2data_input_ids, 
                                                      lbl2data_attention_mask,  
                                                      output_attentions=output_attentions, 
                                                      output_hidden_states=output_hidden_states,
                                                      return_dict=return_dict)
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [None]:
m = DBT008.from_pretrained('sentence-transformers/msmarco-distilbert-dot-v5', ig_tok=0)

In [None]:
o = m(**batch)

In [None]:
o

XCModelOutput(loss=tensor(0.4810, grad_fn=<SumBackward0>), logits=None, lm_loss=None, dr_loss=None, data_repr=tensor([[ 0.0372, -0.0094,  0.1272,  ...,  0.2118, -0.5175, -0.0071],
        [-0.3130, -0.2676,  0.2495,  ...,  0.4102, -0.4511, -0.0603],
        [-0.3817,  0.3488, -0.2542,  ...,  0.1895, -0.0678, -0.2233],
        ...,
        [ 0.1648,  0.1906,  0.5809,  ...,  0.3762,  0.0511, -0.3078],
        [ 0.0934,  0.6205,  0.0835,  ...,  0.1357,  0.1037, -0.2269],
        [-0.4758,  0.0118,  0.1701,  ...,  0.1738,  0.1395, -0.1342]],
       grad_fn=<DivBackward0>), lbl2data_repr=tensor([[-0.2418,  0.2394,  0.2001,  ...,  0.0800,  0.2445, -0.3035],
        [-0.0580,  0.3176, -0.1738,  ..., -0.0049, -0.1210, -0.1995],
        [-0.0683,  0.4309, -0.1003,  ..., -0.1094, -0.1286, -0.2189],
        ...,
        [-0.4602,  0.1432, -0.1938,  ..., -0.1724, -0.0637,  0.2525],
        [-0.2790,  0.4353,  0.2206,  ..., -0.0257, -0.2271,  0.0971],
        [-0.3544,  0.2026, -0.1465,  ...,  0.24