In [None]:
import os
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    PROJECT_ROOT = Path.cwd().parent.resolve()

os.chdir(PROJECT_ROOT)

In [None]:
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
import pandas as pd
from sceptr import variant
from sklearn.decomposition import PCA as Reducer
# from umap import UMAP as Reducer

from cached_representation_model import CachedRepresentationModel
from hugging_face_lms import TcrBert, ProtBert, Esm2

plt.style.use("ggplot")
plt.style.use("my.mplstyle")

In [None]:
finetuning_epitopes = pd.read_csv("tcr_data/preprocessed/benchmarking/train_valid.csv").Epitope.unique()
test_data = pd.read_csv("tcr_data/preprocessed/benchmarking/test.csv")

In [None]:
def generate_rep_projection_plot(model) -> Figure:
    reps = model.calc_vector_representations(test_data)
    reducer = Reducer()
    reps_projected = reducer.fit_transform(reps)[:,:2]

    fig, ax = plt.subplots()

    for epitope in finetuning_epitopes:
        mask = test_data.Epitope == epitope
        specificity_group_reps = reps_projected[mask]
        ax.scatter(*specificity_group_reps.T, label=epitope)
    
    return fig

In [None]:
fig = generate_rep_projection_plot(variant.default())

In [None]:
fig = generate_rep_projection_plot(variant.finetuned())

In [None]:
fig = generate_rep_projection_plot(CachedRepresentationModel(TcrBert()))