In [13]:
from transformers import DebertaV2Model, AutoTokenizer
from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler
import torch.optim as optim
import torch
import torch.nn.functional as F

In [2]:
MODEL_PATH = 'microsoft/deberta-v3-large'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH,use_fast=False)

model = DebertaV2Model.from_pretrained(
    MODEL_PATH,
    device_map = "cuda",
)

In [3]:
# freeze weights
for param in model.parameters():
    param.requires_grad = False

In [5]:
class DebertaV2Feature(torch.nn.Module):
    def __init__(self, model):
        super(DebertaV2Feature, self).__init__()
        config = model.config
        self.deberta = model
        self.pooler = ContextPooler(config)

    def forward(self,input_ids,attention_mask=None):
        outputs = self.deberta(
            input_ids,
            attention_mask=attention_mask,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=None,
        )

        encoder_layer = outputs[0]
        pooled_output = self.pooler(encoder_layer)
        return pooled_output

In [6]:
model = DebertaV2Feature(model).to('cuda')

In [8]:
params_to_optimize = [param for param in model.parameters() if param.requires_grad]
optimizer = optim.Adam(params_to_optimize, lr=0.001)

In [26]:
input_ids1 = torch.randint(0,1000,(4,32),device='cuda')
input_ids2 = torch.randint(0,1000,(4,32),device='cuda')
# TODO: attention_mask 
labels = torch.zeros(4,dtype=torch.float32,device='cuda')
labels[2:] = 1
# tensor([0, 0, 1, 1], device='cuda:0')
# 1 is for same category, 0 different
# retrival should do argmax_k

In [27]:
outputs1 = model(input_ids=input_ids1) # 4,1024
outputs2 = model(input_ids=input_ids2)

In [28]:
loss_fn = torch.nn.BCEWithLogitsLoss()
logits = F.cosine_similarity(outputs1,outputs2) # cosine_similarity -> (4,)
loss = loss_fn(logits, labels)