In [17]:
import os
from pathlib import Path
import sys

if "__project_dir__" not in globals():
    __project_dir__ = Path.cwd().parents[1].resolve()

sys.path.append(__project_dir__)
os.chdir(__project_dir__)

In [18]:
import pandas as pd
from src.model_loader import ModelLoader
import torch
from torch import Tensor
from torch.nn import Linear, Module, SoftMarginLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

### Prepare data

In [19]:
train_df = pd.read_csv("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/preprocessed/cd4_cd8/train.csv")

In [20]:
blastr = ModelLoader(Path("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/model_saves/BCDRBERT_ACL_small_double_censoring"))

In [21]:
train_reps = blastr.embed(train_df)

In [22]:
train_ys = train_df["type"].map({"CD4": 1, "CD8": -1}).to_numpy()

In [23]:
train_dl = DataLoader(dataset=list(zip(train_reps, train_ys)), batch_size=2048, shuffle=True)

In [24]:
valid_df = pd.read_csv("/home/yutanagano/UCLOneDrive/MBPhD/projects/tcr_embedder/data/preprocessed/cd4_cd8/test.csv")

In [25]:
valid_reps = blastr.embed(valid_df)

In [26]:
valid_ys = valid_df["type"].map({"CD4": 1, "CD8": -1}).to_numpy()

In [27]:
valid_dl = DataLoader(dataset=list(zip(valid_reps, valid_ys)), batch_size=2048)

### Prepare finetuning model

In [33]:
class TcellTyper(Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = Linear(64, 1, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(x)

In [34]:
tcelltyper = TcellTyper()

### Prepare loss

In [35]:
sf_loss = SoftMarginLoss()

### Prepare optimiser

In [36]:
optimiser = Adam(params=tcelltyper.parameters())

In [37]:
for epoch in range(1000):

    accuracy_tracker = []
    
    for reps, ys in valid_dl:
        preds = tcelltyper(reps)
        ys = ys.unsqueeze(-1)

        accuracy = ((preds * ys) > 0).to(float).mean()
        accuracy_tracker.append(accuracy)
    
    print(torch.tensor(accuracy_tracker).mean().item())

    for reps, ys in train_dl:
        preds = tcelltyper(reps)
        ys = ys.unsqueeze(-1)

        optimiser.zero_grad()
        loss = sf_loss(preds, ys)
        loss.backward()
        optimiser.step()

0.5025
0.507
0.512
0.5205
0.521
0.5285
0.539
0.543
0.5435
0.5475
0.551
0.556
0.557
0.5575
0.562
0.566
0.5665
0.569
0.5675
0.5665
0.566
0.5645
0.5645
0.5665
0.568
0.567
0.5665
0.5655
0.566
0.565
0.569
0.57
0.573
0.571
0.5715
0.5735
0.574
0.574
0.5755
0.5765
0.5795
0.5785
0.5775
0.5775
0.58
0.5805
0.5785
0.5785
0.578
0.578
0.578
0.579
0.5785
0.58
0.58
0.5825
0.583
0.583
0.583
0.584
0.583
0.5815
0.5815
0.582
0.582
0.5835
0.5835
0.584
0.5845
0.584
0.584
0.585
0.585
0.586
0.586
0.5855
0.5855
0.5865
0.585
0.5855
0.586
0.586
0.5855
0.585
0.587
0.588
0.5875
0.5875
0.5875
0.587
0.588
0.589
0.5905
0.5895
0.59
0.5905
0.589
0.59
0.59
0.589
0.5895
0.591
0.5925
0.589
0.589
0.587
0.5885
0.588
0.5875
0.588
0.5885
0.5885
0.588
0.5885
0.589
0.5895
0.59
0.5885
0.5895
0.589
0.5895
0.5895
0.5885
0.588
0.5885
0.5885
0.59
0.591
0.591
0.591
0.591
0.5895
0.5905
0.5905
0.59
0.591
0.591
0.591
0.592
0.5915
0.5905
0.5895
0.59
0.5895
0.59
0.589
0.589
0.5905
0.5895
0.59
0.5895
0.591
0.59
0.59
0.5905
0.589
0.5895
0.5