In [1]:
#| default_exp models.PPP0XX

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    BertPreTrainedModel,
    BertLMHeadModel, 
    BatchEncoding, 
    BertPreTrainedModel, 
    BertModel, 
    RobertaForCausalLM, 
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput

from fastcore.meta import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from xcai.block import *

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

## Setup

In [8]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data', tfm='xcnlg', tokenizer='distilbert-base-uncased', smp_features=[('lbl2data',1,1)])

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

  self._set_arrayXarray(i, j, x)


In [None]:
block = XCBlock.from_cfg(data_dir, 'data_meta', valid_pct=0.001, tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data', 1, 1), ('cat2data', 1, 1)],
                         n_data_meta_samples=50, n_lbl_meta_samples=50)

  self._set_arrayXarray(i, j, x)


In [9]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_dir = f'{data_dir}/processed'

In [10]:
fname = f'{pkl_dir}/wikiseealso_data-meta_distilbert-base-uncased_rm_ramen-cat.pkl'

In [None]:
with open(fname, 'wb') as file: pickle.dump(block, file)

In [11]:
with open(fname, 'rb') as file: block = pickle.load(file)

In [3]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 3: break

In [None]:
batch.keys()

## Output

In [None]:
#| export
@dataclass
class XCModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    lm_loss: Optional[torch.FloatTensor] = None
    dr_loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    data_embed: Optional[torch.FloatTensor] = None
    lbl2data_embed: Optional[torch.FloatTensor] = None
    data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    lbl2data_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    

## Pooling

In [None]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)


## DBT007

In [None]:
#| export
class DBT007Encoder(DistilBertForMaskedLM):
    
    def __init__(self, config):
        super().__init__(config)
    
    def forward(
        self,
        input_ids:Optional[torch.Tensor]=None,
        attention_mask:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        data_o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        data_logits = self.vocab_transform(data_o[0])
        data_logits = self.activation(data_logits)
        data_logits = self.vocab_layer_norm(data_logits)
        data_logits = self.vocab_projector(data_logits)
        
        return data_o, data_logits


In [None]:
#| export
class DBT007(DistilBertForMaskedLM):
    use_generation,use_representation = True,False
    _tied_weights_keys = ["encoder.module.distilbert", "encoder.module.activation", "encoder.module.vocab_transform",
                          "encoder.module.vocab_layer_norm", "encoder.module.vocab_projector"]
    
    def __init__(self, 
                 config,
                 tn_targ:Optional[int]=None, 
                 ig_tok:Optional[int]=0,
                 vocab_weights:Optional[torch.Tensor]=None,
                 reduction:Optional[str]='mean',
                ):
        super().__init__(config)
        self.encoder = nn.DataParallel(DBT007Encoder(config))
        self.loss_fn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, vocab_weights=vocab_weights, 
                                         reduce=reduction)
        self.remap_post_init()
    
    def remap_post_init(self):
        self.encoder.module.activation = self.activation 
        self.encoder.module.distilbert= self.distilbert
        self.encoder.module.vocab_transform = self.vocab_transform
        self.encoder.module.vocab_layer_norm = self.vocab_layer_norm
        self.encoder.module.vocab_projector = self.vocab_projector 

    @delegates(DBT007Encoder.forward)
    def forward(
        self, 
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_logits = self.encoder(data_input_ids, data_attention_mask, output_attentions,
                                           output_hidden_states, return_dict, **kwargs)
        
        loss = None
        if lbl2data_input_ids is not None:
            loss = self.loss_fn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            
        if not return_dict:
            o = (data_logits,) + data_o[2:]
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            logits=data_logits,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )


### Example

In [None]:
m = DBT007.from_pretrained('distilbert-base-uncased', ig_tok=0)

In [None]:
o = m(**batch)

In [None]:
o.loss

tensor(15.7426, grad_fn=<SumBackward0>)

## DBT009

In [None]:
#| export
class DBT009Encoder(DistilBertPreTrainedModel):
    
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.distilbert = DistilBertModel(config)
        self.dr_transform = nn.Linear(config.dim, config.dim)
        self.dr_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dr_projector = nn.Linear(config.dim, config.dim)
        
    def init_dr_head(self):
        self.dr_transform.weight.data = torch.eye(self.dr_transform.out_features, self.dr_transform.in_features, 
                                                  dtype=self.dr_transform.weight.dtype)
        self.dr_projector.weight.data = torch.eye(self.dr_projector.out_features, self.dr_projector.in_features, 
                                                  dtype=self.dr_projector.weight.dtype)
        
    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
    

In [None]:
#| export
class DBT009(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = DBT009Encoder(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [None]:
m = DBT009.from_pretrained('distilbert-base-uncased', bsz=1024, margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, use_encoder_parallel=False)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
m.init_dr_head()

In [None]:
o = m(**batch)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss, o.data_repr.shape, o.lbl2data_repr.shape

(tensor(0.0550, grad_fn=<DivBackward0>),
 torch.Size([5, 768]),
 torch.Size([5, 768]))

In [None]:
o.loss

tensor(0.0550, grad_fn=<DivBackward0>)

## DBT010

In [None]:
#| export
class DBT010Encoder(DBT009Encoder):
    
    def __init__(self, config, repr_type:Optional[str]='pool', *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('repr_type')

    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        repr_type:Optional[str]=None,
        **kwargs
    ):
        if repr_type is None: repr_type = self.repr_type
        
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if repr_type == 'pool': rep = o[0]
        elif repr_type == 'cls': rep, attention_mask = o[0][:,0:1], attention_mask[:,0:1]
        elif repr_type == 'tok': rep, attention_mask = o[0][:,1:], attention_mask[:,1:]
        else: raise ValueError(f'Invalid representation type `repr_type`({repr_type}).')
        
        rep = self.dr_transform(rep)
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)

        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
        

In [None]:
#| export
class DBT010(DBT009):
    
    def __init__(self, config, repr_type:Optional['str']='pool', *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.encoder = DBT010Encoder(config, repr_type)
        self.post_init()
        self.remap_post_init()

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask,
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr, lbl2data_o = None, None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask, 
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)
            
        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_embed=data_o.last_hidden_state,
            lbl2data_embed=None if lbl2data_o is None else lbl2data_o.last_hidden_state,
        )
            

### Example

In [None]:
m = DBT010.from_pretrained('distilbert-base-uncased', bsz=1024, margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, 
                           repr_type='tok', use_encoder_parallel=False)
m.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
o = m(**b)

In [None]:
o.loss, o.data_repr.shape, o.lbl2data_repr.shape

(tensor(0.0539, grad_fn=<DivBackward0>),
 torch.Size([5, 768]),
 torch.Size([5, 768]))

In [None]:
o.loss

tensor(0.0539, grad_fn=<DivBackward0>)

## DBT011

In [None]:
#| export
class DBT011(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(self,
                 config,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = DBT009Encoder(config)
        self.loss_fn = Triplet(margin=margin, n_negatives=n_negatives, tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_idx, plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [None]:
m = DBT011.from_pretrained('distilbert-base-uncased', margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, use_encoder_parallel=False)
m.init_dr_head()

Some weights of DBT011 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**batch)

In [None]:
o.loss

tensor(0.0537, grad_fn=<MeanBackward0>)

## DBT012

In [None]:
#| export
class DBT012Encoder(DBT009Encoder):

    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        return o, F.log_softmax(Pooling.mean_pooling(rep, attention_mask), dim=-1)
        

In [None]:
#| export
class DBT012(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(self,
                 config,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 psi:Optional[float]=1.0,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = DBT012Encoder(config)
        self.loss_fn = Entropy(margin=margin, n_negatives=n_negatives, tau=tau, psi=psi, apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_idx, plbl2data_data2ptr, plbl2data_idx)
            
        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Examples

In [None]:
m = DBT012.from_pretrained('distilbert-base-uncased', margin=0.01, tau=10, n_negatives=10, apply_softmax=True, use_encoder_parallel=False)
m.init_dr_head()

Some weights of DBT012 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**batch)

In [None]:
o.loss

tensor(0.0013, grad_fn=<MeanBackward0>)

## DBT022

In [None]:
#| export
class DBT022(DBT009):
    
    def __init__(self, config, c_lw:Optional[float]=0.1, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.embed_loss, self.c_lw = Cosine(reduce='mean'), c_lw

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask,
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask, 
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss += self.c_lw * self.embed_loss(data_o[0], data_attention_mask, lbl2data_o[0], lbl2data_attention_mask)
            
        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )

### Example

In [None]:
m = DBT022.from_pretrained('distilbert-base-uncased', bsz=1024, margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, use_encoder_parallel=False)
m.init_dr_head()

Some weights of DBT022 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(m, batch)

In [None]:
o = m(**b)

In [None]:
o.loss, o.data_repr.shape, o.lbl2data_repr.shape

(tensor(0.1189, grad_fn=<AddBackward0>),
 torch.Size([5, 768]),
 torch.Size([5, 768]))

In [None]:
o.loss

tensor(0.1189, grad_fn=<AddBackward0>)

## DBT013

In [None]:
#| export
class DBT013Encoder(DistilBertForMaskedLM):
    
    def __init__(self, config):
        super().__init__(config)
        self.dr_transform = nn.Linear(config.dim, config.dim)
        self.dr_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dr_projector = nn.Linear(config.dim, config.dim)
        
    def init_dr_head(self):
        self.dr_transform.weight.data = torch.eye(self.dr_transform.out_features, self.dr_transform.in_features, 
                                                  dtype=self.dr_transform.weight.dtype)
        self.dr_projector.weight.data = torch.eye(self.dr_projector.out_features, self.dr_projector.in_features, 
                                                  dtype=self.dr_projector.weight.dtype)
        
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None, 
        **kwargs
    ):    
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        rep = F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
        
        logits = self.vocab_transform(o[0])
        logits = self.activation(logits)
        logits = self.vocab_layer_norm(logits)
        logits = self.vocab_projector(logits)
        
        return o,logits,rep
        

In [None]:
#| export
class DBT013(DistilBertForMaskedLM):
    use_generation,use_representation = True,True
    _tied_weights_keys = ["encoder.distilbert", "encoder.vocab_transform", "encoder.vocab_layer_norm", "encoder.vocab_projector"]
    
    def __init__(
        self, config,
        tn_targ:Optional[int]=None, 
        ig_tok:Optional[int]=0,
        bsz:Optional[int]=None,
        margin:Optional[int]=0.3,
        n_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        use_encoder_parallel:Optional[bool]=True,
        lw:Optional[float]=0.8,
    ):
        super().__init__(config)
        store_attr('lw,use_encoder_parallel')
        self.encoder = DBT013Encoder(config)
        self.gen_lfn = MultiCrossEntropy(tn_targ=tn_targ, ig_tok=ig_tok, reduce='mean')
        self.rep_lfn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau,
                                    apply_softmax=apply_softmax, reduce='mean')
        
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.init_dr_head()
        
    def remap_post_init(self): 
        self.encoder.distilbert= self.distilbert
        self.encoder.vocab_transform = self.vocab_transform
        self.encoder.vocab_layer_norm = self.vocab_layer_norm
        self.encoder.vocab_projector = self.vocab_projector
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_logits, data_repr = encoder(data_input_ids, data_attention_mask)
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_logits, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
            
            lm_loss = 0.5 * (self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr) + self.gen_lfn(lbl2data_logits, data_input_ids))
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx)
            loss = dr_loss + self.lw*lm_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )
    

### Example

In [None]:
m = DBT013.from_pretrained('distilbert-base-uncased', tn_targ=10_000, ig_tok=0, margin=0.4, tau=0.7, apply_softmax=True, 
                           n_negatives=5, lw=0.8, use_encoder_parallel=False)
m.init_dr_head()

b = prepare_batch(m, batch, m_args='lbl2data_idx')

Some weights of DBT013 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**b)

In [None]:
o.loss

tensor(13.6458, grad_fn=<AddBackward0>)

## DBT014

In [None]:
#| export
class DBT014Encoder(DBT013Encoder):
    
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None, 
        **kwargs
    ):    
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        rep = self.dr_transform(o[0][:,0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        rep = F.normalize(rep, dim=1)
        
        logits = self.vocab_transform(o[0][:, 1:])
        logits = self.activation(logits)
        logits = self.vocab_layer_norm(logits)
        logits = self.vocab_projector(logits)
        
        return o,logits,rep
        

In [None]:
#| export
class DBT014(DBT013):
    
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.encoder = DBT014Encoder(config)
        self.post_init()
        self.remap_post_init()

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_logits, data_repr = encoder(data_input_ids, data_attention_mask)
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_logits, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
            
            lm_loss = 0.5 * (self.gen_lfn(data_logits, lbl2data_input_ids[:, 1:], lbl2data_data2ptr) + self.gen_lfn(lbl2data_logits, data_input_ids[:, 1:]))
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx)
            loss = dr_loss + self.lw*lm_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )
        

### Example

In [None]:
m = DBT014.from_pretrained('distilbert-base-uncased', tn_targ=10_000, ig_tok=0, margin=0.4, tau=0.7, apply_softmax=True, 
                           n_negatives=5, lw=0.8, use_encoder_parallel=False)
m.init_dr_head()

b = prepare_batch(m, batch, m_args='lbl2data_idx')

Some weights of DBT014 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**b)

In [None]:
o.loss

tensor(14.0373, grad_fn=<AddBackward0>)

## DBT017

In [None]:
#| export
class DBT017(DBT013):
    
    @delegates(DBT013.__init__)
    def __init__(self, config, m_lw:Optional[Union[float,List]]=0.2, meta_prefix:Optional[str]=None, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.m_lw, self.meta_prefix = m_lw, meta_prefix
        
    def _get_meta_inputs(self, **kwargs):
        inputs = {}
        for t in [o for o in kwargs if self.meta_prefix is not None and re.match(f'^[p]?{self.meta_prefix}.*', o)]:
            p,q = t.split('_', maxsplit=1)
            if t[0] == 'p': inputs.setdefault(p[1:], {})[f'p{q}'] = kwargs[t]
            else: inputs.setdefault(p, {})[q] = kwargs[t]
        return inputs

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o, data_logits, data_repr = encoder(data_input_ids, data_attention_mask)
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_logits, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
            
            lm_loss = self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss = dr_loss + self.lw*lm_loss
            
            meta_inputs = self._get_meta_inputs(**kwargs)
            if isinstance(self.m_lw, float):
                meta_lw = self.m_lw/len(meta_inputs) if len(meta_inputs) else None
                meta_lw = [meta_lw]*len(meta_inputs)
            else:
                if len(self.m_lw) != len(meta_inputs): raise ValueError(f'length of `m_lw` should be equal to number of metadata.')
                meta_lw = self.m_lw
                
            for m,m_lw in zip(meta_inputs.values(), meta_lw):
                if 'lbl2data2ptr' in m:
                    valid_idx = torch.where(m['lbl2data2ptr'])[0]
                    if len(valid_idx) > 0:
                        o, logits, rep = encoder(m['input_ids'], m['attention_mask'])
                        m_lml = self.gen_lfn(lbl2data_logits[valid_idx], m['input_ids'], m['lbl2data2ptr'][valid_idx], **kwargs)
                        m_drl = self.rep_lfn(lbl2data_repr[valid_idx], rep, m['lbl2data2ptr'][valid_idx], m['idx'], 
                                             m['plbl2data2ptr'][valid_idx], m['pidx'], **kwargs)
                        loss += m_lw * (m_drl + self.lw* m_lml)
                        
                elif 'data2ptr' in m:
                    valid_idx = torch.where(m['data2ptr'])[0]
                    if len(valid_idx) > 0:
                        o, logits, rep = encoder(m['input_ids'], m['attention_mask'])
                        m_lml = self.gen_lfn(data_logits[valid_idx], m['input_ids'], m['data2ptr'][valid_idx], **kwargs)
                        m_drl = self.rep_lfn(data_repr[valid_idx], rep, m['data2ptr'][valid_idx], m['idx'], 
                                             m['pdata2ptr'][valid_idx], m['pidx'], **kwargs)
                        loss += m_lw * (m_drl + self.lw*m_lml) 
                        
                else: raise ValueError('Invalid metadata input arguments.')
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )
    

### Example

In [None]:
m = DBT017.from_pretrained('distilbert-base-uncased', tn_targ=1000, ig_tok=0, margin=0.4, tau=0.7, apply_softmax=True, n_negatives=5, 
                           lw=0.8, m_lw=[0.1], meta_prefix='cat', use_encoder_parallel=False)

b = prepare_batch(m, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 
                                    'cat2data_attention_mask', 'cat2data_data2ptr'])

Some weights of DBT017 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**b)

In [None]:
o.loss

tensor(11.4427, grad_fn=<AddBackward0>)

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

## DBT021

In [None]:
#| export
class DBT021Encoder(DistilBertPreTrainedModel):
    
    def __init__(self, 
                 config, 
                 repr_type:Optional[str]='pool', 
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.distilbert = DistilBertModel(config)
        
        self.dr_transform = nn.Linear(config.dim, config.dim)
        self.dr_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dr_projector = nn.Linear(config.dim, config.dim)

        self.meta_transform = nn.Linear(config.dim, config.dim)
        self.meta_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.meta_projector = nn.Linear(config.dim, config.dim)
        
    def init_dr_head(self):
        self.dr_transform.weight.data = torch.eye(self.dr_transform.out_features, self.dr_transform.in_features, 
                                                  dtype=self.dr_transform.weight.dtype)
        self.dr_projector.weight.data = torch.eye(self.dr_projector.out_features, self.dr_projector.in_features, 
                                                  dtype=self.dr_projector.weight.dtype)

        self.meta_transform.weight.data = torch.eye(self.meta_transform.out_features, self.meta_transform.in_features, 
                                                  dtype=self.meta_transform.weight.dtype)
        self.meta_projector.weight.data = torch.eye(self.meta_projector.out_features, self.meta_projector.in_features, 
                                                    dtype=self.meta_projector.weight.dtype)
        
    def dr(self, x):
        x = self.dr_transform(x)
        x = self.dr_layer_norm(x)
        x = self.dr_projector(x)
        return x

    def meta(self, x):
        x = self.meta_transform(x)
        x = self.meta_layer_norm(x)
        x = self.meta_projector(x)
        return x
        
    @delegates(BertModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        repr_type:Optional[str]='pool',
        input_type:Optional[str]='data',
        **kwargs
    ):
        if repr_type is None: repr_type = self.repr_type
            
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if repr_type == 'pool': rep = o[0]
        elif repr_type == 'cls': rep, attention_mask = o[0][:,0:1], attention_mask[:,0:1]
        elif repr_type == 'tok': rep, attention_mask = o[0][:,1:], attention_mask[:,1:]
        else: raise ValueError(f'Invalid representation type `repr_type`({repr_type}).')

        if input_type == 'data': rep = self.dr(rep)
        elif input_type == 'meta': rep = self.meta(rep)
        else: raise ValueError(f'Invalid `input_type`({input_type})')
        
        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
        

In [None]:
#| export
class DBT021(DBT010):

    @delegates(DBT010.__init__)
    def __init__(self, 
                 config, 
                 m_lw:Optional[Union[float,List]]=0.2, 
                 data_meta_prefix:Optional[str]=None,
                 lbl2data_meta_prefix:Optional[str]=None,
                 task_repr_type:Optional[str]='pool',
                 meta_repr_type:Optional[str]='pool',
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.m_lw, self.data_meta_prefix, self.lbl2data_meta_prefix = m_lw, data_meta_prefix, lbl2data_meta_prefix
        self.task_repr_type, self.meta_repr_type = task_repr_type, meta_repr_type
        
        self.encoder = DBT021Encoder(config)
        self.post_init()
        self.remap_post_init()
        
    def _get_meta_inputs(self, meta_prefix, **kwargs):
        inputs = {}
        for t in [o for o in kwargs if meta_prefix is not None and re.match(f'^[p]?{meta_prefix}.*', o)]:
            p,q = t.split('_', maxsplit=1)
            if t[0] == 'p': inputs.setdefault(p[1:], {})[f'p{q}'] = kwargs[t]
            else: inputs.setdefault(p, {})[q] = kwargs[t]
        return inputs

    def _get_meta_loss_weights(self, m_lw, n_meta_inputs):
        if isinstance(m_lw, float):
            meta_lw = m_lw/n_meta_inputs if n_meta_inputs else None
            meta_lw = [meta_lw]*n_meta_inputs
        else:
            if len(m_lw) != n_meta_inputs: raise ValueError(f'length of `m_lw` should be equal to number of metadata.')
            meta_lw = m_lw
        return meta_lw

    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder

        data_meta_inputs = self._get_meta_inputs(self.data_meta_prefix, **kwargs)
        lbl2data_meta_inputs = self._get_meta_inputs(self.lbl2data_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}
        
        m_lw = self._get_meta_loss_weights(self.m_lw, len(meta_inputs))
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(input_ids=inputs['input_ids'],
                                       attention_mask=inputs['attention_mask'], 
                                       input_type="meta", repr_type=self.meta_repr_type)
                    m_loss = self.loss_fn(lbl2data_repr[idx], inputs_o[1], inputs['lbl2data2ptr'][idx],
                                          inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(input_ids=inputs['input_ids'], 
                                       attention_mask=inputs['attention_mask'], 
                                       input_type="meta", repr_type=self.meta_repr_type)
                    m_loss = self.loss_fn(data_repr[idx], inputs_o[1], inputs['data2ptr'][idx], inputs['idx'], 
                                          inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder

        data_o, data_repr = encoder(data_input_ids, data_attention_mask, repr_type=self.task_repr_type)
        
        loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask, repr_type=self.task_repr_type)
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, plbl2data_data2ptr, plbl2data_idx)
            loss += self.compute_meta_loss(data_repr, lbl2data_repr, **kwargs)
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            dr_loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o.hidden_states,
            data_attentions=data_o.attentions,
        )
    

### Example

In [None]:
m = DBT021.from_pretrained('distilbert-base-uncased', tn_targ=1000, margin=0.3, tau=0.1, apply_softmax=True, n_negatives=5, 
                           m_lw=[0.1, 0.1], data_meta_prefix='cat2data', lbl2data_meta_prefix='cat2lbl', use_encoder_parallel=False, 
                           task_repr_type='tok', meta_repr_type='cls')

b = prepare_batch(m, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 
                                    'cat2data_attention_mask', 'cat2data_data2ptr', 
                                    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
                                    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr'])

Some weights of DBT021 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight', 'encoder.meta_layer_norm.bias', 'encoder.meta_layer_norm.weight', 'encoder.meta_projector.bias', 'encoder.meta_projector.weight', 'encoder.meta_transform.bias', 'encoder.meta_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
o = m(**b)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0532, grad_fn=<AddBackward0>)

## Fuser

In [None]:
#| export
import math

In [None]:
#| export
class Fuser(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def forward(
        self, 
        data: torch.Tensor,
        data_mask: torch.Tensor,
        meta: torch.Tensor, 
        meta_mask: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        q, k, v, q_m, k_m = data, meta, meta, data_mask, meta_mask
        
        bs, q_len, dim = q.size()
        k_len = k.size(1)

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask.bool(), torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
from transformers import AutoConfig

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')



In [None]:
fuser = Fuser(config)

In [None]:
bsz, data_seq_len, meta_seq_len, dim, dtype = 10, 14, 17, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, meta_seq_len, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,meta_seq_len), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([10, 14, 768])

## DBT018

In [None]:
#| export
from fastcore.utils import *

In [None]:
#| export
class DBT018Encoder(DBT013Encoder):
    
    def __init__(self, config, tn_meta:Optional[int]=None):
        super().__init__(config)
        self.fuser, self.ln = Fuser(config), nn.LayerNorm(config.dim, eps=1e-12)
        self.o = torch.ones(tn_meta, dtype=torch.long, device=self.device) if tn_meta is not None else None
        
        
    def _get_meta_inputs(self, meta_prefix:Optional[str]=None, **kwargs):
        inputs = {}
        for t in [o for o in kwargs if meta_prefix is not None and re.match(f'^[p]?{meta_prefix}.*', o)]:
            p,q = t.split('_', maxsplit=1)
            if t[0] == 'p': inputs.setdefault(p[1:], {})[f'p{q}'] = kwargs[t]
            else: inputs.setdefault(p, {})[q] = kwargs[t]
        return inputs
    
    
    def get_output(self, input_ids:Optional[torch.Tensor]=None, attention_mask:Optional[torch.Tensor]=None, 
                   **kwargs):
        return self.distilbert(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
    
    
    def resize_meta(self, meta:torch.Tensor, mask:torch.Tensor, n_data2meta:torch.Tensor):
        bsz, dim, tn_data2meta = n_data2meta.shape[0], meta.shape[-1], meta.shape[0]
        self.o = self.o.to(meta.device)
        o = (
            torch.ones(tn_data2meta, dtype=torch.long, device=meta.device) 
            if self.o is None or len(self.o) < tn_data2meta else self.o[:tn_data2meta]
        )

        max_n_data2meta = n_data2meta.max()
        xn_data2meta = max_n_data2meta-n_data2meta+1

        data2meta_ptr = n_data2meta.cumsum(dim=0)-1
        r_data2meta = o.scatter(0, data2meta_ptr, xn_data2meta)

        xmeta,xmask = meta.repeat_interleave(r_data2meta, dim=0),mask.repeat_interleave(r_data2meta, dim=0)
        m = o.scatter(0, data2meta_ptr, 0).repeat_interleave(r_data2meta, dim=0).view(bsz, -1)
        m[:, -1] = 1; m = m.view(-1, 1)
        xmask *= m

        return xmeta,xmask
    

    def get_meta_fused_output(self, input_ids:Optional[torch.Tensor]=None, attention_mask:Optional[torch.Tensor]=None, 
                              aug_meta_prefix:Optional[str]=None, **kwargs):
        data_h = self.get_output(input_ids, attention_mask)[0]

        meta_inputs = self._get_meta_inputs(aug_meta_prefix, **kwargs)
        for m in meta_inputs.values():
            valid_idx = torch.where(m['data2ptr'] > 0)[0]
            if len(valid_idx):
                bsz = len(valid_idx)
                meta_input_ids, meta_attention_mask = self.resize_meta(m['input_ids'], m['attention_mask'], 
                                                                       m['data2ptr'][valid_idx])
                meta_h = self.get_output(meta_input_ids, meta_attention_mask)[0]

                meta_h,meta_attention_mask = meta_h.view(bsz, -1, self.config.dim), meta_attention_mask.view(bsz, -1)

                data_h[valid_idx] += self.fuser(data_h[valid_idx], attention_mask[valid_idx], 
                                                meta_h, meta_attention_mask)[0]

        data_h = self.ln(data_h)
        return (data_h,)
    
    
    def forward(
        self, 
        data_input_ids:Optional[torch.Tensor]=None, 
        data_attention_mask:Optional[torch.Tensor]=None, 
        data_aug_meta_prefix:Optional[str]=None,
        **kwargs
    ):  
        o = self.get_meta_fused_output(
            input_ids=data_input_ids,
            attention_mask=data_attention_mask,
            aug_meta_prefix=data_aug_meta_prefix,
            **kwargs
        )
        torch.cuda.empty_cache()

        rep = self.dr_transform(o[0])
        rep = self.dr_layer_norm(rep)
        rep = self.dr_projector(rep)
        rep = F.normalize(Pooling.mean_pooling(rep, data_attention_mask), dim=1)

        logits = self.vocab_transform(o[0])
        logits = self.activation(logits)
        logits = self.vocab_layer_norm(logits)
        logits = self.vocab_projector(logits)
        
        return o,logits,rep
    

In [None]:
#| export
class DBT018(DBT013):
    
    @delegates(DBT013.__init__)
    def __init__(
        self, 
        config, 
        data_aug_meta_prefix:Optional[List]=None, 
        lbl2data_aug_meta_prefix:Optional[List]=None, 
        tn_meta:Optional[int]=None, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        self.encoder = XCDataParallel(module=DBT018Encoder(config, tn_meta))
        self.remap_post_init()
        
    def _get_meta_kwargs(self, feat:str, meta_prefix:Optional[str]=None, **kwargs):
        keys = ['attention_mask', 'input_ids']
        o = {f'{meta_prefix}_{k}': kwargs[f'{meta_prefix}_{k}'] for k in keys if f'{meta_prefix}_{k}' in kwargs}
        if meta_prefix is not None and f'{meta_prefix}_{feat}2ptr' in kwargs:
            o.update({f'{meta_prefix}_data2ptr': kwargs[f'{meta_prefix}_{feat}2ptr']})
        return o
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_logits, data_repr = self.encoder(data_input_ids=data_input_ids, 
                                                      data_attention_mask=data_attention_mask, 
                                                      data_aug_meta_prefix=self.data_aug_meta_prefix,
                                                      **self._get_meta_kwargs('data', self.data_aug_meta_prefix, 
                                                                              **kwargs))
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_logits, lbl2data_repr = self.encoder(data_input_ids=lbl2data_input_ids, 
                                                                      data_attention_mask=lbl2data_attention_mask, 
                                                                      data_aug_meta_prefix=self.lbl2data_aug_meta_prefix,
                                                                      **self._get_meta_kwargs('lbl2data', 
                                                                                              self.lbl2data_aug_meta_prefix, 
                                                                                              **kwargs))
            
            lm_loss = self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss = dr_loss + self.lw*lm_loss
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o[0],
        )
    

### Example

In [None]:
torch.manual_seed(100)

model = DBT018.from_pretrained('distilbert-base-uncased', ig_tok=0, tn_targ=1000, tn_meta=1000, 
                               margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, lw=0.01,
                               data_aug_meta_prefix='hlk2data', lbl2data_aug_meta_prefix='hlk2lbl')

Some weights of DBT018 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.module.dr_layer_norm.bias', 'encoder.module.dr_layer_norm.weight', 'encoder.module.dr_projector.bias', 'encoder.module.dr_projector.weight', 'encoder.module.dr_transform.bias', 'encoder.module.dr_transform.weight', 'encoder.module.fuser.k.bias', 'encoder.module.fuser.k.weight', 'encoder.module.fuser.o.bias', 'encoder.module.fuser.o.weight', 'encoder.module.fuser.q.bias', 'encoder.module.fuser.q.weight', 'encoder.module.fuser.v.bias', 'encoder.module.fuser.v.weight', 'encoder.module.ln.bias', 'encoder.module.ln.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(model, batch, m_args=[
    'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr',
    'hlk2lbl_input_ids', 'hlk2lbl_attention_mask', 'hlk2lbl_lbl2data2ptr'
])

In [None]:
model = model.to('cuda')

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.1960, device='cuda:0', grad_fn=<AddBackward0>)

## DBT020

In [None]:
#| export
class DBT020(DBT018):
    
    @delegates(DBT018.__init__)
    def __init__(
        self, 
        config, 
        m_lw:Optional[Union[List,float]]=0.2, 
        pred_meta_prefix:Optional[List]=None, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_lw,pred_meta_prefix')
        
        
    def _get_meta_inputs(self, meta_prefix:Optional[str]=None, **kwargs):
        inputs = {}
        for t in [o for o in kwargs if meta_prefix is not None and re.match(f'^[p]?{meta_prefix}.*', o)]:
            p,q = t.split('_', maxsplit=1)
            if t[0] == 'p': inputs.setdefault(p[1:], {})[f'p{q}'] = kwargs[t]
            else: inputs.setdefault(p, {})[q] = kwargs[t]
        return inputs
    
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        data_o, data_logits, data_repr = self.encoder(data_input_ids=data_input_ids, 
                                                      data_attention_mask=data_attention_mask, 
                                                      data_aug_meta_prefix=self.data_aug_meta_prefix,
                                                      **self._get_meta_kwargs('data', self.data_aug_meta_prefix, 
                                                                              **kwargs))
        
        loss = lm_loss = dr_loss = lbl2data_repr = None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_logits, lbl2data_repr = self.encoder(data_input_ids=lbl2data_input_ids, 
                                                                      data_attention_mask=lbl2data_attention_mask, 
                                                                      data_aug_meta_prefix=self.lbl2data_aug_meta_prefix,
                                                                      **self._get_meta_kwargs('lbl2data', 
                                                                                              self.lbl2data_aug_meta_prefix, 
                                                                                              **kwargs))
            
            lm_loss = self.gen_lfn(data_logits, lbl2data_input_ids, lbl2data_data2ptr, **kwargs)
            dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                   plbl2data_data2ptr, plbl2data_idx, **kwargs)
            loss = dr_loss + self.lw*lm_loss
            
            meta_inputs = self._get_meta_inputs(self.pred_meta_prefix, **kwargs)
            if isinstance(self.m_lw, float):
                meta_lw = self.m_lw/len(meta_inputs) if len(meta_inputs) else None
                meta_lw = [meta_lw]*len(meta_inputs)
            else:
                if len(self.m_lw) != len(meta_inputs): raise ValueError(f'length of `m_lw` should be equal to number of metadata.')
                meta_lw = self.m_lw
            
            for m,m_lw in zip(meta_inputs.values(),meta_lw):
                if 'lbl2data2ptr' in m:
                    valid_idx = torch.where(m['lbl2data2ptr'])[0]
                    if len(valid_idx) > 0:
                        o, logits, rep = self.encoder(data_input_ids=m['input_ids'], data_attention_mask=m['attention_mask'])
                        m_lml = self.gen_lfn(lbl2data_logits[valid_idx], m['input_ids'], m['lbl2data2ptr'][valid_idx], **kwargs)
                        m_drl = self.rep_lfn(lbl2data_repr[valid_idx], rep, m['lbl2data2ptr'][valid_idx], m['idx'], 
                                             m['plbl2data2ptr'][valid_idx], m['pidx'], **kwargs)
                        loss += m_lw * (m_drl + self.lw* m_lml)
                        
                elif 'data2ptr' in m:
                    valid_idx = torch.where(m['data2ptr'])[0]
                    if len(valid_idx) > 0:
                        o, logits, rep = self.encoder(data_input_ids=m['input_ids'], data_attention_mask=m['attention_mask'])
                        m_lml = self.gen_lfn(data_logits[valid_idx], m['input_ids'], m['data2ptr'][valid_idx], **kwargs)
                        m_drl = self.rep_lfn(data_repr[valid_idx], rep, m['data2ptr'][valid_idx], m['idx'], 
                                             m['pdata2ptr'][valid_idx], m['pidx'], **kwargs)
                        loss += m_lw * (m_drl + self.lw*m_lml)       
                
                else: raise ValueError('Invalid metadata input arguments.')
            
        if not return_dict:
            o = (data_logits,data_repr,lbl2data_repr) + data_o[2:]
            return ((loss,lm_loss,dr_loss) + o) if loss is not None else o
        
        return XCModelOutput(
            loss=loss,
            lm_loss=lm_loss,
            dr_loss=dr_loss,
            logits=data_logits,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            data_hidden_states=data_o[0],
        )
    

### Example

In [None]:
model = DBT020.from_pretrained('distilbert-base-uncased', ig_tok=0, tn_targ=10_000, tn_meta=10_000, 
                               margin=0.3, tau=0.1, n_negatives=5, apply_softmax=True, lw=0.01,
                               pred_meta_prefix='cat', data_aug_meta_prefix='hlk2data', 
                               lbl2data_aug_meta_prefix='hlk2lbl', m_lw=0.3)

Some weights of DBT020 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.module.dr_layer_norm.bias', 'encoder.module.dr_layer_norm.weight', 'encoder.module.dr_projector.bias', 'encoder.module.dr_projector.weight', 'encoder.module.dr_transform.bias', 'encoder.module.dr_transform.weight', 'encoder.module.fuser.k.bias', 'encoder.module.fuser.k.weight', 'encoder.module.fuser.o.bias', 'encoder.module.fuser.o.weight', 'encoder.module.fuser.q.bias', 'encoder.module.fuser.q.weight', 'encoder.module.fuser.v.bias', 'encoder.module.fuser.v.weight', 'encoder.module.ln.bias', 'encoder.module.ln.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr',
    
    'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr',
    'hlk2lbl_input_ids', 'hlk2lbl_attention_mask', 'hlk2lbl_lbl2data2ptr'
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.2606, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
hasattr(model, 'encoder')

True