**Table of contents**<a id='toc0_'></a>    
- [Imports](#toc1_1_)    
  - [Building dataset](#toc1_2_)    
  - [Model training](#toc1_3_)    
  - [Lora implementation](#toc1_4_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# What is Done In This Notebook ?
- Select the appropriate model and dataset.
- Drop the unnecessary and NA columns.
- Ignore the unnecessary string parts in the columns. (For ex: the last 140 characters in a column would always be the same. Then remove that part.)
- Build the SequenceDataset class.
- Build the evaluate function 
- Train the model using accelerate library.

## <a id='toc1_1_'></a>[Imports](#toc0_)

In [22]:
!pip install focal_loss_torch
!pip install peft==0.4.0
!pip install accelerate==0.22.0

[0m

In [23]:
import torch
from transformers import AutoTokenizer, EsmForSequenceClassification
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, recall_score, precision_score, classification_report, roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.metrics import auc, precision_recall_curve, average_precision_score
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import pandas as pd
from focal_loss.focal_loss import FocalLoss
from torch.optim import AdamW
from transformers import get_scheduler
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

# Accelerate parts
from accelerate import Accelerator, notebook_launcher # main interface, distributed launcher
from accelerate.utils import set_seed # reproducability across devices

sns.set_theme()
sns.set_context('paper')

In [24]:
# cuda = torch.cuda.is_available()

In [25]:
from pathlib import Path

## Parameters

In [26]:
experimental = False
# if you set 'experimental' to True, then make this 1 to prevent confusion
experiment_on = 1 # example: 0.01
data = 'plusnegs'

model_tag = '8m' # options are 8m, 35m, 150m

# Feel free to play around with these hyper-parameters
lr = 5e-5
num_epochs = 3

# This is somewhat important, as sequences that are longer than this just get truncated to this length. 
# For very long sequences, this can render the model useless. In our experience, this generally still works fine. Notice that memory requirements grow dramatically with max_length!
# max_length = 1028
max_length = 630

if model_tag == '8m':
    model_name = "facebook/esm2_t6_8M_UR50D"
elif model_tag == '35m':
    model_name = "facebook/esm2_t12_35M_UR50D"
elif model_tag == '150m':
    model_name = "facebook/esm2_t30_150M_UR50D"
else:
    raise ValueError("Invalid model tag")

In [27]:
if data == 'plusnegs':
    train_path = Path('/kaggle/input/plusnegswholedata/traintest_plusnegs_plm.csv')
    test_path = Path('/kaggle/input/plusnegswholedata/indep_plusnegs_plm.csv')
    
if data == 'ds_complex-antigen':
    train_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_train_only-antigen.csv')
    test_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_test_only-antigen.csv')
if data == 'ds_complex-TRA':
    train_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_train_only-tra.csv')
    test_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_test_only-tra.csv')
if data == 'ds_complex-TRB':
    train_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_train_only-trb.csv')
    test_path = Path('/kaggle/input/ds-complex-single-columns/ds_complex_eos_tcrA_tcrB_peptide_linker_test_only-trb.csv')

if data == 'plusnegs_antigen':
    train_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/traintest_plusnegs_plm_only_antigen.csv')
    test_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/indep_plusnegs_plm_only_antigen.csv')
if data == 'plusnegs_TRA':
    train_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/traintest_plusnegs_plm_only_TRA.csv')
    test_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/indep_plusnegs_plm_only_TRA.csv')
if data == 'plusnegs_TRB':
    train_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/traintest_plusnegs_plm_only_TRB.csv')
    test_path = Path('/kaggle/input/d/yasininal/plusnegssinglecolumn/indep_plusnegs_plm_only_TRB.csv')
    

In [28]:
df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path)
display( df_train.head(1))

# drop unnecessary columns
df_train = df_train.drop(columns=['COMPLEX_ID'])
df_test = df_test.drop(columns=['COMPLEX_ID'])
display( df_train.head(1) )
print( df_train.shape )

# drop null columns
print('after drop null columns')
df_train = df_train.dropna()
df_test = df_test.dropna()
print(df_train.shape)

Unnamed: 0,COMPLEX_ID,antigen.epitope,TRA_aa,TRB_aa,tcr_affinity
0,pos_vdj_plm_train_tcr_ep_1,KAFSPEVIPMF,MLLLLVPVLEVIFTLGGTRAQSVTQLGSHVSVSEGALVLLRCNYSS...,MSNQVLCCVVLCFLGANTVDGGITQSPKYLFRKEGQNVTLSCEQNL...,1


Unnamed: 0,antigen.epitope,TRA_aa,TRB_aa,tcr_affinity
0,KAFSPEVIPMF,MLLLLVPVLEVIFTLGGTRAQSVTQLGSHVSVSEGALVLLRCNYSS...,MSNQVLCCVVLCFLGANTVDGGITQSPKYLFRKEGQNVTLSCEQNL...,1


(38536, 4)
after drop null columns
(38528, 4)


In [29]:
if experimental:
    df_train = df_train.sample(n=int(df_train.shape[0] * experiment_on), random_state=44).reset_index(drop=True)
if experimental:
    df_test = df_test.sample(n=int(df_test.shape[0] * experiment_on), random_state=44).reset_index(drop=True)

print('the shape after sampling')
df_train.shape

the shape after sampling


(38528, 4)

# Data Analysis

> Calculate the length of the longest string in each column. 

In [30]:
max_antigen_length = df_train['antigen.epitope'].apply(len).max()
max_TRA_length = df_train['TRA_aa'].apply(len).max()
max_TRB_length = df_train['TRB_aa'].apply(len).max()

max_antigen_length, max_TRA_length, max_TRB_length

(13, 285, 331)

In [31]:
df_train['TRA_aa'].str[-140:].unique().shape

(1,)

> The last 140 characters of TRA_aa column is **always** the same. So we will ignore them.

In [32]:
df_train['TRA_aa'] = df_train['TRA_aa'].str[-140:]

In [33]:
max_antigen_length = df_train['antigen.epitope'].apply(len).max()
max_TRA_length = df_train['TRA_aa'].apply(len).max()
max_TRB_length = df_train['TRB_aa'].apply(len).max()

max_antigen_length, max_TRA_length, max_TRB_length

(13, 140, 331)

> Now the length of TRA_aa is halved.

# Tokenize Items

In [34]:
# # Initialize pre-trained model and tokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# def tokenize_function(examples):
#     outputs = tokenizer(examples, 
#                         return_tensors='pt',
#                         truncation=True, 
#                         padding="max_length", 
#                         max_length=max_length)
#     return outputs

In [35]:
df_train_tk = pd.DataFrame()
df_train_tk['x'] = df_train['antigen.epitope'] + '<eos>' + df_train['TRA_aa'] + '<eos>' + df_train['TRB_aa']
df_train_tk['y'] = df_train['tcr_affinity']

df_test_tk = pd.DataFrame()
df_test_tk['x'] = df_test['antigen.epitope'] + '<eos>' + df_test['TRA_aa'] + '<eos>' + df_test['TRB_aa']
df_test_tk['y'] = df_train['tcr_affinity']

## <a id='toc1_2_'></a>[Building dataset](#toc0_)

In [36]:
df_train.columns

Index(['antigen.epitope', 'TRA_aa', 'TRB_aa', 'tcr_affinity'], dtype='object')

First, we need to construct our data set. Since we are training a binary sequence classifier, we need positive and negative examples. 

We simply split our data into a train/test split and place the positive and negative example sequences into separate text files.

In [37]:
class SequenceDataset(torch.utils.data.Dataset):
    """
    What is happens here is that:
    1. we get a dataframe with an x and y column.
    2. encode the x column, and reshape accordingly
    3. return encoded x column, and raw y column
    """
    def __init__(self, df, plot=False):
        self.data = df
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        if plot:
            plt.hist([len(s[0]) for s in self.data], bins=64)
            plt.show()
            
        print(f"Initialized dataset consisting of {self.data.shape[0]} sequences.")
        
    def __getitem__(self, idx):
        sequence = self.data.iloc[idx]
        sequence_x = sequence['x']
        encoded_sequence = self.tokenizer(sequence_x,
                                          padding="max_length",
                                          max_length=max_length,
                                          truncation=True, 
                                          return_tensors="pt")

        encoded_sequence["input_ids"] = encoded_sequence["input_ids"].squeeze()  # Remove the extra dimension
        encoded_sequence["attention_mask"] = encoded_sequence["attention_mask"].squeeze()
        
        return encoded_sequence, sequence['y']
    
    def __len__(self):
        return self.data.shape[0]

In [38]:
train = SequenceDataset(df_train_tk)
test = SequenceDataset(df_test_tk)

Initialized dataset consisting of 38528 sequences.
Initialized dataset consisting of 1498 sequences.


## Data Loader

## <a id='toc1_3_'></a>[Model training](#toc0_)

We are using the [ESM-2 model](https://huggingface.co/docs/transformers/model_doc/esm) by Meta.


Given a trained model, we'd like to evaluate how well we are doing.

Make a prediction for every example in the test set and report total accuracy (only makes sense if our test set is balanced!) and a calibration curve.

In [41]:
def validate(model, test_loader):
    model.eval()
    
    targets = []
    probabilities = []
    predictions = []
    
    for s, l in tqdm(test_loader):
        with torch.no_grad():
            logits = model(**s, labels=l).logits
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
#         predictions, labels = accelerator.gather_for_metrics((
#                 predictions, probs, preds, batch["label"]
#             ))
        predictions.append(preds)    
        probabilities.append(probs)
        targets.append(l)
            
    targets = torch.cat(targets).cpu()
    probabilities = torch.cat(probabilities).cpu()
    predictions = torch.cat(predictions).cpu()
            
    # Visualize calibration curve
    # prob_true, prob_pred = calibration_curve(targets, probabilities[:,1], pos_label=None, n_bins=10, strategy='uniform')
    # sns.lineplot(x=prob_true, y=prob_pred)
    # sns.lineplot(x=np.linspace(0,1,5), y=np.linspace(0,1,5))
    # plt.show()
    
    print("Test accuracy:", accuracy_score(targets, predictions))
    print("Test precision:", precision_score(targets, predictions))
    print("Test recall:", recall_score(targets, predictions))
    print('Confusion matrix:\n', confusion_matrix(targets, predictions))
    
    # Reset our model to trraining mode before exiting evaluation, 
    # so we don't forget to do this later!
    model.train()
    
    return(targets, probabilities, predictions)


> Now, we simply train our model and see how it performs.

In [47]:
# Reference for acceleration: https://www.kaggle.com/code/muellerzr/multi-gpu-and-accelerate

def accl_training_loop(mixed_precision:str="fp16", seed:int=42, batch_size:int=64):
    set_seed(42)
    accelerator = Accelerator(mixed_precision=mixed_precision) # Change to be able to run in old kaggle gpus

    with accelerator.main_process_first():
        model = EsmForSequenceClassification.from_pretrained(model_name)

    train_loader = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=4, shuffle=False)
    
    optimizer = AdamW(model.parameters(), lr=lr)

    # A linear lr schedule worked well in our experiments. 
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_scheduler(
         name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    # Replacing the standard cross-entropy loss by a focal loss improved our results slightly.
    focal = FocalLoss(gamma=0.7)

    model, optimizer, train_loader, test_loader, lr_scheduler = accelerator.prepare(model, 
                                                                                    optimizer, 
                                                                                    train_loader, 
                                                                                    test_loader, 
                                                                                    lr_scheduler)
    
    progress_bar = tqdm(range(num_training_steps))
    print('num_epochs:', num_epochs)
    
    for epoch in range(num_epochs):
        model.train()

        for idx, (s, l) in enumerate(train_loader):

            logits = model(**s, labels=l).logits
            
            loss = focal(torch.softmax(logits, dim=1), l)

            accelerator.backward(loss)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)

        validate(model, test_loader)
        # To save the model afterwards to use for inference we first wait for all of the processes to be aligned
        accelerator.wait_for_everyone() 

        # Then we unwrap the model from any distributed wrapping that was performed
        model = accelerator.unwrap_model(model)
        
        print('PARAMETERS:')
        print('experimental:', experimental)
        print('experiment_on:', experiment_on)
        print('data:', data)

In [None]:
args = ("fp16", 42, 64)
notebook_launcher(accl_training_loop, args, num_processes=2)

Launching training on one GPU.


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.dense.bias']
You

  0%|          | 0/28896 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

Test accuracy: 0.9439252336448598
Test precision: 1.0
Test recall: 0.9439252336448598
Confusion matrix:
 [[   0    0]
 [  84 1414]]
PARAMETERS:
experimental: False
experiment_on: 1
data: plusnegs


In [None]:
model.save_pretrained('./lora_tra_trb_training_peptide_all_dataset_eos')

## <a id='toc1_4_'></a>[Lora implementation](#toc0_)

For larger models, you might have to do parameter efficient fine-tuning. Here, we quickly demonstrate this using a technique called [LoRA](https://arxiv.org/abs/2106.09685).

In [None]:
model_name = "facebook/esm2_t12_35M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_name)
print('you may need to add with accelerator.main_process_first():')
model = EsmForSequenceClassification.from_pretrained(model_name)
# model = EsmForSequenceClassification.from_pretrained(model_name).cuda()


#
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, 
    lora_dropout=0.1, bias="none", target_modules=["query", "value"], 
    modules_to_save=["decode_head"], # make sure to set classification heads here to save so we do train them
)

lora_model = get_peft_model(model, peft_config)
lora_model.print_trainable_parameters()

optimizer = AdamW(lora_model.parameters(), lr=lr)
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
     name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

In [None]:
for epoch in range(num_epochs):
    lora_model.train()
    
    for s, l in train_loader:
        
        inputs = tokenizer(s, return_tensors='pt', padding="max_length", truncation=True, max_length=max_length).to("cuda")
        logits = lora_model(**inputs, labels=l).logits
        
        loss = focal(torch.softmax(logits, dim=1), l.cuda())
        
        loss.backward()
        lr_scheduler.step()
        optimizer.step()
        optimizer.zero_grad()
        
        progress_bar.update(1)
        
    validate(lora_model)

Results with Lora and a larger model do not get better, but maybe increasing the number of epochs would do the trick. Actually, the test accuracy keeps improving.

## Results:
data: plusnegs with only antigen column
model: 8M

    Test accuracy: 0.07142857142857142
    Test precision: 0.07142857142857142
    Test recall: 1.0
    Confusion matrix:
     [[   0 1872]
     [   0  144]]

data: plusnegs with only TRA columns (1/4 sample)
model: plusnegs

    Test accuracy: 0.07936507936507936
    Test precision: 0.07142857142857142
    Test recall: 0.9907407407407407
    Confusion matrix:
     [[  13 1391]
     [   1  107]]

data: plusnegs with only TRB columns (1/4 sample)
model: plusnegs

    Test accuracy: 0.07142857142857142
    Test precision: 0.07142857142857142
    Test recall: 1.0
    Confusion matrix:
     [[   0 1859]
     [   0  143]]

     
data: ds-complex with only antigen column (1/100 sample)
model: 8M

    Test accuracy: 0.9055211803902904
    Test precision: 0.9307381193124368
    Test recall: 0.8762494050452165
    Confusion matrix:
     [[1964  137]
     [ 260 1841]]
data: ds-complex with only TRA column (1/100 sample)
model: 8M

    Test accuracy: 0.5
    Test precision: 0.5
    Test recall: 1.0
    Confusion matrix:
     [[   0 2101]
     [   0 2101]]
data: ds-complex with only TRB column (1/100 sample)
model: 8M

    Test accuracy: 0.5
    Test precision: 0.5
    Test recall: 1.0
    Confusion matrix:
     [[   0 2101]
     [   0 2101]]
     