In [1]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer, models, LoggingHandler, InputExample, losses, evaluation
from sentence_transformers.losses import SiameseDistanceMetric
import numpy as np
import random
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

VOC_NAMES = ["Alpha", "Beta", "Delta", "Gamma", "Omicron"]

In [2]:
#word_embedding_model = models.Transformer(model_name_or_path="Rostlab/prot_bert", max_seq_length=1280)

encoder = models.Transformer(model_name_or_path="./mlm_checkpoints/CoV-RoBERTa_2048",
                                          max_seq_length=1280,
                                          tokenizer_name_or_path="tok/")

pooler = models.Pooling(encoder.get_word_embedding_dimension(),
                                pooling_mode_cls_token=True,
                                pooling_mode_max_tokens=True,
                                pooling_mode_mean_tokens=True,
                                pooling_mode_mean_sqrt_len_tokens=True)

dropout1 = models.Dropout(0.5)
linear1 = models.Linear(pooler.get_sentence_embedding_dimension(), int(pooler.get_sentence_embedding_dimension()/16))
dropout2 = models.Dropout(0.2)
linear2 = models.Linear(linear1.get_sentence_embedding_dimension(), int(linear1.get_sentence_embedding_dimension()/8))
dropout3 = models.Dropout(0.2)
linear3 = models.Linear(linear2.get_sentence_embedding_dimension(), 1)

print(encoder.get_word_embedding_dimension())
print(pooler.get_sentence_embedding_dimension())
print(linear1.get_sentence_embedding_dimension())
print(linear2.get_sentence_embedding_dimension())
print(linear3.get_sentence_embedding_dimension())

model = SentenceTransformer(modules=[encoder, pooler, dropout1, linear1, dropout2, linear2, dropout3, linear3])

Some weights of the model checkpoint at ./mlm_checkpoints/CoV-RoBERTa_2048 were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ./mlm_checkpoints/CoV-RoBERTa_2048 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and

768
3072
192
24
1


# Load Contrastive Dataset

In [3]:
import pickle
with open('contrastive_train.pkl', 'rb') as file:
    train_all = pickle.load(file)
with open('contrastive_dev.pkl', 'rb') as file:
    dev_all = pickle.load(file)
with open('contrastive_test.pkl', 'rb') as file:
    test_all = pickle.load(file)

In [4]:
display(train_all)
display(dev_all)

Unnamed: 0,seq1,seq2,label
0,MFVFFVLLPLVSSQCVNLTTRTQLPTAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
1,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,0
2,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,1
3,MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,0
4,MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLRTRTQLPSAYTNSFTRGVYYPDKVFRSS...,1
...,...,...,...
13996495,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,1
13996496,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
13996497,MFVFLVLLPLVCSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
13996498,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFFVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1


Unnamed: 0,seq1,seq2,label
0,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,1
1,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSXTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
2,MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVFSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
3,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQXXXXYTNSFTRGVYYPDKVFRSS...,0
4,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
...,...,...,...
4018990,MFVFFVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFFVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,1
4018991,MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
4018992,MFVFLVLLPLVSTQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLRTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
4018993,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,1


In [5]:
sample = train_all.at[0, 'seq1']
t = model.tokenize(sample)
e = model.encode(sample)
print(e.size)
print(t['input_ids'].shape)
print(t)
print(model.tokenizer.encode(sample))
print(e)

1
torch.Size([1, 8])
{'input_ids': tensor([[   0, 9529,  826,   23,   83, 2126, 7477,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
[0, 9529, 826, 23, 83, 2126, 7477, 2]
[0.43635342]


In [6]:
train_examples = []
seq1=list(train_all['seq1'])
seq2=list(train_all['seq2'])
label=list(train_all['label'])
for i in tqdm(range(len(train_all)//100)):
    train_examples.append(InputExample(texts=[seq1[i], seq2[i]], label=label[i]))

dev_examples = []
seq1=list(dev_all['seq1'])
seq2=list(dev_all['seq2'])
label=list(dev_all['label'])
for i in tqdm(range(len(dev_all)//100)):
    dev_examples.append(InputExample(texts=[seq1[i], seq2[i]], label=label[i]))

  0%|          | 0/139965 [00:00<?, ?it/s]

  0%|          | 0/40189 [00:00<?, ?it/s]

In [7]:
train_batch_size = dev_batch_size = 600 # 600 for CoV-RoBERTa, 3000 for ProtBERT
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
dev_dataloader = DataLoader(dev_examples, shuffle=True, batch_size=dev_batch_size)

In [8]:
#train_loss = losses.ContrastiveLoss(model=model, distance_metric=SiameseDistanceMetric.MANHATTAN, margin=0.5)
train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=2)

In [9]:
# evaluator = evaluation.BinaryClassificationEvaluator(sentences1=list(dev_all['seq1']),
#                                                      sentences2=list(dev_all['seq2']),
#                                                      labels=list(dev_all['label']),
#                                                      batch_size=1000,
#                                                      show_progress_bar=True,
#                                                      write_csv=True)

evaluator = evaluation.LabelAccuracyEvaluator(dataloader=dev_dataloader, softmax_model=train_loss)

In [10]:
eval_iter = len(train_all)//(10*100*train_batch_size)
print(eval_iter)

23


In [12]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=1,
          optimizer_params= {'lr': 1e-4}, # 1e-3 for CoV-RoBERTa, 1e-6 for ProtBERT
          weight_decay=0.1, # 0.1 for CoV-RoBERTa, 0.01 for ProtBERT
          evaluation_steps=eval_iter, # run an evalution in every 10% 
          output_path='./snn_output',
          save_best_model=True,
          checkpoint_path='./snn_checkpoints',
          checkpoint_save_steps=500,
          checkpoint_save_total_limit=2,
          show_progress_bar=True)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/234 [00:00<?, ?it/s]

Accuracy: 0.4998 (20085/40189)

Accuracy: 0.5003 (20106/40189)



KeyboardInterrupt: 

In [None]:
# TODO
# Try learning rate 1e-3 instead of 2e-5
# Create test set evaluator
# Perform zero-shot testing by removing a VOC or adding a new VOI