In [1]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer, models, InputExample, losses, evaluation
from sentence_transformers.util import SiameseDistanceMetric
import numpy as np
import random
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

CONSTANTS = {
    "VOC_NAMES": ["Greaney", "Baum"],
    "LOSS_NAME": "ContrastiveLoss",
    "NEG_SET": "Greaney",
    "POOLING_MODE": "max",
    "CONCAT": None,
    "NUM_LABELS": None,
    "CONF_THRESHOLD": None,
    "BATCH_SIZE": 32,
    "EPOCHS": 10,
    "LR": 1e-4,
    "WD": 1e-3,
    "RELU": 0.3,
    "DROPOUT": 0.5,
    "MARGIN": 0.2
}

In [2]:
#word_embedding_model = models.Transformer(model_name_or_path="Rostlab/prot_bert", max_seq_length=1280)

encoder = models.Transformer(model_name_or_path="./mlm_checkpoints/CoV-RoBERTa_128",
                                          max_seq_length=1280,
                                          tokenizer_name_or_path="tok/")

dim = encoder.get_word_embedding_dimension() # 768

pooler = models.Pooling(dim, pooling_mode = CONSTANTS["POOLING_MODE"])

modules = [encoder, pooler]

if CONSTANTS["RELU"] > 0:
    dense = models.Dense(in_features=dim, out_features=int(dim*CONSTANTS["RELU"]), activation_function=nn.ReLU())
    modules.append(dense)

if CONSTANTS["DROPOUT"] > 0:
    dropout = models.Dropout(CONSTANTS["DROPOUT"])
    modules.append(dropout)

model = SentenceTransformer(modules=modules)

# # Freeze initial transformer layers
# for param in model[0].auto_model.embeddings.parameters():
#     param.requires_grad = False
# for param in model[0].auto_model.encoder.layer[:6].parameters():
#     param.requires_grad = False

print(model)

Some weights of the model checkpoint at ./mlm_checkpoints/CoV-RoBERTa_128 were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ./mlm_checkpoints/CoV-RoBERTa_128 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and i

SentenceTransformer(
  (0): Transformer({'max_seq_length': 1280, 'do_lower_case': False}) with Transformer model: RobertaModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': True, 'pooling_mode_global_max': False, 'pooling_mode_global_avg': False, 'pooling_mode_attention': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
  (2): Dense({'in_features': 768, 'out_features': 230, 'bias': True, 'activation_function': 'torch.nn.modules.activation.ReLU'})
  (3): Dropout(
    (dropout_layer): Dropout(p=0.5, inplace=False)
  )
)


# Generate Pairs for Training

In [3]:
sig_seq = pd.read_csv('exp_data/sig_train_val_extended.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()
non_sig_seq = pd.read_csv('exp_data/non_sig_train_val_extended.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()
print(len(sig_seq), len(non_sig_seq))
from Bio import SeqIO
wt = str(SeqIO.read('exp_data/wild_type.fasta', 'fasta').seq)
examples = []

for neg in sig_seq:
    examples.append(InputExample(texts=[wt, neg], label=0))

for pos in non_sig_seq:
    examples.append(InputExample(texts=[wt, pos], label=1))

print("Training set length:", len(examples))
# split examples list into train, validation and test sets
random.shuffle(examples)
train_size = int(len(examples) * 0.8)
val_size = int(len(examples) * 0.1)
train_examples = examples[:train_size]
val_examples = examples[train_size:train_size + val_size]
test_examples = examples[train_size + val_size:]

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=CONSTANTS["BATCH_SIZE"])
# val_dataloader = DataLoader(val_examples, shuffle=False, batch_size=CONSTANTS["BATCH_SIZE"])
# test_dataloader = DataLoader(test_examples, shuffle=False, batch_size=CONSTANTS["BATCH_SIZE"])

2404 2698
Training set length: 5102


# Generate Pairs for Zero-shot Test

In [4]:
if CONSTANTS["NEG_SET"] == "Greaney":
        non_sig_seq = pd.read_csv('exp_data/non_sig_seq_greaney_filtered.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()
        sig_seq = pd.read_csv('exp_data/sig_seq_greaney.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()
elif CONSTANTS["NEG_SET"] == "Baum":
        non_sig_seq = pd.read_csv('exp_data/non_sig_seq_baum_filtered.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()
        sig_seq = pd.read_csv('exp_data/sig_seq_baum.csv', header=None, names=['mutation', 'sequence'])['sequence'].tolist()

print("non-sig:", len(non_sig_seq), "sig:", len(sig_seq))

zero_test_examples = []

for seq in non_sig_seq:
        zero_test_examples.append(InputExample(texts=[wt, seq], label=1))

for seq in sig_seq:
        zero_test_examples.append(InputExample(texts=[wt, seq], label=0))

# shuffle the zero-shot test examples
random.shuffle(zero_test_examples)

print("Zero-shot test set length: ", len(zero_test_examples))

non-sig: 407 sig: 181
Zero-shot test set length:  588


# Define Loss

In [5]:
if CONSTANTS["LOSS_NAME"] == "ContrastiveLoss":
    train_loss = losses.ContrastiveLoss(model=model,
                                        distance_metric=SiameseDistanceMetric.EUCLIDEAN,
                                        margin = CONSTANTS["MARGIN"])
elif CONSTANTS["LOSS_NAME"] == "OnlineContrastiveLoss":
    train_loss = losses.OnlineContrastiveLoss(model=model,
                                              distance_metric=SiameseDistanceMetric.EUCLIDEAN,
                                              margin = CONSTANTS["MARGIN"])

# Construct Evaluators

In [6]:
evaluator = evaluation.BinaryClassificationEvaluator(
    sentences1=[val_example.texts[0] for val_example in val_examples],
    sentences2=[val_example.texts[1] for val_example in val_examples],
    labels=[val_example.label for val_example in val_examples],
    distance_metric=SiameseDistanceMetric.EUCLIDEAN,
    batch_size=CONSTANTS["BATCH_SIZE"],
    margin = CONSTANTS["MARGIN"],
    show_progress_bar=False,
    write_csv=True,
    name='Eval')

test_evaluator = evaluation.BinaryClassificationEvaluator(
    sentences1=[test_example.texts[0] for test_example in test_examples],
    sentences2=[test_example.texts[1] for test_example in test_examples],
    labels=[test_example.label for test_example in test_examples],
    batch_size=CONSTANTS['BATCH_SIZE'],
    margin=CONSTANTS['MARGIN'],
    show_progress_bar=False,
    name="Test")

zero_test_evaluator = evaluation.BinaryClassificationEvaluator(
    sentences1=[zero_test_example.texts[0] for zero_test_example in zero_test_examples],
    sentences2=[zero_test_example.texts[1] for zero_test_example in zero_test_examples],
    labels=[zero_test_example.label for zero_test_example in zero_test_examples],
    batch_size=CONSTANTS['BATCH_SIZE'],
    margin=CONSTANTS['MARGIN'],
    show_progress_bar=False,
    name="Zero")

# Prepare Folders

In [7]:
import os
import shutil

# Create output directory if needed
output_dir = f"./exp_outputs/{CONSTANTS['NEG_SET']}"

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    print(f"Removed directory: {output_dir}")

checkpoint_dir = f"{output_dir}/checkpoints"
stats_dir = f"{output_dir}/stats"

for d in [checkpoint_dir, stats_dir]:
    if not os.path.exists(d):
        os.makedirs(d)
        print(f"Created directory: {d}")

# Dump CONSTANTS dict to file
import json
with open(f"{output_dir}/constants.json", "w") as f:
    json.dump(CONSTANTS, f, indent=4)

Removed directory: ./exp_outputs/Greaney
Created directory: ./exp_outputs/Greaney/checkpoints
Created directory: ./exp_outputs/Greaney/stats


# Run Training & Test

In [8]:
# print CONSTANTS
for k, v in CONSTANTS.items():
    print(f"{k}: {v}")

model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          tester=test_evaluator,
          zero_shot_tester=zero_test_evaluator,
          epochs=CONSTANTS['EPOCHS'],
          optimizer_class=torch.optim.AdamW,
          optimizer_params= {'lr': CONSTANTS['LR']}, # 1e-3 for CoV-RoBERTa, 1e-6 for ProtBERT
          weight_decay=CONSTANTS['WD'], # 0.1 for CoV-RoBERTa, 0.01 for ProtBERT
          # evaluation_steps=64,
          output_path=output_dir,
          #save_best_model=True,
          #checkpoint_path=checkpoint_dir,
          #checkpoint_save_steps=len(train_dataloader),
          #checkpoint_save_total_limit=1000000,
          show_progress_bar=True,
          loss_name=CONSTANTS['LOSS_NAME'])

VOC_NAMES: ['Greaney', 'Baum']
LOSS_NAME: ContrastiveLoss
NEG_SET: Greaney
POOLING_MODE: max
CONCAT: None
NUM_LABELS: None
CONF_THRESHOLD: None
BATCH_SIZE: 32
EPOCHS: 10
LR: 0.0001
WD: 0.001
RELU: 0.3
DROPOUT: 0.5
MARGIN: 0.2


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 0 ---
Train Loss = 6.6125   Train Accuracy = 0.7059    Train AUC = 0.5972
Eval Loss  = 0.8175   Eval Accuracy  = 0.5863    Eval AUC  = 0.5729    (using best distance threshold   = 1.7684)
Test Loss  = 0.7880   Test Accuracy  = 0.6008    Test AUC  = 0.6068    (using best distance threshold   = 1.6915)
Zero Loss  = 1.2744   Zero Accuracy  = 0.6922    Zero AUC  = 0.5362    (using best distance threshold   = 3.2962)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 1 ---
Train Loss = 0.0045   Train Accuracy = 0.7647    Train AUC = 0.7500
Eval Loss  = 0.0053   Eval Accuracy  = 0.6294    Eval AUC  = 0.6336    (using best distance threshold   = 0.1258)
Test Loss  = 0.0051   Test Accuracy  = 0.6301    Test AUC  = 0.6641    (using best distance threshold   = 0.1037)
Zero Loss  = 0.0081   Zero Accuracy  = 0.6905    Zero AUC  = 0.5771    (using best distance threshold   = 0.3068)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 2 ---
Train Loss = 0.0063   Train Accuracy = 0.8235    Train AUC = 0.7000
Eval Loss  = 0.0051   Eval Accuracy  = 0.6157    Eval AUC  = 0.5912    (using best distance threshold   = 0.0946)
Test Loss  = 0.0051   Test Accuracy  = 0.5988    Test AUC  = 0.6306    (using best distance threshold   = 0.0922)
Zero Loss  = 0.0056   Zero Accuracy  = 0.6905    Zero AUC  = 0.5534    (using best distance threshold   = 0.2494)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 3 ---
Train Loss = 0.0094   Train Accuracy = 0.5294    Train AUC = 0.3542
Eval Loss  = 0.0058   Eval Accuracy  = 0.6294    Eval AUC  = 0.6053    (using best distance threshold   = 0.0607)
Test Loss  = 0.0061   Test Accuracy  = 0.6086    Test AUC  = 0.6313    (using best distance threshold   = 0.0597)
Zero Loss  = 0.0047   Zero Accuracy  = 0.7245    Zero AUC  = 0.6459    (using best distance threshold   = 0.1204)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 4 ---
Train Loss = 0.0091   Train Accuracy = 0.7059    Train AUC = 0.7083
Eval Loss  = 0.0036   Eval Accuracy  = 0.8000    Eval AUC  = 0.8278    (using best distance threshold   = 0.0770)
Test Loss  = 0.0040   Test Accuracy  = 0.7867    Test AUC  = 0.8330    (using best distance threshold   = 0.0523)
Zero Loss  = 0.0101   Zero Accuracy  = 0.7262    Zero AUC  = 0.6899    (using best distance threshold   = 0.3314)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 5 ---
Train Loss = 0.0074   Train Accuracy = 0.4706    Train AUC = 0.6250
Eval Loss  = 0.0034   Eval Accuracy  = 0.8451    Eval AUC  = 0.8560    (using best distance threshold   = 0.0415)
Test Loss  = 0.0038   Test Accuracy  = 0.5049    Test AUC  = 0.8574    (using best distance threshold   = 0.0000)
Zero Loss  = 0.0143   Zero Accuracy  = 0.6939    Zero AUC  = 0.6592    (using best distance threshold   = 0.4344)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 6 ---
Train Loss = 0.0084   Train Accuracy = 0.4706    Train AUC = 0.5833
Eval Loss  = 0.0029   Eval Accuracy  = 0.8451    Eval AUC  = 0.8986    (using best distance threshold   = 0.1337)
Test Loss  = 0.0043   Test Accuracy  = 0.8317    Test AUC  = 0.8893    (using best distance threshold   = 0.0871)
Zero Loss  = 0.0370   Zero Accuracy  = 0.7058    Zero AUC  = 0.6909    (using best distance threshold   = 0.4438)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 7 ---
Train Loss = 0.0083   Train Accuracy = 0.5294    Train AUC = 0.7361
Eval Loss  = 0.0064   Eval Accuracy  = 0.4569    Eval AUC  = 0.6944    (using best distance threshold   = 0.0000)
Test Loss  = 0.0070   Test Accuracy  = 0.5049    Test AUC  = 0.7167    (using best distance threshold   = 0.0000)
Zero Loss  = 0.0104   Zero Accuracy  = 0.7194    Zero AUC  = 0.6902    (using best distance threshold   = 0.1218)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 8 ---
Train Loss = 0.0049   Train Accuracy = 0.9412    Train AUC = 0.8462
Eval Loss  = 0.0045   Eval Accuracy  = 0.7922    Eval AUC  = 0.7753    (using best distance threshold   = 0.0346)
Test Loss  = 0.0049   Test Accuracy  = 0.7730    Test AUC  = 0.7862    (using best distance threshold   = 0.0314)
Zero Loss  = 0.0168   Zero Accuracy  = 0.7279    Zero AUC  = 0.7443    (using best distance threshold   = 0.1785)


Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

--- Epoch 9 ---
Train Loss = 0.0141   Train Accuracy = 0.5294    Train AUC = 0.5486
Eval Loss  = 0.0062   Eval Accuracy  = 0.8647    Eval AUC  = 0.8924    (using best distance threshold   = 0.2038)
Test Loss  = 0.0067   Test Accuracy  = 0.8611    Test AUC  = 0.9116    (using best distance threshold   = 0.2461)
Zero Loss  = 0.1566   Zero Accuracy  = 0.7058    Zero AUC  = 0.6503    (using best distance threshold   = 1.1966)


# Display Stats

In [9]:
# read loss values from csv:
f_train_stats = os.path.join(stats_dir, 'Train.csv')
f_eval_stats = os.path.join(stats_dir, 'Eval.csv')
f_test_stats = os.path.join(stats_dir, 'Test.csv')
f_zero_stats = os.path.join(stats_dir, 'Zero.csv')

train_stats = pd.read_csv(f_train_stats)
eval_stats = pd.read_csv(f_eval_stats)
test_stats = pd.read_csv(f_test_stats)
zero_stats = pd.read_csv(f_zero_stats)

best_test_auc = test_stats["auc"].max()
best_zero_auc = zero_stats["auc"].max()

# create a dataframe with CONSTANTS and best accuracies
df = pd.DataFrame()
for k, v in CONSTANTS.items():
    if k not in ["VOC_NAMES", "CONCAT", "NUM_LABELS", "CONF_THRESHOLD"]:
        df[k] = [v] # if v is not None else ["N/A"]

df["MAX_TEST_ACC"] = best_test_auc
df["MAX_ZERO_ACC"] = best_zero_auc

display(df)

# save the dataframe to a csv file under stats_dir
df.to_csv(os.path.join(stats_dir, "summary.csv"), index=False)

# append row to global_stats.csv
if not os.path.exists("global_stats.csv") or os.path.getsize("global_stats.csv") == 0:
    df.to_csv("global_stats.csv", index=False)
else:
    global_stats = pd.read_csv("global_stats.csv")
    global_stats = pd.concat([global_stats, df], ignore_index=True)
    global_stats.to_csv("global_stats.csv", index=False)

Unnamed: 0,LOSS_NAME,NEG_SET,POOLING_MODE,BATCH_SIZE,EPOCHS,LR,WD,RELU,DROPOUT,MARGIN,MAX_TEST_ACC,MAX_ZERO_ACC
0,ContrastiveLoss,Greaney,max,32,10,0.0001,0.001,0.3,0.5,0.2,0.9116,0.7443
