In [None]:
# !pip install datasets
# !pip install umap-learn plotly

In [188]:
import sys

from datasets import Dataset, DatasetDict
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, EsmForSequenceClassification, logging, \
    Trainer, TrainingArguments
from sklearn.model_selection import train_test_split

### Load Model

In [189]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D")

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Load Dataset

In [190]:
sep = '0'
df = pd.read_csv('train-set_full-seq.csv')

### Filter Dataset

In [191]:
from collections import Counter
ept = Counter(df['epitope_aa']).most_common()[0][0]
# ept = 'AYAQKIFKI'
ept

'KLGGALQAK'

In [192]:
epitope_df = df[df.epitope_aa == ept]
epitope_df = pd.DataFrame({'seq': epitope_df['cdr3_beta_aa'], 'label': epitope_df['label_true_pair']})
# epitope_df = pd.DataFrame({'seq': epitope_df['cdr3_alpha_aa'] + sep + epitope_df['epitope_aa']+ sep + epitope_df['cdr3_beta_aa'], 'label': epitope_df['label_true_pair']})
epitope_df

Unnamed: 0,seq,label
3271,CASGYWKLAGGPQETQYF,True
3272,CASSLYGNLGTGELFF,True
3273,CASSRGGIASGANVLTF,True
3274,CASSQGTVLQPQHF,True
3275,CARQPLRGANVLTF,True
...,...,...
7184,CASSLSTHESYNEQFF,False
7185,CASTPIESSTDTQYF,False
7186,CASSQDPPDTQYF,False
7187,CASRGGLGTEAFF,False


In [193]:
epitope_dataset = Dataset.from_pandas(epitope_df)
epitope_dataset

Dataset({
    features: ['seq', 'label', '__index_level_0__'],
    num_rows: 3918
})

In [194]:
def tokenize_function(dataset):
    return tokenizer(dataset['seq'], return_tensors='pt', max_length=len(tokenizer), padding='max_length', truncation=True)
epitope_tokenized_dataset = epitope_dataset.map(tokenize_function, batched=True, batch_size=16).with_format('torch')
epitope_tokenized_dataset

Map:   0%|          | 0/3918 [00:00<?, ? examples/s]

Dataset({
    features: ['seq', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
    num_rows: 3918
})

In [195]:
inputs = {
    'input_ids': epitope_tokenized_dataset['input_ids'],
    'attention_mask': epitope_tokenized_dataset['attention_mask']
}

with torch.no_grad():
    output = model(**inputs, output_hidden_states=True)

In [196]:
logits = output.logits
preds = logits.argmax(dim=1)
preds

tensor([1, 1, 1,  ..., 1, 1, 1])

In [197]:
logits.argmax(dim=1).unique(return_counts=True)

(tensor([0, 1]), tensor([  58, 3860]))

In [198]:
cls_tokens = output.hidden_states[-1][:,0,:]
cls_tokens.shape

torch.Size([3918, 320])

### Visualization

In [199]:
import plotly.express as px
from umap import UMAP
from sklearn.manifold import TSNE

In [200]:
cls_tokens.shape

torch.Size([3918, 320])

In [201]:
epitope_df

Unnamed: 0,seq,label
3271,CASGYWKLAGGPQETQYF,True
3272,CASSLYGNLGTGELFF,True
3273,CASSRGGIASGANVLTF,True
3274,CASSQGTVLQPQHF,True
3275,CARQPLRGANVLTF,True
...,...,...
7184,CASSLSTHESYNEQFF,False
7185,CASTPIESSTDTQYF,False
7186,CASSQDPPDTQYF,False
7187,CASRGGLGTEAFF,False


In [220]:
umap_2d = UMAP(n_components=2, init='random', random_state=0)
umap_3d = UMAP(n_components=3, init='random', random_state=0)
tsne_2d = TSNE(n_components=2)

In [221]:
umap_proj_2d = umap_2d.fit_transform(cls_tokens)
umap_proj_3d = umap_3d.fit_transform(cls_tokens)
tsne_proj_2d = tsne_2d.fit_transform(cls_tokens)


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.


n_jobs value -1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [226]:
fig = px.scatter(umap_proj_2d, x=0, y=1, color=epitope_df.label)
fig.update_traces(marker={'size': 3})
fig.show()

In [223]:
fig = px.scatter_3d(umap_proj_3d, x=0, y=1, z=2, color=epitope_df.label)
fig.update_traces(marker={'size': 3})
fig.show()

In [224]:
fig = px.scatter(tsne_proj_2d, x=0, y=1, color=epitope_df.label)
fig.update_traces(marker={'size': 3})
fig.show()