In [159]:
import os
from pathlib import Path
import sys

if "__project_dir__" not in globals():
    __project_dir__ = Path.cwd().parents[1].resolve()

sys.path.append(__project_dir__)
os.chdir(__project_dir__)

In [160]:
import pandas as pd
from src.model_loader import ModelLoader
import torch
from torch import Tensor
from torch.nn import Linear, Module, SoftMarginLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

### Load data

In [161]:
ep_data = pd.read_csv("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/preprocessed/vdjdb/evaluation_beta.csv")
bg_data = pd.read_csv("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/preprocessed/tanno/test.csv")

In [162]:
ep_data.Epitope.unique()

array(['ATDALMTGY', 'AVFDRKSDAK', 'CINGVCWTV', 'EIYKRWII', 'ELAGIGILTV',
       'FLKEKGGL', 'FRDYVDRFYKTLRAEQASQE', 'GILGFVFTL', 'GLCTLVAML',
       'GLIYNRMGAVTTEV', 'IVTDFSVIK', 'KAFSPEVIPMF', 'KLGGALQAK',
       'KRWIILGLNK', 'LLLGIGILV', 'LLQTGIHVRVSQPSL', 'LLWNGPMAV',
       'LPRRSGAAGA', 'NEGVKAAW', 'NLSALGIFST', 'NLVPMVATV',
       'PKYVKQNTLKLAT', 'QARQMVQAMRTIGTHP', 'RAKFKQLL', 'RLRAEAQVK',
       'SFHSLHLLF', 'TPQDLNTML', 'TPRVTGGGAM', 'VTEHDTLLY'], dtype=object)

In [163]:
EPITOPE = "ELAGIGILTV"

In [164]:
bg_data.Epitope = "BG"

In [165]:
ep_train = ep_data[ep_data.Epitope == EPITOPE][:50]
ep_valid = ep_data[ep_data.Epitope == EPITOPE][50:]

In [166]:
bg_train = bg_data.sample(n=50, random_state=420)
bg_valid = bg_data.sample(n=50, random_state=421)

In [167]:
train = pd.concat([ep_train, bg_train], ignore_index=True)
train = train[["TRBV", "CDR3B", "TRBJ", "Epitope"]]

valid = pd.concat([ep_valid, bg_valid], ignore_index=True)
valid = valid[["TRBV", "CDR3B", "TRBJ", "Epitope"]]

In [168]:
train_labels = train.Epitope.map({EPITOPE: 1, "BG": -1})
valid_labels = valid.Epitope.map({EPITOPE: 1, "BG": -1})

### Load BLAsTR

In [169]:
blastr = ModelLoader(Path("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/model_saves/BCDRBERT_ACL_small_double_censoring"))

### Vectorise TCRs

In [170]:
train_tcrs = blastr.embed(train)
valid_tcrs = blastr.embed(valid)

### Prepare training objects

In [171]:
class SVM(Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = Linear(64, 1, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(x)

In [172]:
model = SVM()

In [173]:
loss_fn = SoftMarginLoss()

In [174]:
optimiser = Adam(params=model.parameters())

In [175]:
train_dl = DataLoader(dataset=list(zip(train_tcrs, train_labels)), batch_size=100, shuffle=True)
valid_dl = DataLoader(dataset=list(zip(valid_tcrs, valid_labels)), batch_size=100, shuffle=True)

In [176]:
accuracy_tracker = []

for epoch in range(1000):
    batch_accuracies = []

    for tcrs, labels in valid_dl:
        preds = model(tcrs)
        labels = labels.unsqueeze(-1)

        accuracy = ((preds * labels) > 0).to(float).mean()
        batch_accuracies.append(accuracy.item())

    accuracy_tracker.append(torch.tensor(batch_accuracies).mean().item())

    for tcrs, labels in train_dl:
        preds = model(tcrs)
        labels = labels.unsqueeze(-1)

        optimiser.zero_grad()
        loss = loss_fn(preds, labels)
        loss.backward()
        optimiser.step()


print(max(accuracy_tracker))

0.6200000047683716
