##### This script defines the encoders for the multimodal model
##### Overview:
- The inputs are the ESMC and Boltz embeddings 
- 3 encoders from PEFT of ESMC

In [2]:
# Streamlined imports - removing duplicates
import esm
import pandas as pd
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.nn.utils.rnn import pad_sequence
from torch.amp import autocast
from torch.cuda.amp import GradScaler

import sys
import os
import time
import math
import random
import pickle
import subprocess
import gc
from pathlib import Path
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import display, update_display

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# ESM imports
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinTensor
from esm.models.esmc import _BatchedESMProteinTensor

# Tokenizer imports
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers import normalizers
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer
from tokenizers import decoders

# PEFT imports
from peft import get_peft_model, LoraConfig, TaskType
from peft.tuners.lora import LoraConfig, LoraModel

# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"


  import pynvml  # type: ignore[import]


In [3]:
# Get current working directory and create models folder
import os
from pathlib import Path

# Get the current working directory
current_dir = os.getcwd()
print(f"Current working directory: {current_dir}")

# Get the project root (go up one level from scripts/train)
project_root = Path(current_dir).parent.parent
print(f"Project root: {project_root}")

# Create models directory in project root
models_dir = project_root / "models"
if not models_dir.exists():
    print(f"Models directory does not exist, creating it at: {models_dir}")
    models_dir.mkdir(exist_ok=True)
print(f"Models directory at: {models_dir}")

# Also create a checkpoints subdirectory for saving model checkpoints
checkpoints_dir = models_dir / "checkpoints"
if not checkpoints_dir.exists():
    print(f"Checkpoints directory does not exist, creating it at: {checkpoints_dir}")
    checkpoints_dir.mkdir(exist_ok=True)
print(f"Checkpoints directory at: {checkpoints_dir}")


Current working directory: /home/natasha/multimodal_model/scripts/train
Project root: /home/natasha/multimodal_model
Models directory at: /home/natasha/multimodal_model/models
Checkpoints directory at: /home/natasha/multimodal_model/models/checkpoints


##### Get ESM Embeddings

In [4]:
# need to load ESM C with LM head enabled 
# expose final token embeddings before the logits head (is logits head the LM head, LM head=language modelling head)?
# collator returns: input_ids, attention_mask

device = "cuda" if torch.cuda.is_available() else "cpu"
# load model and allow lora (rather than eval mode?)

In [5]:
# size = [1,7,960], always the case?
# size = [1, 12, 960]
# size is I think batch number, sequence length, embedding dimension

df = pd.read_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv')
# Fill empty/nan values with <unk> token
df['TCRa'] = df['TCRa'].fillna('X')
df['TCRb'] = df['TCRb'].fillna('X')

# Replace empty strings with <unk>
df.loc[df['TCRa'] == '', 'TCRa'] = 'X'
df.loc[df['TCRb'] == '', 'TCRb'] = 'X'

df['TCR_full'] = df['TCRa'] + df['TCRb']
df['m_alpha'] = 1
df['m_beta'] = 1
df.loc[df['TCRa'] == 'X', 'm_alpha'] = 0
df.loc[df['TCRb'] == 'X', 'm_beta'] = 0
#df.to_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', index=False)

In [6]:
# hugging face dataset?
class TCR_dataset(Dataset):
    """Dataset for TCR data, for use in encoder training to propagate through to NC model"""
    def __init__(self, data_path, column_name='TCR_full', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name  # Store column name here
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):  # Remove column_name parameter
        row = self.data.iloc[idx]  # Fix: self.data, not self.csv
        protein = row[self.column_name]  # Use stored column name
        protein_idx = f'TCR_{idx}'
        if self.include_label:
            return protein_idx, protein, row.get('Binding', -1)
        #return protein_idx, protein
        return protein

class peptide_dataset(Dataset):
    def __init__(self, data_path, column_name='Peptide', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        peptide = row[self.column_name]
        peptide_idx = f'peptide_{idx}'
        if self.include_label:
            return peptide_idx, peptide, row.get('Binding', -1)
        #return peptide_idx, peptide
        return peptide

class HLA_dataset(Dataset):
    def __init__(self, data_path, column_name='HLA', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        hla = row[self.column_name]
        hla_idx = f'hla_{idx}'
        if self.include_label:
            return hla_idx, hla, row.get('Binding', -1)
        #return hla_idx, hla
        return hla

In [7]:
tcr = TCR_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='TCR_full')
peptide = peptide_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='Peptide')
hla = HLA_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='HLA_sequence')

# not sure I reallt need these classes??? Hmmmmmm

In [8]:
tcrs = [ESMProtein(sequence=s) for s in tcr.data['TCR_full']]
peptides = [ESMProtein(sequence=s) for s in peptide.data['Peptide']]
hlas = [ESMProtein(sequence=s) for s in hla.data['HLA_sequence']]

# can batch at the forward step, not the encoding step

#model = ESMC.from_pretrained("esmc_300m").to(device).eval()
model = ESMC.from_pretrained("esmc_300m").eval()

tcrs_data = [seq for seq in tcr]
peptides_data = [seq for seq in peptide]
hlas_data = [seq for seq in hla]

encoded_tcrs = [model.encode(p) for p in tcrs]
encoded_peptides = [model.encode(p) for p in peptides]
encoded_hlas = [model.encode(p) for p in hlas]


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

##### Mask Data and Collate for MLM for Encoders

In [9]:
# do for entire dataset
# do we also want to output attention_mask from the tokenizer?

tok = model.tokenizer
CLS_ID = tok.cls_token_id
EOS_ID = tok.eos_token_id
PAD_ID = tok.pad_token_id
MASK_ID = tok.mask_token_id

AA_IDS =  [5,10,17,13,23,16,9,6,21,12,4,15,20,18,14,8,11,22,19,7]


class EncodedSeqDataset(Dataset):
    def __init__(self, sequences, enc):     # ← now takes two arguments
        self.sequences = sequences          # list[str]
        self.input_ids = enc['input_ids']
        self.attention_mask = enc['attention_mask']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "sequence": self.sequences[idx],  # raw sequence string
            "input_ids": torch.as_tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.as_tensor(self.attention_mask[idx], dtype=torch.long),
        }


In [10]:
class MLMProteinCollator:
    def __init__(self, *, cls_id, eos_id, pad_id, mask_id, amino_acids,
                 p=0.15, min_per_seq=2, max_per_seq=45, aa_frac=0.20):
        self.CLS = cls_id
        self.EOS = eos_id
        self.PAD = pad_id
        self.MASK = mask_id
        self.aa = torch.as_tensor(amino_acids, dtype=torch.long)
        self.p = p
        self.min_per_seq = min_per_seq
        self.max_per_seq = max_per_seq
        self.aa_frac = aa_frac

    @torch.no_grad()
    def mask_batch(self, input_ids, attention_mask):
        device = input_ids.device
        aa = self.aa.to(device)

        B, L = input_ids.shape
        valid_mask = attention_mask.bool() \
                   & (input_ids != self.PAD) \
                   & (input_ids != self.CLS) \
                   & (input_ids != self.EOS)

        masked_input_ids = input_ids.clone()
        labels = torch.full_like(input_ids, -100)

        for i in range(B):
            vmask = valid_mask[i]
            if not vmask.any():
                continue

            valid_idx = vmask.nonzero(as_tuple=False).squeeze(1)  # (L_valid,)
            L_valid = valid_idx.numel()

            # how many to mask: floor(p*L_valid), clamped to [2, 45] but never > L_valid
            n = torch.floor(self.p * torch.tensor(L_valid, device=device, dtype=torch.float32)).to(torch.int64)
            n = torch.clamp(n, min=self.min_per_seq, max=min(self.max_per_seq, L_valid))
            if n.item() == 0:
                continue

            # choose n distinct valid positions
            chosen = valid_idx[torch.randperm(L_valid, device=device)[:n]]

            # split into AA vs MASK; ensure >=1 AA if n>=2
            n_amino = torch.floor(self.aa_frac * n).to(torch.int64)
            if n.item() >= 2:
                n_amino = torch.clamp(n_amino, min=1)
            n_mask = n - n_amino

            order = torch.randperm(n.item(), device=device)
            mask_pos  = chosen[order[:n_mask]]
            amino_pos = chosen[order[n_mask:]]

            # labels only at supervised positions
            labels[i, chosen] = input_ids[i, chosen]

            # apply edits
            if n_mask.item() > 0:
                masked_input_ids[i, mask_pos] = self.MASK
            if n_amino.item() > 0:
                r_idx = torch.randint(high=aa.numel(), size=(n_amino.item(),), device=device)
                masked_input_ids[i, amino_pos] = aa[r_idx]

        return masked_input_ids, labels


    def __call__(self, features):
        input_ids = torch.stack([f["input_ids"] for f in features], dim=0)
        attention_mask = torch.stack([f["attention_mask"] for f in features], dim=0)
        sequences = [f["sequence"] for f in features]
        proteins = [ESMProtein(sequence=f["sequence"]) for f in features]
        batched_clean = _BatchedESMProteinTensor(sequence=input_ids)


        masked_input_ids, labels = self.mask_batch(input_ids, attention_mask)

        # build masked sequences as strings (keep <mask>, drop CLS/EOS/PAD)
        masked_sequences = []
        for row in masked_input_ids.tolist():
            toks = collator.tokenizer.convert_ids_to_tokens(row, skip_special_tokens=False)
            aa = []
            for t in toks:
                if t in (collator.tokenizer.cls_token, collator.tokenizer.eos_token, collator.tokenizer.pad_token):
                    continue
                aa.append(t)  # AA tokens are single letters; keep "<mask>" as is
            masked_sequences.append("".join(aa))

        proteins_masked = [ESMProtein(sequence=s) for s in masked_sequences]
        batched_masked = _BatchedESMProteinTensor(sequence=masked_input_ids)


        return {
            "masked_input_ids": masked_input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "clean_input_ids": input_ids.clone(),
            "clean_sequences": sequences,                 # clean strings
            "masked_sequences": masked_sequences,   # masked strings  ← NEW
            "clean_sequences_ESMprotein": proteins,
            "masked_sequences_ESMprotein": proteins_masked,
            "masked_input_ids_ESMprotein_batched": batched_masked,
            "clean_input_ids_ESMprotein_batched": batched_clean,
        }

In [11]:
# Your existing BatchEncodings:
# clean_tcrs_tokenized, clean_peptides_tokenized, clean_hlas_tokenized
clean_tcrs_tokenized = model.tokenizer(tcrs_data, return_tensors='pt', padding=True)
clean_peptides_tokenized = model.tokenizer(peptides_data, return_tensors='pt', padding=True)
clean_hlas_tokenized = model.tokenizer(hlas_data, return_tensors='pt', padding=True)

tcr_ds = EncodedSeqDataset(tcrs_data,clean_tcrs_tokenized)
pep_ds = EncodedSeqDataset(peptides_data, clean_peptides_tokenized)
hla_ds = EncodedSeqDataset(hlas_data, clean_hlas_tokenized)

collator = MLMProteinCollator(
    cls_id=CLS_ID, eos_id=EOS_ID, pad_id=PAD_ID, mask_id=MASK_ID,
    amino_acids=AA_IDS, p=0.15, min_per_seq=2, max_per_seq=45, aa_frac=0.20
)
collator.tokenizer = model.tokenizer


tcr_loader = DataLoader(tcr_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)
pep_loader = DataLoader(pep_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)
hla_loader = DataLoader(hla_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)


# gives a batch dict from the collator with 4 keys
# input_ids, labels (original tokens only at masked positions, -100 everywhere else)
# attention_mask (0,1 for padding), clean_input_ids (clean copy of the input for clean forward pass using boltz for NC loss)

# model.to("cpu")
# torch.cuda.empty_cache()

def optimizer_to_cpu(optim):
    for st in optim.state.values():
        for k, v in list(st.items()):
            if torch.is_tensor(v):
                st[k] = v.detach().to("cpu")

# move model to CPU and delete, as now we have the correct inputs 
model.to("cpu")
del model
torch.cuda.empty_cache()

In [12]:

use_amp = torch.cuda.is_available()

#base = ESMC.from_pretrained("esmc_300m").to(device)
#base = ESMC.from_pretrained("esmc_300m")

lora_cfg = LoraConfig(
    r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
    target_modules=["out_proj", "layernorm_qkv.1"],  # inner Linear inside the Sequential
)

model_tcr = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="tcr")


#model_tcr = LoraModel(base, lora_cfg, adapter_name="tcr")

# freeze everything; unfreeze only LoRA params
for p in model_tcr.parameters():
    p.requires_grad = False
for name, p in model_tcr.named_parameters():
    if "lora_A" in name or "lora_B" in name:
        p.requires_grad = True

model_tcr.to("cuda")
model_tcr.train()

optim_tcr = torch.optim.AdamW(
    (p for p in model_tcr.parameters() if p.requires_grad),
    lr=1e-3, weight_decay=0.01
)


use_amp = torch.cuda.is_available()
scaler  = GradScaler(enabled=False)                     # <-- no scaler

for batch in tcr_loader:
    input_ids = batch["masked_input_ids"].to(device, dtype=torch.long)
    labels    = batch["labels"].to(device)

    with autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):  # <-- bf16
        out    = model_tcr(input_ids)
        logits = out.sequence_logits
        #print(type(logits), logits.shape, logits.requires_grad)  # expect: Tensor, [B,L,V], True

        loss   = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )

    loss.backward()                                     # <-- standard backward
    optim_tcr.step(); optim_tcr.zero_grad(set_to_none=True)

    del out, logits, loss, input_ids, labels, batch
    torch.cuda.synchronize()


# 1.1 seconds for 100
# 11 for 1000
# 110 for 10000
# 350 for 35000 - 5.8 minutes to train one encoder on full dataset

optimizer_to_cpu(optim_tcr)

# save tcr model

checkpoint_filename = 'tcr_encoder_checkpoint.pth'

checkpoint_dict = {
    #'epoch': num_epochs,
    'tcr_model_state_dict': model_tcr.state_dict(),
    #'pep_model_state_dict': peptide_model_nc.state_dict(), 
    'optimizer_state_dict': optim_tcr.state_dict(),
    #'final_loss': avg_epoch_loss,
    #'binding_threshold': threshold  # Add threshold to saved state
}

torch.save(checkpoint_dict, checkpoints_dir/checkpoint_filename)

print(f"Checkpoint saved to {checkpoint_filename}")

model_tcr.to("cpu")

del optim_tcr, model_tcr, checkpoint_dict, checkpoint_filename

torch.cuda.empty_cache()  # Free up GPU memory



  scaler  = GradScaler(enabled=False)                     # <-- no scaler


Checkpoint saved to tcr_encoder_checkpoint.pth


In [13]:
# peptide encoder

use_amp = torch.cuda.is_available()

#base = ESMC.from_pretrained("esmc_300m").to(device)

## use the same lora config for peptide and HLA - modify later
# lora_cfg = LoraConfig(
#     r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
#     target_modules=["out_proj", "layernorm_qkv.1"],  # inner Linear inside the Sequential
# )

model_pep = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="pep")
#model_pep = LoraModel(base, lora_cfg, adapter_name="peptide")

# freeze everything; unfreeze only LoRA params
for p in model_pep.parameters():
    p.requires_grad = False
for name, p in model_pep.named_parameters():
    if "lora_A" in name or "lora_B" in name:
        p.requires_grad = True

model_pep.to("cuda")
model_pep.train()

optim_pep = torch.optim.AdamW(
    (p for p in model_pep.parameters() if p.requires_grad),
    lr=1e-3, weight_decay=0.01
)

use_amp = torch.cuda.is_available()
scaler  = GradScaler(enabled=False)                     # <-- no scaler

for batch in pep_loader:
    input_ids = batch["masked_input_ids"].to(device, dtype=torch.long)
    labels    = batch["labels"].to(device)

    with autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):  # <-- bf16
        out    = model_pep(input_ids)
        logits = out.sequence_logits
        #print(type(logits), logits.shape, logits.requires_grad)  # expect: Tensor, [B,L,V], True

        loss   = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )

    loss.backward()                                     # <-- standard backward
    optim_pep.step(); optim_pep.zero_grad(set_to_none=True)

    del out, logits, loss, input_ids, labels, batch
    torch.cuda.synchronize()

optimizer_to_cpu(optim_pep)

# save peptide model

checkpoint_filename = 'peptide_encoder_checkpoint.pth'

checkpoint_dict = {
    #'epoch': num_epochs,
    'peptide_model_state_dict': model_pep.state_dict(),
    'optimizer_state_dict': optim_pep.state_dict(),
    #'final_loss': avg_epoch_loss,
    #'binding_threshold': threshold  # Add threshold to saved state
}

torch.save(checkpoint_dict, checkpoints_dir/checkpoint_filename)

print(f"Checkpoint saved to {checkpoint_filename}")

model_pep.to("cpu")
del optim_pep, model_pep, checkpoint_dict, checkpoint_filename
torch.cuda.empty_cache() 

#print(torch.cuda.memory_summary())           # “Active” bytes should drop


  scaler  = GradScaler(enabled=False)                     # <-- no scaler


Checkpoint saved to peptide_encoder_checkpoint.pth


In [14]:
# HLA encoder

use_amp = torch.cuda.is_available()

#base = ESMC.from_pretrained("esmc_300m").to(device)

## use the same lora config for peptide and HLA - modify later
# lora_cfg = LoraConfig(
#     r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
#     target_modules=["out_proj", "layernorm_qkv.1"],  # inner Linear inside the Sequential
# )

#model_hla = LoraModel(base, lora_cfg, adapter_name="hla")
model_hla = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="hla")


# freeze everything; unfreeze only LoRA params
for p in model_hla.parameters():
    p.requires_grad = False
for name, p in model_hla.named_parameters():
    if "lora_A" in name or "lora_B" in name:
        p.requires_grad = True

model_hla.to("cuda")
model_hla.train()

optim_hla = torch.optim.AdamW(
    (p for p in model_hla.parameters() if p.requires_grad),
    lr=1e-3, weight_decay=0.01
)

use_amp = torch.cuda.is_available()
scaler  = GradScaler(enabled=False)                     # <-- no scaler

for batch in hla_loader:
    input_ids = batch["masked_input_ids"].to(device, dtype=torch.long)
    labels    = batch["labels"].to(device)

    with autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):  # <-- bf16
        out    = model_hla(input_ids)
        logits = out.sequence_logits
        #print(type(logits), logits.shape, logits.requires_grad)  # expect: Tensor, [B,L,V], True

        loss   = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )

    loss.backward()                                     # <-- standard backward
    optim_hla.step()
    optim_hla.zero_grad(set_to_none=True)

    # clear memory
    del out, logits, loss, input_ids, labels, batch
    torch.cuda.synchronize()


optimizer_to_cpu(optim_hla)

# save HLA model

checkpoint_filename = 'hla_encoder_checkpoint.pth'

checkpoint_dict = {
    #'epoch': num_epochs,
    'hla_model_state_dict': model_hla.state_dict(),
    'optimizer_state_dict': optim_hla.state_dict(),
    #'final_loss': avg_epoch_loss,
    #'binding_threshold': threshold  # Add threshold to saved state
}

torch.save(checkpoint_dict, checkpoints_dir/checkpoint_filename)

print(f"Checkpoint saved to {checkpoint_filename}")

model_hla.to("cpu")
del optim_hla, model_hla, checkpoint_dict, checkpoint_filename
torch.cuda.empty_cache() 

  scaler  = GradScaler(enabled=False)                     # <-- no scaler


Checkpoint saved to hla_encoder_checkpoint.pth


#### Projection Step

In [15]:
device = 'cpu'

In [16]:
"""
LOADING ENCODERS FOR USE AS FIXED FEATURE EXTRACTORS

Use case: You're NOT continuing to train the encoders. Instead, you'll use them
as fixed/pre-trained components in the next part of your training pipeline (likely
the multimodal model that combines TCR, peptide, and HLA embeddings).

WHAT YOU NEED:
✅ Model weights (to get the fine-tuned LoRA adapters)
✅ Set to .eval() mode (to disable dropout, use batch norm stats, etc.)

WHAT YOU DON'T NEED:
❌ Optimizer state - only needed if you were continuing to train the encoders

NOTE: The checkpoint contains 'optimizer_state_dict', but we're ignoring it
since we're using these models as frozen feature extractors.
"""

# Step 1: Load the checkpoint dictionaries (just Python dicts from disk)
# checkpoint_tcr = torch.load(checkpoints_dir/'tcr_encoder_checkpoint.pth', map_location=device)
# checkpoint_pep = torch.load(checkpoints_dir/'peptide_encoder_checkpoint.pth', map_location=device)
# checkpoint_hla = torch.load(checkpoints_dir/'hla_encoder_checkpoint.pth', map_location=device)
checkpoint_tcr = torch.load(checkpoints_dir/'tcr_encoder_checkpoint.pth', map_location='cpu')
checkpoint_pep = torch.load(checkpoints_dir/'peptide_encoder_checkpoint.pth', map_location='cpu')
checkpoint_hla = torch.load(checkpoints_dir/'hla_encoder_checkpoint.pth', map_location='cpu')

# Step 2: Recreate the model architectures (same as during training)
lora_cfg = LoraConfig(
    r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
    target_modules=["out_proj", "layernorm_qkv.1"],
)

# Create models with same architecture as during training
tcr_encoder = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="tcr")
peptide_encoder = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="pep")
hla_encoder = LoraModel(ESMC.from_pretrained("esmc_300m"), lora_cfg, adapter_name="hla")

# Step 3: Load the model state dictionaries (this loads base weights + LoRA adapters)
# NOTE: We're NOT loading optimizer_state_dict - not needed for inference/feature extraction
tcr_encoder.load_state_dict(checkpoint_tcr['tcr_model_state_dict'])
peptide_encoder.load_state_dict(checkpoint_pep['peptide_model_state_dict'])
hla_encoder.load_state_dict(checkpoint_hla['hla_model_state_dict'])

# Step 4: Set to evaluation mode and move to CPU
# .eval() mode ensures: no dropout, batch norm uses running stats, etc.
tcr_encoder.to('cpu').eval()
peptide_encoder.to('cpu').eval()
hla_encoder.to('cpu').eval()

print("✓ Encoders loaded successfully as fixed feature extractors!")
print("  (Optimizer state skipped - not needed for this use case)")



✓ Encoders loaded successfully as fixed feature extractors!
  (Optimizer state skipped - not needed for this use case)


##### a) Factorisation/Projection for ESMC Encoders
- N.B. From next cell onwards, need to reconfigure to do for loop for all batches

In [17]:
# 1. Pick a batch and get token tensors and masks

tcr_batch = next(iter(tcr_loader))
pep_batch = next(iter(pep_loader))
hla_batch = next(iter(hla_loader))

tcr_ids = tcr_batch['clean_input_ids']
tcr_mask = tcr_batch['attention_mask']

pep_ids = pep_batch['clean_input_ids']
pep_mask = pep_batch['attention_mask']

hla_ids = hla_batch['clean_input_ids']
hla_mask = hla_batch['attention_mask']

# Unsqueze to make the right dims
mT = tcr_mask.unsqueeze(-1).float()   # (B, L_T_pad, 1)
mP = pep_mask.unsqueeze(-1).float()   # (B, L_P_pad, 1)
mH = hla_mask.unsqueeze(-1).float()   # (B, L_H_pad, 1)


B = tcr_ids.size(0)


In [18]:
# Call the BASE model inside the LoRA wrapper
with torch.no_grad():
    out_T = tcr_encoder.model.forward(sequence_tokens=tcr_ids)
    out_P = peptide_encoder.model.forward(sequence_tokens=pep_ids)
    out_H = hla_encoder.model.forward(sequence_tokens=hla_ids)

# Some ESMC builds expose .embeddings; otherwise take last hidden state
emb_T = out_T.embeddings 
emb_P = out_P.embeddings 
emb_H = out_H.embeddings

# it is true that the shape of the embeddings is (B, L_pad, D)
# takes 18 seconds on CPU to do 24 embeddings

In [19]:
# maximum true tcr length 
L_T_true = tcr_mask.sum(dim=1)
L_T_max = L_T_true.max()

# maximum true peptide length 
L_P_true = pep_mask.sum(dim=1)
L_P_max = L_P_true.max()

# maximum true HLA length 
L_H_true = hla_mask.sum(dim=1)
L_H_max = L_H_true.max()

In [20]:
# Factorised Encoder to get z_T and Z_pMHC
# z = vec(A^TXB)H
# X - (B, L_pad, D)
# B - (D, rD)
# A - (L_pad, rL)
# H - (rD * rL, d)


eps = 1e-8

class ESMFactorisedEncoder(nn.Module):
    def __init__(self, D, rL, rD, d, L_max):
        """
        D    : ESM embedding dim (e.g. 960)
        rL   : positional rank
        rD   : channel rank
        d    : latent dim
        L_max: max true length for this modality in the batch
        """
        super().__init__()
        self.D   = D
        self.rL  = rL
        self.rD  = rD
        self.d   = d
        self.L_max = L_max

        # Channel mixing: D -> rD
        self.B_c = nn.Parameter(torch.empty(D, rD))
        nn.init.xavier_uniform_(self.B_c)

        # Positional mixing: positions 0..L_max-1 -> rL
        self.A_c = nn.Parameter(torch.empty(L_max, rL))
        nn.init.xavier_uniform_(self.A_c)

        # Final map: (rL * rD) -> d
        self.H_c = nn.Parameter(torch.empty(rL * rD, d))
        nn.init.xavier_uniform_(self.H_c)

    def forward(self, emb, mask):
        """
        emb  : (B, L_pad, D) token embeddings
        mask : (B, L_pad)   1 = real token, 0 = pad
        returns z : (B, d)
        """
        device = emb.device
        B, L_pad, D = emb.shape
        assert D == self.D

        # Compute true lengths
        L_true = mask.sum(dim=1)            # (B,)
        z_list = []

        for b in range(B):
            Lb = int(L_true[b].item())
            if Lb == 0:
                # Degenerate case: no tokens -> zero vector
                z_b = torch.zeros(self.d, device=device)
                z_list.append(z_b)
                continue

            Xb = emb[b, :Lb, :]                      # (Lb, D)
            mb = mask[b, :Lb].unsqueeze(-1).float()  # (Lb, 1)
            Xb = Xb * mb                             # (Lb, D)

            # 1) Channel compression: D -> rD
            Yb = Xb @ self.B_c                       # (Lb, rD)

            # 2) Positional compression: Lb -> rL
            A_pos = self.A_c[:Lb, :]                 # (Lb, rL)
            Ub = A_pos.T @ Yb                        # (rL, rD)

            # 3) Flatten and map to latent d
            Ub_flat = Ub.reshape(-1)                 # (rL * rD,)
            z_b = Ub_flat @ self.H_c                 # (d,)

            # 4) Normalise (optional; you can drop this if you want magnitude to carry info)
            #z_b = z_b / (z_b.norm() + eps)
            #normalise after function because need to combine p and hla first

            z_list.append(z_b)

        z = torch.stack(z_list, dim=0)               # (B, d)
        return z


In [21]:
# Shape
B, L_T_pad, D = emb_T.shape

# Latent ranks and final dimension (hyperparameters)
rL = 8      # positional rank for TCR (tunable)
rD = 16     # channel rank for TCR (tunable)
d    = 128    # final latent dimension (same d as in Z*)

# ratio of peptide to HLA (hyperparameter)
#R = 0.7
# Epsilon for numerical stability
eps=1e-8

# maximum true lenghts of the sequences
# maximum true tcr length 
L_T_true = tcr_mask.sum(dim=1)
L_T_max = L_T_true.max()
# maximum true peptide length 
L_P_true = pep_mask.sum(dim=1)
L_P_max = L_P_true.max()
# maximum true HLA length 
L_H_true = hla_mask.sum(dim=1)
L_H_max = L_H_true.max()


tcr_encoder_new = ESMFactorisedEncoder(D, rL, rD, d, L_max=L_T_max).to(device)
pep_encoder_new = ESMFactorisedEncoder(D, rL, rD, d, L_max=L_P_max).to(device)
hla_encoder_new = ESMFactorisedEncoder(D, rL, rD, d, L_max=L_H_max).to(device)

# when you call module(), you are calling the forward method (PyTorch convention)
zT = tcr_encoder_new(emb_T, tcr_mask)
zP = pep_encoder_new(emb_P, pep_mask)
zH = hla_encoder_new(emb_H, hla_mask)

# normalise T as we are not gating this
zT = zT / (zT.norm(dim=-1, keepdim=True) + eps)



In [22]:
# change to what Barbara suggested in meeting - scale within the projection learning

R_PH = 0.7  # peptide gets more weight

gP = (R_PH ** 0.5)          # scalar
gH = ((1.0 - R_PH) ** 0.5)  # scalar

gP_t = torch.tensor(gP, device=device)
gH_t = torch.tensor(gH, device=device)

# zP, zH: (B, d)
zPH = gP_t * zP + gH_t * zH      # (B, d)
# Optionally normalise:
zPH = zPH / (zPH.norm(dim=-1, keepdim=True) + eps)

# concatenate into e_hat
e_hat = torch.cat([zT, zPH], dim=-1)



In [23]:
# next step to apply non linear layers (potentially ReLU as middle layers, as long as the end is linear transformation and can learn negative values)

##### Get Boltz Embeddings

In [24]:
# get Boltz embeddings
file_path = '/home/natasha/multimodal_model/outputs/boltz_runs/positives/pair_000/boltz_results_pair_000/predictions/pair_000/embeddings_pair_000.npz'
manifest_path = '/home/natasha/multimodal_model/data/manifests/boltz_100_manifest.csv'

In [25]:
import os
print(os.getcwd())

home = '/home/natasha/multimodal_model'
manifest_path = os.path.join(home, 'data', 'manifests', 'boltz_100_manifest.csv')
print(manifest_path)

/home/natasha/multimodal_model/scripts/train
/home/natasha/multimodal_model/data/manifests/boltz_100_manifest.csv


In [26]:
class BoltzDataset(Dataset):
    """
    Dataset for loading Boltz z-embeddings one by one,
    with chain lengths from the manifest.
    Each pair has its own .npz file.
    ORIGINAL VERSION - returns numpy arrays
    """
    def __init__(self, manifest_path, base_path):
        self.manifest = pd.read_csv(manifest_path)
        self.base_path = base_path

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        yaml_rel_path = self.manifest.iloc[idx]['yaml_path']
        pair_id = os.path.splitext(os.path.basename(yaml_rel_path))[0]
        emb_path = os.path.join(
            self.base_path,
            'outputs',
            'boltz_runs',
            'positives',
            pair_id,
            f'boltz_results_{pair_id}',
            'predictions',
            pair_id,
            f'embeddings_{pair_id}.npz'
        )
        with np.load(emb_path) as arr:
            z = arr['z']  # Returns numpy array as-is
        pep_len = self.manifest.iloc[idx]['pep_len']
        tcra_len = self.manifest.iloc[idx]['tcra_len']
        tcrb_len = self.manifest.iloc[idx]['tcrb_len']
        hla_len = self.manifest.iloc[idx]['hla_len']
        return z, pep_len, tcra_len, tcrb_len, hla_len


def boltz_collate_fn(batch):
    """
    Uses numpy for padding, then converts to torch at the end
    """
    zs, pep_lens, tcra_lens, tcrb_lens, hla_lens = zip(*batch)
    # Each z has shape [sum_of_lengths, sum_of_lengths, dim] or [1, sum_of_lengths, sum_of_lengths, dim]
    # Pad zs to max shape with zeros and stack into tensor
    zs = [np.squeeze(z, axis=0) for z in zs]
    max_len = max(z.shape[0] for z in zs)
    # get channel dimension number from first z
    dim = zs[0].shape[-1]
    padded_zs = np.zeros((len(zs), max_len, max_len, dim), dtype=zs[0].dtype)
    for i, z in enumerate(zs):
        l = z.shape[0]
        padded_zs[i, :l, :l, :] = z
    zs = torch.from_numpy(padded_zs).float()  # or .to(device)
    
    return {
        "z": zs,  # batch of z arrays, each possibly of different shape
        "pep_len": np.array(pep_lens),
        "tcra_len": np.array(tcra_lens),
        "tcrb_len": np.array(tcrb_lens),
        "hla_len": np.array(hla_lens),
    }

# Example usage with original version:
dataset_original = BoltzDataset(manifest_path, home)
dataloader_original = DataLoader(
    dataset_original, batch_size=8, shuffle=True, collate_fn=boltz_collate_fn
)


batch = next(iter(dataloader_original))
print("Batch z shapes:", [z.shape for z in batch["z"]])
print("Pep lengths:", batch["pep_len"])

# 5.6s to load 100 pairs
# 56 s to load 1000 pairs
# 560 s to load 10000 pairs (9 mins)
# 5600 s to load 100000 pairs (16 mins)

Batch z shapes: [torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128]), torch.Size([604, 604, 128])]
Pep lengths: [ 9  9 10  9  9 10  9 10]


In [30]:
L_T = batch["tcra_len"] + batch["tcrb_len"]
L_P = batch["pep_len"]
L_H = batch["hla_len"]

L_T_max = max(L_T)
L_P_max = max(L_P)
L_H_max = max(L_H)

print(L_T_max, L_P_max, L_H_max, batch['tcra_len'], batch['tcrb_len'], batch['pep_len'], batch['hla_len'])

230 10 365 [112 114 112 110 113 112 113 112] [116 116 113 114 115 114 115 117] [ 9  9 10  9  9 10  9 10] [362 365 365 365 365 365 365 365]


In [None]:
# Factorisation Z

for i in range(len(batch["z"])):
    L_T_max = batch["tcra_len"][i]
    L_P_max = batch["pep_len"][i]
    L_H_max = batch["hla_len"][i]



dB  = 128 # dimension of Boltz embeddings
rB  = 16 # Boltz channel rank
rT  = 8 # TCR positional rank
rPH = 8 # pMHC positional rank

B_Z = torch.nn.Parameter(torch.empty(dB, rB))
nn.init.xavier_uniform_(B_Z)


L_T_max = int()

In [None]:
class BoltzFactorised(nn.Module):
    """
    Factorised Boltz embeddings for projection into latent shared space before NC loss

    Inputs:
    - z_boltz: (B, L_pad, L_pad, d_boltz) full Boltz z for the batch
    - L_alpha, L_beta, L_p, L_h: (B,) lengths of the TCR alpha, TCR beta, peptide, HLA
    - gP, gH: (B,) scalar or (B,) peptide/HLA gates in [0,1], norm-preserving in quadrature???? As in, gP**2 + gH**2 = 1

    Outputs:
    - Zstar_batch: (B, 2d, 2d) operator acting on e_hat_t, e_hat_pmc in R^2d
    """
    def __init__(self, dB, rB, rT, rPH, d, L_max, L_PH_max):
        """
        dB      : channel dimension of Boltz embeddings
        rB      : rank of Boltz channel factorisation
        rT      : rank of TCR positional encoding
        rPH     : rank of pMHC positional encoding
        d       : latent dimension of shared space
        L_T_max   : maximum length of any sequence in the batch
        L_PH_max: maximum length of any pMHC sequence in the batch
        """
        super().__init__()
        self.dB     = dB
        self.rB     = rB
        self.rT     = rT
        self.rPH    = rPH
        self.d      = d
        self.L_T_max  = L_max
        self.L_PH_max = L_PH_max

        # ---- 1) Channel mixing: dB -> rB ---- 
        self.B_Z = torch.nn.Parameter(torch.empty(dB, rB))
        nn.init.xavier_uniform_(self.B_Z)

        # ---- 2)a) TCR positional encoding: rT -> rT ---- 
        self.A_T = torch.nn.Parameter(torch.empty(L_T_max, rT))
        nn.init.xavier_uniform_(self.A_T)

        # ---- 2)b) pMHC positional encoding: rPH -> rPH ---- 
        self.A_PH = torch.nn.Parameter(torch.empty(L_PH_max, rPH))
        nn.init.xavier_uniform_(self.A_PH)

        # ---- 3) Learnable maps from factorised z (r* x r* x rB) -> d x d ---- 
        # flatten sizes for each block
        n_TT   = rT  * rT  * rB
        n_TPH  = rT  * rPH * rB
        n_PHT  = rPH * rT  * rB
        n_PHPH = rPH * rPH * rB
        dd     = d * d

        self.H_TT   = nn.Parameter(torch.empty(n_TT,   dd))
        self.H_TPH  = nn.Parameter(torch.empty(n_TPH,  dd))
        self.H_PHT  = nn.Parameter(torch.empty(n_PHT,  dd))
        self.H_PHPH = nn.Parameter(torch.empty(n_PHPH, dd))

        nn.init.xavier_uniform_(self.H_TT)
        nn.init.xavier_uniform_(self.H_TPH)
        nn.init.xavier_uniform_(self.H_PHT)
        nn.init.xavier_uniform_(self.H_PHPH)

        # ---- 4) Final linear layer: d -> d ---- 
        self.W_out = nn.Parameter(torch.empty(d, d))
        nn.init.xavier_uniform_(self.W_out)
    
    def _get_gate_scalar(self, g, b):
        """
        Helper: allow g to be a scalar tensor () or per-sample tensor (B,).
        Returns a Python float for sample b.
        """
        if g.dim() == 0:
            return float(g.item())
        else:
            return float(g[b].item())

    def forward(self, z_boltz, L_alpha, L_beta, L_p, L_h, gP, gH):
        """
        Z_boltz : (B, L_pad, L_pad, dB)
        L_alpha : (B,) true alpha lengths
        L_beta  : (B,) true beta lengths
        L_p     : (B,) true peptide lengths
        L_h     : (B,) true HLA lengths
        gP      : scalar () or (B,) peptide gate (already sqrt(R_PH))
        gH      : scalar () or (B,) HLA gate (already sqrt(1-R_PH))

        Returns:
          Zstar_batch: (B, 2d, 2d)
        """

        device = z_boltz.device
        B, L_pad, _, dB = z_boltz.shape
        assert dB == self.dB

        Zstar_list = []

        for b in range(B):
            La  = int(L_alpha[b].item())
            Lb  = int(L_beta[b].item())
            Lp_ = int(L_p[b].item())
            Lh_ = int(L_h[b].item())

            L_T     = La + Lb
            L_PH    = Lp_ + Lh_
            L       = L_T + L_PH

            # if we have missing z it just returns identity
            if L == 0:
                I_2d = torch.eye(2* self.d, device=device)
                Zstar_list.append(I_2d)
                continue
            

            # restrict to true tokens for the sample
            Z = z_boltz[b, :L, :L, :] # (L, L, dB)

            # ---- 2) Per channel normalisation (right now omit this step) ----
            # another potential here is to use Adam optimiser to learn the normalisation
            # mu = Z.mean(dim=(0,1), keepdim=True) # (1, 1, dB)
            # std = Z.std(dim=(0,1), keepdim=True) # (1, 1, dB)
            # Zc = (Z - mu) / std            # (L, L, dB)
            # Zc = Zc / (math.sqrt(L) + eps) # scale by sqrt(L)
            Zc = Z.clone()

            # ----- 3) Gating ----
            # TCR gates
            if La > 0 and Lb > 0:
                gA_b = 2**-0.5
                gB_b = 2**-0.5
            elif La > 0 and Lb == 0:
                gA_b = 1
                gB_b = 0
            elif La == 0 and Lb > 0:
                gA_b = 0
                gB_b = 1
            else:
                gA_b = 0
                gB_b = 0

            # Peptide/HLA gates (from R set in encoder part)
            gP_b = self._get_gate_scalar(gP, b)
            gH_b = self._get_gate_scalar(gH, b)

            # ---- 4) Build token-level gate vector over [alpha | beta | p | h] ----
            gate = torch.zeros(L, device=device) # (L,)

            idx0 = 0
            idx1 = idx0 + La
            idx2 = idx1 + Lb
            idx3 = idx2 + Lp_
            idx4 = idx3 + Lh_

            if La  > 0: gate[idx0:idx1] = gA_b
            if Lb  > 0: gate[idx1:idx2] = gB_b
            if Lp_ > 0: gate[idx2:idx3] = gP_b
            if Lh_ > 0: gate[idx3:idx4] = gH_b

            gate_row = gate.view(L, 1, 1) # (L, 1, 1)
            gate_col = gate.view(1, L, 1) # (1, L, 1)

            Zg = Zc * gate_row * gate_col # (L, L, dB)

            # ---- 5) Get TCR/pMHC blocks ----
            sT = slice(0, L_T)           # [0, L_T) -> TCR (alpha + beta)
            sPH = slice(L_T, L_T + L_PH) # [L_T, L] -> pMHC (P+H)

            Z_TT  = Zg[sT, sT, :]    # (L_T, L_T, dB)
            Z_TPH = Zg[sT, sPH, :]   # (L_T, L_PH, dB)
            Z_PHT = Zg[sPH, sT, :]   # (L_PH, L_T, dB)
            Z_PHPH = Zg[sPH, sPH, :] # (L_PH, L_PH, dB)

            # ---- 6) channel/dimension compression ----
            B_Z = self.B_Z # (dB, rB) operator across channels
            Y_TT   = torch.einsum('ijc,cr->ijr', Z_TT,   B_Z)   # (L_T,  L_T,  rB)
            Y_TPH  = torch.einsum('ijc,cr->ijr', Z_TPH,  B_Z)   # (L_T,  L_PH, rB)
            Y_PHT  = torch.einsum('ijc,cr->ijr', Z_PHT,  B_Z)   # (L_PH, L_T,  rB)
            Y_PHPH = torch.einsum('ijc,cr->ijr', Z_PHPH, B_Z)   # (L_PH, L_PH, rB)

            # ---- 7) TCR positional compression with A_T / A_PH ----
            # Sample-specific rows for the correct lengths for the per-sample positional tensors
            if L_T > 0:
                A_T_b = self.A_T[:L_T, :] # (L_T, rT)
            else: 
                # no TCRs, so treat as zero contribution (or could be noise, decide later)
                A_T_b = self.A_T[:1, :] * 0.0 # (1, rT) dummy

            if L_PH > 0:
                A_PH_b = self.A_PH[:L_PH, :] # (L_PH, rPH)
            else:
                # no pMHCs, so treat as zero contribution (or could be noise, decide later)
                A_PH_b = self.A_PH[:1, :] * 0.0 # (1, rPH) dummy
            
            rT  = self.rT
            rPH = self.rPH
            rB  = self.rB
            d   = self.d

            # N.B. U tensors are not learned, they are discarded as intermediary steps for compression
            # TCR-TCR (L_T, L_T, rB) -> (rT, rT, rB)
            if L_T > 0:
                U_TT = torch.einsum('ip,ijr->pjr', A_T_b, Y_TT) # (rT, L_T, rB)
                V_TT = torch.einsum('pjr,jq->pqr', U_TT, A_T_b) # (rT, rT, rB)
            else:
                V_TT = torch.zeros(rT, rT, rB, device=device) # (rT, rT, rB)

            # TCR–pMHC: (L_T, L_PH, rB) -> (rT, rPH, rB)
            if L_T > 0 and L_PH > 0:
                U_TPH = torch.einsum('ip,ijr->pjr', A_T_b,  Y_TPH)   # (rT,  L_PH, rB)
                V_TPH = torch.einsum('pjr,jq->pqr', U_TPH, A_PH_b)   # (rT,  rPH, rB)
            else:
                V_TPH = torch.zeros(rT, rPH, rB, device=device)

            # pMHC–TCR: (L_PH, L_T, rB) -> (rPH, rT, rB)
            if L_T > 0 and L_PH > 0:
                U_PHT = torch.einsum('ip,ijr->pjr', A_PH_b, Y_PHT)   # (rPH, L_T,  rB)
                V_PHT = torch.einsum('pjr,jq->pqr', U_PHT, A_T_b)    # (rPH, rT,  rB)
            else:
                V_PHT = torch.zeros(rPH, rT, rB, device=device)

            # pMHC–pMHC: (L_PH, L_PH, rB) -> (rPH, rPH, rB)
            if L_PH > 0:
                U_PHPH = torch.einsum('ip,ijr->pjr', A_PH_b, Y_PHPH) # (rPH, L_PH, rB)
                V_PHPH = torch.einsum('pjr,jq->pqr', U_PHPH, A_PH_b) # (rPH, rPH, rB)
            else:
                V_PHPH = torch.zeros(rPH, rPH, rB, device=device)

            # ---- 8) Flatten factorised blocks and map to d×d via H_* ----
            v_TT_flat   = V_TT.reshape(-1)    # (rT*rT*rB,)
            v_TPH_flat  = V_TPH.reshape(-1)   # (rT*rPH*rB,)
            v_PHT_flat  = V_PHT.reshape(-1)   # (rPH*rT*rB,)
            v_PHPH_flat = V_PHPH.reshape(-1)  # (rPH*rPH*rB,)

            k_TT_flat   = v_TT_flat   @ self.H_TT   # (d*d,)
            k_TPH_flat  = v_TPH_flat  @ self.H_TPH  # (d*d,)
            k_PHT_flat  = v_PHT_flat  @ self.H_PHT  # (d*d,)
            k_PHPH_flat = v_PHPH_flat @ self.H_PHPH # (d*d,)

            K_TT   = k_TT_flat.view(d, d)
            K_TPH  = k_TPH_flat.view(d, d)
            K_PHT  = k_PHT_flat.view(d, d)
            K_PHPH = k_PHPH_flat.view(d, d)

            # Optional: enforce symmetry on diagonal blocks
            # Enforce symmetry on all blocks?
            K_TT   = 0.5 * (K_TT   + K_TT.t())
            K_PHPH = 0.5 * (K_PHPH + K_PHPH.t())

            # ---- 9) Assemble 2d x 2d operator for this sample ----
            I_d = torch.eye(d, device=device)
            Zstar_b = torch.zeros(2*d, 2*d, device=device)

            Zstar_b[:d,  :d]  = I_d + K_TT
            Zstar_b[:d,  d:]  = I_d + K_TPH
            Zstar_b[d:,  :d]  = I_d + K_PHT
            Zstar_b[d:,  d:]  = I_d + K_PHPH

            Zstar_list.append(Zstar_b)

        Zstar_batch = torch.stack(Zstar_list, dim=0)  # (B, 2d, 2d)
        return Zstar_batch 


# index convention for the einsum
# i = row token position, “i” is historically “index” or “first axis”
# j	= column token position, second positional axis (matrix-like)
# c	= channel
# p, q	= latent positional modes, P = projection / latent position
# r = latent channel modes, R = rank / channel rank
# b	= batch index, B = batch

NameError: name 'nn' is not defined

In [None]:
# putting it all together



#### Old/Unused Code + Notes

##### Masking/Encoder Fine-tuning Notes
1. Model selection & freeze policy
- Load ESM-C base.
- Freeze embeddings + lower N blocks; leave top K blocks trainable (or use adapters/LoRA). Record N, K.
2. Two forwards you will run
- Masked forward (for loss): feed masked_sequences → get logits over vocab (B, L, V).
- Clean forward (for caching embeddings): feed sequences → get token_hidden_states (B, L, d_model).
3. Loss and metrics
- Compute token-level cross-entropy on logits vs labels (ignore_index = −100).
- Track masked-token accuracy and perplexity per modality.
5. Optimisation
- AdamW; no weight decay on LayerNorm/bias; gradient clip ~1.0; fp16/bf16 if available.
- LR schedule: warm-up then cosine/linear decay.
- Validation sets per modality with identical masking policy.
6. Checkpoints
- Save model weights + adapter/LoRA if used, optimizer, scheduler, tokenizer hash, mask rate, freeze policy.

In [51]:
# Check GPU memory
def print_gpu_memory():
    if torch.cuda.is_available():
        print("\nGPU Memory Usage:")
        print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
        print(f"Cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
        
        # Get total GPU memory
        import subprocess
        result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.total', '--format=csv,nounits,noheader'])
        total_memory = float(result.decode('utf-8').strip())
        print(f"Total GPU Memory: {total_memory:.2f} MB")
        print(f"Available: {total_memory - torch.cuda.memory_reserved()/1024**2:.2f} MB")

# Check memory before model loading
print("Before model loading:")
print_gpu_memory()


Before model loading:

GPU Memory Usage:
Allocated: 1168.29 MB
Cached: 1272.00 MB
Total GPU Memory: 16376.00 MB
Available: 15104.00 MB


In [None]:
# not necessary, ESM C has their own much more complex attention layers
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        hidden_dim: Dimensionality of the input
        num_heads: Number of attention heads to split the input into
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output part

    def check_sdpa_inputs(self, x):
        assert x.size(1) == self.num_heads, f"Expected shape (-1, {self.num_heads}, -1, {self.hidden_dim // self.num_heads}), got {tuple(x.size())}"
        assert x.size(3) == self.hidden_dim // self.num_heads

    def scaled_dot_product_attention(self,
            query,
            key,
            value,
            attention_mask=None,
            key_padding_mask=None):
        """
        query: tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim // num_heads)
        key: tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim // num_heads)
        value: tensor of shape (batch_size, num_heads, value_sequence_length, hidden_dim // num_heads)
        attention_mask: tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor of shape (sequence_length, key_sequence_length)
        """
        self.check_sdpa_inputs(query)
        self.check_sdpa_inputs(key)
        self.check_sdpa_inputs(value)

        d_k = query.size(-1) #size of the last dimension of the query tensor
        tgt_len, src_len = query.size(-2), key.size(-2)
        #logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Attention mask
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                assert attention_mask.size() == (tgt_len, src_len)
                attention_mask = attention_mask.unsqueeze(0)
                logits = logits + attention_mask
            else:
                raise ValueError(f"Invalid size of attention_mask: {attention_mask.size()}")

        # Key padding mask
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, num heads
            logits = logits + key_padding_mask

        attention = torch.softmax(logits, dim=-1)
        output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)

        return output, attention

    def split_into_heads(self, x, num_heads):
        batch_size, seq_length, hidden_dim = x.size()
        x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)

        return x.transpose(1, 2) # (batch_size, num_heads, seq_length, hidden_dim // num_heads)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_hidden_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim)
    
    def forward(
            self,
            q,
            k,
            v,
            attention_mask=None,
            key_padding_mask=None):
        """
        q: tensor of shape (batch_size, query_sequence_length, hidden_dim)
        k: tensor of shape (batch_size, key_sequence_length, hidden_dim)
        v: tensor of shape (batch_size, value_sequence_length, hidden_dim)
        attention_mask: tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor of shape (sequence_length, key_sequence_length)
        """
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q, self.num_heads) # (batch_size, num_heads, query_sequence_length, hidden_dim // num_heads)
        k = self.split_into_heads(k, self.num_heads) # (batch_size, num_heads, key_sequence_length, hidden_dim // num_heads)
        v = self.split_into_heads(v, self.num_heads) # (batch_size, num_heads, value_sequence_length, hidden_dim // num_heads)
        
        attn_values, attn_weights = self.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask
        )
        grouped = self.combine_heads(attn_values)
        output = self.Wo(grouped)

        self.attention_weights = attn_weights

        return output
        
        


In [None]:
### masking for one sequence

# # tokenize sequences
# clean_tcrs_tokenized = model.tokenizer(tcrs_data, return_tensors='pt', padding=True)
# clean_peptides_tokenized = model.tokenizer(peptides_data, return_tensors='pt', padding=True)
# clean_hlas_tokenized = model.tokenizer(hlas_data, return_tensors='pt', padding=True)

# # clean_tcrs_tokenized.keys()
# # ids = clean_tcrs_tokenized['input_ids']
# # attention_mask = clean_tcrs_tokenized['attention_mask']

# #copies for masking
# masked_tcrs_tokenized = clean_tcrs_tokenized.copy()
# masked_peptides_tokenized = clean_peptides_tokenized.copy()
# masked_hlas_tokenized = clean_hlas_tokenized.copy()

# # special tokens and amino acids
# CLS = model.tokenizer.cls_token_id
# EOS = model.tokenizer.eos_token_id
# PAD = model.tokenizer.pad_token_id
# amino_acids = torch.tensor([5, 10, 17, 13, 23, 16, 9, 6, 21, 12, 4, 15, 20, 18, 14, 8, 11, 22, 19, 7], device=ids.device)
# mask_token = 32

# # beginning of per sequence masking
# # IDS
# ids = masked_tcrs_tokenized['input_ids'][0]
# valid_mask = (ids != CLS) & (ids != EOS) & (ids != PAD)
# valid_idx = valid_mask.nonzero(as_tuple=False).squeeze(1)

# #L_valid = valid_mask.sum()
# L_valid = valid_idx.numel()
# if L_valid == 0:
#     # nothing to mask, make labels -100 to be ignored in the loss
#     masked_ids = ids.clone()
#     labels = torch.full_like(ids, -100)
# else:
#     # L_valid = valid_mask.sum(dim=1) # for batch, but currently only 1 dimension
#     n = torch.floor(0.15 * torch.tensor(L_valid, device=ids.device, dtype=torch.float32)).to(torch.int64)
#     n = n.clamp(min=2, max=min(45, L_valid))
#     # n = torch.floor(
#     #         0.15 * torch.tensor(L_valid, device=ids.device, dtype=torch.float32)
#     #         ).to(torch.int64).clamp(min=2, max=45)
    
#     # choose n distinct valid positions 
#     # reshuffle
#     perm = torch.randperm(L_valid, device=ids.device)
#     # choose first n
#     chosen_local = perm[:n]
#     # get correct indices based off of valid indices
#     chosen = valid_idx[chosen_local]

#     # 4) split into amino acids and mask tokens
#     n_amino = torch.floor(0.2 * n).to(torch.int64).clamp(min=1)
#     n_mask = n - n_amino

#     # 5) shuffle and create positions
#     # need to shuffle as otherwise will always have the first part being masked with the mask token
#     # and the same goes for the amino acids, always masked at the second part of the sequence
#     shuffle   = torch.randperm(n.item(), device=ids.device)
#     mask_pos  = chosen[shuffle[:n_mask]]
#     amino_pos = chosen[shuffle[n_mask:]]

#     # 6) build labels - originals at selected positions, -100 for non-masked positions
#     masked_ids = ids.clone()
#     labels     = torch.full_like(ids, -100)
#     labels[chosen] = ids[chosen]

#     # 7) apply mask
#     masked_ids[mask_pos] = mask_token

#     # 8) assign amino acids masks from random amino acids
#     if n_amino > 0:
#         rand_idx = torch.randint(high=amino_acids.numel(), size=(n_amino.item(),), device=ids.device)
#         masked_ids[amino_pos] = amino_acids[rand_idx]

In [None]:
# forward pass

# for batch in tcr_loader:
#     batched = batch["masked_input_ids_ESMprotein_batched"]
#     batched = batched.to(device)
#     labels = batch['labels'].to(device)

#     logits_output = model_tcr.logits(
#         batched, LogitsConfig(sequence=True, return_embeddings=True)
#         )
    
#     logits_s = logits_output.logits.sequence
#     loss = F.cross_entropy(
#         logits_s.view(-1, logits_s.size(-1)), # [B*L, V]
#         labels.view(-1),                      # [B*L]
#         ignore_index=-100
#     )

In [None]:
# to print the modules within linear layers to find what to add Lora to

for name, mod in base.named_modules():
    if isinstance(mod, torch.nn.Linear):
        print(name, mod.weight.shape)

In [None]:
print(getattr(tok, "vocab_size", None))

get_vocab = getattr(tok, "get_vocab", None)
if callable(get_vocab):
    vocab = get_vocab()
    print("num tokens:", len(vocab))
    print("sample tokens:", list(vocab.keys())[:50])

# Special tokens (common names; guarded)
for name in ["bos_token", "eos_token", "unk_token", "pad_token", "mask_token"]:
    #print(name, getattr(tok, name, None))


seq = "ARNDCQEGHILKMFPSTWYV<pad><mask><unk>"
ids = tok.encode(seq) if hasattr(tok, "encode") else None
#print("ids:", ids)
if ids is not None and hasattr(tok, "decode"):
    #print("decoded:", tok.decode(ids))


# load base protein language model (esm2 8M params, 6 layers) with specified PEFT methods. See the esm2 repo for more model size options

protein = ESMProtein(sequence="AACGTATTTA<unk>")
model = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
protein_tensor = model.encode(protein)
logits_output = model.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings.size())
#protein_batch_converter = protein_LM_alphabet.get_batch_converter()

33
num tokens: 33
sample tokens: ['O', 'W', '-', 'S', 'T', '<eos>', 'G', 'B', 'K', '<pad>', 'H', 'M', 'L', 'R', '<mask>', 'A', 'F', 'U', 'Z', '<unk>', 'P', 'X', 'N', 'Q', 'C', 'I', 'V', 'E', '<cls>', 'D', '|', 'Y', '.']
bos_token <cls>
eos_token <eos>
unk_token <unk>
pad_token <pad>
mask_token <mask>
ids: [0, 5, 10, 17, 13, 23, 16, 9, 6, 21, 12, 4, 15, 20, 18, 14, 8, 11, 22, 19, 7, 1, 32, 3, 2]
decoded: <cls> A R N D C Q E G H I L K M F P S T W Y V <pad> <mask> <unk> <eos>
ForwardTrackData(sequence=tensor([[[-36.7500, -36.7500, -36.7500,  11.6250,  18.7500,  19.1250,  19.1250,
           18.7500,  18.7500,  18.7500,  18.8750,  18.7500,  18.1250,  18.3750,
           18.7500,  18.2500,  18.1250,  18.0000,  17.6250,  17.8750,  18.5000,
           17.6250,  17.3750,  17.1250,  18.1250,  -0.1055,  -3.7031,  -3.7656,
          -20.0000, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
          -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500, -36.7500,
          -36.75

In [None]:
# from esm.models.esmc import ESMC
# model = ESMC.from_pretrained("esmc_300m").eval()

# print([a for a in dir(model) if "token" in a.lower()])
# # Common: 'tokenizer' shows up; then introspect it:
# tok = getattr(model, "tokenizer", None)
# print(type(tok))
# print([a for a in dir(tok) if not a.startswith("_")])

# # use model's input embeddings to get the vocab 
# print(model.named_parameters)

['_detokenize', '_tokenize', 'tokenizer']
<class 'esm.tokenization.sequence_tokenizer.EsmSequenceTokenizer'>
<bound method Module.named_parameters of ESMC(
  (embed): Embedding(64, 960)
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0-29): 30 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=960, out_features=2880, bias=False)
          )
          (out_proj): Linear(in_features=960, out_features=960, bias=False)
          (q_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=960, out_features=5120, bias=False)
          (2): SwiGLU()
          (3): Linear

###### Raw code before making into a class

In [None]:
# Hyperparameters and defining dimensions + projection heads
# N.B. Turned this into a class 

# Shape
B, L_T_pad, D = emb_T.shape

# Latent ranks and final dimension (hyperparameters)
rL_T = 8      # positional rank for TCR (tunable)
rD_T = 16     # channel rank for TCR (tunable)
d    = 128    # final latent dimension (same d as in Z*)

# ratio of peptide to HLA (hyperparameter)
R = 0.7
# Epsilon for numerical stability
eps=1e-8

# maximum true lenghts of the sequences
# maximum true tcr length 
L_T_true = tcr_mask.sum(dim=1)
L_T_max = L_T_true.max()
# maximum true peptide length 
L_P_true = pep_mask.sum(dim=1)
L_P_max = L_P_true.max()
# maximum true HLA length 
L_H_true = hla_mask.sum(dim=1)
L_H_max = L_H_true.max()

# Factorisation Matrices

# Channel mixing D -> rD_T
B_Tc = torch.nn.Parameter(torch.empty(D, rD_T, device=device))
torch.nn.init.xavier_uniform_(B_Tc)

# Positional mixing: positions 0..L_T_max-1 -> rL_T
A_Tc = torch.nn.Parameter(torch.empty(L_T_max, rL_T, device=device))
torch.nn.init.xavier_uniform_(A_Tc)

# Final map: flattened projections (rD_T * rL_T, d)
H_Tc = torch.nn.Parameter(torch.empty(rD_T * rL_T, d, device=device))
torch.nn.init.xavier_uniform_(H_Tc)


# loop over the batch to get zT (the factorised TCR embedding)
# z = vec(A^TXB)H
# X - (B, L_pad, D)
# B - (D, rD_T)
# A - (L_)

zT_list = []

for b in range(B):
    ### 1. Get true length and slice
    Lb = int(L_T_true[b].item())          # scalar length for sample b

    # Slice to the true length (drop padded positions)
    # Encoder input X_T
    X = emb_T[b, :Lb, :]                  # (Lb, D)
    m = tcr_mask[b, :Lb].unsqueeze(-1)    # (Lb, 1), 1 = real token, 0 = pad

    # Apply mask to zero out padded positions (it's not necessary to do this explicitly, but clean)
    X = X * m 

    # ---- 2. Channel compression: D -> rD_T ----
    # Mix the D channels using B_Tc
    Y = X @ B_Tc                          # (Lb, rD_T)
    # map the D dim to a latent rank rD_T, so the same for each channel

    A_pos = A_Tc[:Lb, :]
    U = A_pos.T @ Y 

    # ---- 4. Flatten and map to the final d-dimensional vector ----
    U_flat = U.reshape(-1)                # (rL_T * rD_T,)
    z_b = U_flat @ H_Tc                   # (d,)

    # ---- 5. Normalise the resulting vector ----
    z_b = z_b / (z_b.norm() + eps)        # (d,)

    zT_list.append(z_b)

# Stack into a single tensor of shape (B, d)
zT = torch.stack(zT_list, dim=0)          # (B, d)




In [None]:
# loading invidiual z embeddings
file_path = '/home/natasha/multimodal_model/outputs/boltz_runs/positives/pair_000/boltz_results_pair_000/predictions/pair_000/embeddings_pair_000.npz'


# shape of s is the sum of the lengths of the TCR, HLA and peptide [batch, sum_of_len, dim (384)]
# shape of z is the length of the TCR, HLA and peptide twice [batch, sum_of_len, sum_of_len, dim (128)]

with np.load(file_path) as data:
    #print(list(data.keys()))
    #print(data['s'].shape)
    print(data['z'].shape)
    s = data['s']
    z = data['z']

print(z)
    
    