##### This script defines the encoders for the multimodal model
##### Overview:
- The inputs are the ESMC and Boltz embeddings 
- 3 encoders from PEFT of ESMC

In [1]:
# Streamlined imports - removing duplicates
import esm
import pandas as pd
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.nn.utils.rnn import pad_sequence
from torch.amp import autocast
from torch.cuda.amp import GradScaler

import sys
import os
import time
import math
import random
import pickle
import subprocess
import gc
from pathlib import Path
from typing import List, Tuple

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import display, update_display

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# ESM imports
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig, ESMProteinTensor
from esm.models.esmc import _BatchedESMProteinTensor

# Tokenizer imports
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers import normalizers
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordPieceTrainer
from tokenizers import decoders

# PEFT imports
from peft import get_peft_model, LoraConfig, TaskType
from peft.tuners.lora import LoraConfig, LoraModel

# Set environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [4]:
# Get current working directory and create models folder
import os
from pathlib import Path

# Get the current working directory
current_dir = os.getcwd()
print(f"Current working directory: {current_dir}")

# Get the project root (go up one level from scripts/train)
project_root = Path(current_dir).parent.parent
print(f"Project root: {project_root}")

# Create models directory in project root
models_dir = project_root / "models"
if not models_dir.exists():
    print(f"Models directory does not exist, creating it at: {models_dir}")
    models_dir.mkdir(exist_ok=True)
print(f"Models directory at: {models_dir}")

# Also create a checkpoints subdirectory for saving model checkpoints
checkpoints_dir = models_dir / "checkpoints"
if not checkpoints_dir.exists():
    print(f"Checkpoints directory does not exist, creating it at: {checkpoints_dir}")
    checkpoints_dir.mkdir(exist_ok=True)
print(f"Checkpoints directory at: {checkpoints_dir}")


Current working directory: /home/natasha/multimodal_model/scripts/train
Project root: /home/natasha/multimodal_model
Models directory at: /home/natasha/multimodal_model/models
Checkpoints directory at: /home/natasha/multimodal_model/models/checkpoints


##### Get ESM Embeddings

In [5]:
# need to load ESM C with LM head enabled 
# expose final token embeddings before the logits head (is logits head the LM head, LM head=language modelling head)?
# collator returns: input_ids, attention_mask

device = "cuda" if torch.cuda.is_available() else "cpu"
# load model and allow lora (rather than eval mode?)

In [6]:
# from esm.models.esmc import ESMC
# model = ESMC.from_pretrained("esmc_300m").eval()

# print([a for a in dir(model) if "token" in a.lower()])
# # Common: 'tokenizer' shows up; then introspect it:
# tok = getattr(model, "tokenizer", None)
# print(type(tok))
# print([a for a in dir(tok) if not a.startswith("_")])

In [7]:
# # use model's input embeddings to get the vocab 
# print(model.named_parameters)

In [8]:
# print(getattr(tok, "vocab_size", None))

# get_vocab = getattr(tok, "get_vocab", None)
# if callable(get_vocab):
#     vocab = get_vocab()
#     print("num tokens:", len(vocab))
#     print("sample tokens:", list(vocab.keys())[:50])

# # Special tokens (common names; guarded)
# for name in ["bos_token", "eos_token", "unk_token", "pad_token", "mask_token"]:
#     print(name, getattr(tok, name, None))

In [9]:
# seq = "ARNDCQEGHILKMFPSTWYV<pad><mask><unk>"
# ids = tok.encode(seq) if hasattr(tok, "encode") else None
# print("ids:", ids)
# if ids is not None and hasattr(tok, "decode"):
#     print("decoded:", tok.decode(ids))

In [10]:
# # load base protein language model (esm2 8M params, 6 layers) with specified PEFT methods. See the esm2 repo for more model size options

# protein = ESMProtein(sequence="AACGTATTTA<unk>")
# model = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
# protein_tensor = model.encode(protein)
# logits_output = model.logits(
#    protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
# )
# print(logits_output.logits, logits_output.embeddings.size())
# #protein_batch_converter = protein_LM_alphabet.get_batch_converter()

In [11]:
# size = [1,7,960], always the case?
# size = [1, 12, 960]
# size is I think batch number, sequence length, embedding dimension

df = pd.read_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv')
# Fill empty/nan values with <unk> token
df['TCRa'] = df['TCRa'].fillna('<unk>')
df['TCRb'] = df['TCRb'].fillna('<unk>')

# Replace empty strings with <unk>
df.loc[df['TCRa'] == '', 'TCRa'] = '<unk>'
df.loc[df['TCRb'] == '', 'TCRb'] = '<unk>'

df['TCR_full'] = df['TCRa'] + df['TCRb']
#df.to_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', index=False)

In [12]:
# hugging face dataset?
class TCR_dataset(Dataset):
    """Dataset for TCR data, for use in encoder training to propagate through to NC model"""
    def __init__(self, data_path, column_name='TCR_full', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name  # Store column name here
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):  # Remove column_name parameter
        row = self.data.iloc[idx]  # Fix: self.data, not self.csv
        protein = row[self.column_name]  # Use stored column name
        protein_idx = f'TCR_{idx}'
        if self.include_label:
            return protein_idx, protein, row.get('Binding', -1)
        #return protein_idx, protein
        return protein

class peptide_dataset(Dataset):
    def __init__(self, data_path, column_name='Peptide', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        peptide = row[self.column_name]
        peptide_idx = f'peptide_{idx}'
        if self.include_label:
            return peptide_idx, peptide, row.get('Binding', -1)
        #return peptide_idx, peptide
        return peptide

class HLA_dataset(Dataset):
    def __init__(self, data_path, column_name='HLA', include_label=False):
        self.data_path = data_path
        self.data = pd.read_csv(data_path)
        self.column_name = column_name
        self.include_label = include_label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        hla = row[self.column_name]
        hla_idx = f'hla_{idx}'
        if self.include_label:
            return hla_idx, hla, row.get('Binding', -1)
        #return hla_idx, hla
        return hla

In [13]:
tcr = TCR_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='TCR_full')
peptide = peptide_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='Peptide')
hla = HLA_dataset(data_path='/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', column_name='HLA_sequence')

# not sure I reallt need these classes??? Hmmmmmm

In [14]:
tcrs = [ESMProtein(sequence=s) for s in tcr.data['TCR_full']]
peptides = [ESMProtein(sequence=s) for s in peptide.data['Peptide']]
hlas = [ESMProtein(sequence=s) for s in hla.data['HLA_sequence']]

# can batch at the forward step, not the encoding step

model = ESMC.from_pretrained("esmc_300m").to(device).eval()

tcrs_data = [seq for seq in tcr]
peptides_data = [seq for seq in peptide]
hlas_data = [seq for seq in hla]

encoded_tcrs = [model.encode(p) for p in tcrs]
encoded_peptides = [model.encode(p) for p in peptides]
encoded_hlas = [model.encode(p) for p in hlas]


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

##### Mask Data and Collate Data for Model

In [15]:
# do for entire dataset
# do we also want to output attention_mask from the tokenizer?

tok = model.tokenizer
CLS_ID = tok.cls_token_id
EOS_ID = tok.eos_token_id
PAD_ID = tok.pad_token_id
MASK_ID = tok.mask_token_id

AA_IDS =  [5,10,17,13,23,16,9,6,21,12,4,15,20,18,14,8,11,22,19,7]


class EncodedSeqDataset(Dataset):
    def __init__(self, sequences, enc):     # ← now takes two arguments
        self.sequences = sequences          # list[str]
        self.input_ids = enc['input_ids']
        self.attention_mask = enc['attention_mask']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "sequence": self.sequences[idx],  # raw sequence string
            "input_ids": torch.as_tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.as_tensor(self.attention_mask[idx], dtype=torch.long),
        }


In [16]:
class MLMProteinCollator:
    def __init__(self, *, cls_id, eos_id, pad_id, mask_id, amino_acids,
                 p=0.15, min_per_seq=2, max_per_seq=45, aa_frac=0.20):
        self.CLS = cls_id
        self.EOS = eos_id
        self.PAD = pad_id
        self.MASK = mask_id
        self.aa = torch.as_tensor(amino_acids, dtype=torch.long)
        self.p = p
        self.min_per_seq = min_per_seq
        self.max_per_seq = max_per_seq
        self.aa_frac = aa_frac

    @torch.no_grad()
    def mask_batch(self, input_ids, attention_mask):
        device = input_ids.device
        aa = self.aa.to(device)

        B, L = input_ids.shape
        valid_mask = attention_mask.bool() \
                   & (input_ids != self.PAD) \
                   & (input_ids != self.CLS) \
                   & (input_ids != self.EOS)

        masked_input_ids = input_ids.clone()
        labels = torch.full_like(input_ids, -100)

        for i in range(B):
            vmask = valid_mask[i]
            if not vmask.any():
                continue

            valid_idx = vmask.nonzero(as_tuple=False).squeeze(1)  # (L_valid,)
            L_valid = valid_idx.numel()

            # how many to mask: floor(p*L_valid), clamped to [2, 45] but never > L_valid
            n = torch.floor(self.p * torch.tensor(L_valid, device=device, dtype=torch.float32)).to(torch.int64)
            n = torch.clamp(n, min=self.min_per_seq, max=min(self.max_per_seq, L_valid))
            if n.item() == 0:
                continue

            # choose n distinct valid positions
            chosen = valid_idx[torch.randperm(L_valid, device=device)[:n]]

            # split into AA vs MASK; ensure >=1 AA if n>=2
            n_amino = torch.floor(self.aa_frac * n).to(torch.int64)
            if n.item() >= 2:
                n_amino = torch.clamp(n_amino, min=1)
            n_mask = n - n_amino

            order = torch.randperm(n.item(), device=device)
            mask_pos  = chosen[order[:n_mask]]
            amino_pos = chosen[order[n_mask:]]

            # labels only at supervised positions
            labels[i, chosen] = input_ids[i, chosen]

            # apply edits
            if n_mask.item() > 0:
                masked_input_ids[i, mask_pos] = self.MASK
            if n_amino.item() > 0:
                r_idx = torch.randint(high=aa.numel(), size=(n_amino.item(),), device=device)
                masked_input_ids[i, amino_pos] = aa[r_idx]

        return masked_input_ids, labels


    def __call__(self, features):
        input_ids = torch.stack([f["input_ids"] for f in features], dim=0)
        attention_mask = torch.stack([f["attention_mask"] for f in features], dim=0)
        sequences = [f["sequence"] for f in features]
        proteins = [ESMProtein(sequence=f["sequence"]) for f in features]
        batched_clean = _BatchedESMProteinTensor(sequence=input_ids)


        masked_input_ids, labels = self.mask_batch(input_ids, attention_mask)

        # build masked sequences as strings (keep <mask>, drop CLS/EOS/PAD)
        masked_sequences = []
        for row in masked_input_ids.tolist():
            toks = collator.tokenizer.convert_ids_to_tokens(row, skip_special_tokens=False)
            aa = []
            for t in toks:
                if t in (collator.tokenizer.cls_token, collator.tokenizer.eos_token, collator.tokenizer.pad_token):
                    continue
                aa.append(t)  # AA tokens are single letters; keep "<mask>" as is
            masked_sequences.append("".join(aa))

        proteins_masked = [ESMProtein(sequence=s) for s in masked_sequences]
        batched_masked = _BatchedESMProteinTensor(sequence=masked_input_ids)


        return {
            "masked_input_ids": masked_input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
            "clean_input_ids": input_ids.clone(),
            "clean_sequences": sequences,                 # clean strings
            "masked_sequences": masked_sequences,   # masked strings  ← NEW
            "clean_sequences_ESMprotein": proteins,
            "masked_sequences_ESMprotein": proteins_masked,
            "masked_input_ids_ESMprotein_batched": batched_masked,
            "clean_input_ids_ESMprotein_batched": batched_clean,
        }

In [None]:
# Your existing BatchEncodings:
# clean_tcrs_tokenized, clean_peptides_tokenized, clean_hlas_tokenized
clean_tcrs_tokenized = model.tokenizer(tcrs_data, return_tensors='pt', padding=True)
clean_peptides_tokenized = model.tokenizer(peptides_data, return_tensors='pt', padding=True)
clean_hlas_tokenized = model.tokenizer(hlas_data, return_tensors='pt', padding=True)

tcr_ds = EncodedSeqDataset(tcrs_data,clean_tcrs_tokenized)
pep_ds = EncodedSeqDataset(peptides_data, clean_peptides_tokenized)
hla_ds = EncodedSeqDataset(hlas_data, clean_hlas_tokenized)

collator = MLMProteinCollator(
    cls_id=CLS_ID, eos_id=EOS_ID, pad_id=PAD_ID, mask_id=MASK_ID,
    amino_acids=AA_IDS, p=0.15, min_per_seq=2, max_per_seq=45, aa_frac=0.20
)
collator.tokenizer = model.tokenizer


tcr_loader = DataLoader(tcr_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)
pep_loader = DataLoader(pep_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)
hla_loader = DataLoader(hla_ds, batch_size=8, shuffle=True, num_workers=4, collate_fn=collator)


# gives a batch dict from the collator with 4 keys
# input_ids, labels (original tokens only at masked positions, -100 everywhere else)
# attention_mask (0,1 for padding), clean_input_ids (clean copy of the input for clean forward pass using boltz for NC loss)

model.to("cpu")
torch.cuda.empty_cache()

In [None]:
# define LoRA model

use_amp = torch.cuda.is_available()

base = ESMC.from_pretrained("esmc_300m").to(device)

lora_cfg = LoraConfig(
    r=8, lora_alpha=32, lora_dropout=0.05, bias="none",
    target_modules=["out_proj", "layernorm_qkv.1"],  # inner Linear inside the Sequential
)

model_tcr = LoraModel(base, lora_cfg, adapter_name="tcr")

# freeze everything; unfreeze only LoRA params
for p in model_tcr.parameters():
    p.requires_grad = False
for name, p in model_tcr.named_parameters():
    if "lora_A" in name or "lora_B" in name:
        p.requires_grad = True

model_tcr.train()

optim_tcr = torch.optim.AdamW(
    (p for p in model_tcr.parameters() if p.requires_grad),
    lr=1e-3, weight_decay=0.01
)
#scaler = GradScaler(enabled=use_amp)


In [20]:
from torch.amp import autocast
from torch.cuda.amp import GradScaler

use_amp = torch.cuda.is_available()
scaler  = GradScaler(enabled=False)                     # <-- no scaler

for batch in tcr_loader:
    input_ids = batch["masked_input_ids"].to(device, dtype=torch.long)
    labels    = batch["labels"].to(device)

    with autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):  # <-- bf16
        out    = model_tcr(input_ids)
        logits = out.sequence_logits
        #print(type(logits), logits.shape, logits.requires_grad)  # expect: Tensor, [B,L,V], True

        loss   = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )

    loss.backward()                                     # <-- standard backward
    optim_tcr.step(); optim_tcr.zero_grad(set_to_none=True)


# 1.1 seconds for 100
# 11 for 1000
# 110 for 10000
# 350 for 35000 - 5.8 minutes to train one encoder on full dataset

  scaler  = GradScaler(enabled=False)                     # <-- no scaler


In [21]:
# save model and offload to CPU

# save tcr model

checkpoint_filename = 'tcr_encoder_checkpoint.pth'

checkpoint_dict = {
    #'epoch': num_epochs,
    'tcr_model_state_dict': model_tcr.state_dict(),
    #'pep_model_state_dict': peptide_model_nc.state_dict(), 
    'optimizer_state_dict': optim_tcr.state_dict(),
    #'final_loss': avg_epoch_loss,
    #'binding_threshold': threshold  # Add threshold to saved state
}

torch.save(checkpoint_dict, checkpoints_dir/checkpoint_filename)

print(f"Checkpoint saved to {checkpoint_filename}")

model_tcr.to("cpu")
torch.cuda.empty_cache()  # Free up GPU memory


Checkpoint saved to tcr_encoder_checkpoint.pth


In [None]:
# peptide encoder





1. Model selection & freeze policy
- Load ESM-C base.
- Freeze embeddings + lower N blocks; leave top K blocks trainable (or use adapters/LoRA). Record N, K.
2. Two forwards you will run
- Masked forward (for loss): feed masked_sequences → get logits over vocab (B, L, V).
- Clean forward (for caching embeddings): feed sequences → get token_hidden_states (B, L, d_model).
3. Loss and metrics
- Compute token-level cross-entropy on logits vs labels (ignore_index = −100).
- Track masked-token accuracy and perplexity per modality.
4. Pooling you will later reuse
- From token_hidden_states (clean forward), compute a pooled sequence vector per item using a masked mean over non-special tokens → (B, d_model).
- Save both: token states (for optional analysis) and pooled vectors (for projection/alignment).
5. Optimisation
- AdamW; no weight decay on LayerNorm/bias; gradient clip ~1.0; fp16/bf16 if available.
- LR schedule: warm-up then cosine/linear decay.
- Validation sets per modality with identical masking policy.
6. Checkpoints
- Save model weights + adapter/LoRA if used, optimizer, scheduler, tokenizer hash, mask rate, freeze policy.

In [51]:
# Check GPU memory
def print_gpu_memory():
    if torch.cuda.is_available():
        print("\nGPU Memory Usage:")
        print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
        print(f"Cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
        
        # Get total GPU memory
        import subprocess
        result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.total', '--format=csv,nounits,noheader'])
        total_memory = float(result.decode('utf-8').strip())
        print(f"Total GPU Memory: {total_memory:.2f} MB")
        print(f"Available: {total_memory - torch.cuda.memory_reserved()/1024**2:.2f} MB")

# Check memory before model loading
print("Before model loading:")
print_gpu_memory()


Before model loading:

GPU Memory Usage:
Allocated: 1168.29 MB
Cached: 1272.00 MB
Total GPU Memory: 16376.00 MB
Available: 15104.00 MB


##### Get Boltz Embeddings

In [31]:
# get Boltz embeddings
file_path = '/home/natasha/multimodal_model/outputs/boltz_runs/positives/pair_000/boltz_results_pair_000/predictions/pair_000/embeddings_pair_000.npz'

In [32]:
# shape of s is the sum of the lengths of the TCR, HLA and peptide [batch, sum_of_len, dim (384)]
# shape of z is the length of the TCR, HLA and peptide twice [batch, sum_of_len, sum_of_len, dim (128)]

with np.load(file_path) as data:
    print(list(data.keys()))
    print(data['s'].shape)
    print(data['z'].shape)
    s = data['s']
    z = data['z']

print(z)
    
    

['s', 'z']
(1, 604, 384)
(1, 604, 604, 128)
[[[[ -88.07086     -52.60885     108.844284   ...  190.84561
      36.47628    -110.61696   ]
   [ -57.74353     -54.678726    -54.20086    ...  -43.508102
       5.2886047   -22.915913  ]
   [  -5.9032555   -26.453156     -5.7282715  ...  -50.97639
      46.78975     -54.079994  ]
   ...
   [  -7.059223     20.620789      9.03931    ...  -15.813938
      20.766754    -12.156269  ]
   [  -4.7150135    15.798729      8.588043   ...  -16.762444
      14.437408     -7.1301365 ]
   [ -21.511803     37.621155     17.605013   ...   14.496872
      17.363136     -8.831755  ]]

  [[ -35.09397       7.525936     71.0813     ...   47.07572
      20.585098     46.14451   ]
   [ -31.209412     -7.549141     64.2861     ...   86.585236
      -2.7985687  -130.94954   ]
   [ -10.455566     -6.6848755    30.168354   ...   14.787491
      17.332607    -59.565666  ]
   ...
   [   0.5986786    -1.5232773    -2.8792143  ...   -5.399969
      14.654137    -17.363

##### Old/Unused Code

In [None]:
# not necessary, ESM C has their own much more complex attention layers
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=4):
        """
        hidden_dim: Dimensionality of the input
        num_heads: Number of attention heads to split the input into
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
        self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
        self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output part

    def check_sdpa_inputs(self, x):
        assert x.size(1) == self.num_heads, f"Expected shape (-1, {self.num_heads}, -1, {self.hidden_dim // self.num_heads}), got {tuple(x.size())}"
        assert x.size(3) == self.hidden_dim // self.num_heads

    def scaled_dot_product_attention(self,
            query,
            key,
            value,
            attention_mask=None,
            key_padding_mask=None):
        """
        query: tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim // num_heads)
        key: tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim // num_heads)
        value: tensor of shape (batch_size, num_heads, value_sequence_length, hidden_dim // num_heads)
        attention_mask: tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor of shape (sequence_length, key_sequence_length)
        """
        self.check_sdpa_inputs(query)
        self.check_sdpa_inputs(key)
        self.check_sdpa_inputs(value)

        d_k = query.size(-1) #size of the last dimension of the query tensor
        tgt_len, src_len = query.size(-2), key.size(-2)
        #logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Attention mask
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                assert attention_mask.size() == (tgt_len, src_len)
                attention_mask = attention_mask.unsqueeze(0)
                logits = logits + attention_mask
            else:
                raise ValueError(f"Invalid size of attention_mask: {attention_mask.size()}")

        # Key padding mask
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, num heads
            logits = logits + key_padding_mask

        attention = torch.softmax(logits, dim=-1)
        output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)

        return output, attention

    def split_into_heads(self, x, num_heads):
        batch_size, seq_length, hidden_dim = x.size()
        x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)

        return x.transpose(1, 2) # (batch_size, num_heads, seq_length, hidden_dim // num_heads)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_hidden_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim)
    
    def forward(
            self,
            q,
            k,
            v,
            attention_mask=None,
            key_padding_mask=None):
        """
        q: tensor of shape (batch_size, query_sequence_length, hidden_dim)
        k: tensor of shape (batch_size, key_sequence_length, hidden_dim)
        v: tensor of shape (batch_size, value_sequence_length, hidden_dim)
        attention_mask: tensor of shape (query_sequence_length, key_sequence_length)
        key_padding_mask: tensor of shape (sequence_length, key_sequence_length)
        """
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q, self.num_heads) # (batch_size, num_heads, query_sequence_length, hidden_dim // num_heads)
        k = self.split_into_heads(k, self.num_heads) # (batch_size, num_heads, key_sequence_length, hidden_dim // num_heads)
        v = self.split_into_heads(v, self.num_heads) # (batch_size, num_heads, value_sequence_length, hidden_dim // num_heads)
        
        attn_values, attn_weights = self.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask
        )
        grouped = self.combine_heads(attn_values)
        output = self.Wo(grouped)

        self.attention_weights = attn_weights

        return output
        
        


In [None]:
### masking for one sequence

# # tokenize sequences
# clean_tcrs_tokenized = model.tokenizer(tcrs_data, return_tensors='pt', padding=True)
# clean_peptides_tokenized = model.tokenizer(peptides_data, return_tensors='pt', padding=True)
# clean_hlas_tokenized = model.tokenizer(hlas_data, return_tensors='pt', padding=True)

# # clean_tcrs_tokenized.keys()
# # ids = clean_tcrs_tokenized['input_ids']
# # attention_mask = clean_tcrs_tokenized['attention_mask']

# #copies for masking
# masked_tcrs_tokenized = clean_tcrs_tokenized.copy()
# masked_peptides_tokenized = clean_peptides_tokenized.copy()
# masked_hlas_tokenized = clean_hlas_tokenized.copy()

# # special tokens and amino acids
# CLS = model.tokenizer.cls_token_id
# EOS = model.tokenizer.eos_token_id
# PAD = model.tokenizer.pad_token_id
# amino_acids = torch.tensor([5, 10, 17, 13, 23, 16, 9, 6, 21, 12, 4, 15, 20, 18, 14, 8, 11, 22, 19, 7], device=ids.device)
# mask_token = 32

# # beginning of per sequence masking
# # IDS
# ids = masked_tcrs_tokenized['input_ids'][0]
# valid_mask = (ids != CLS) & (ids != EOS) & (ids != PAD)
# valid_idx = valid_mask.nonzero(as_tuple=False).squeeze(1)

# #L_valid = valid_mask.sum()
# L_valid = valid_idx.numel()
# if L_valid == 0:
#     # nothing to mask, make labels -100 to be ignored in the loss
#     masked_ids = ids.clone()
#     labels = torch.full_like(ids, -100)
# else:
#     # L_valid = valid_mask.sum(dim=1) # for batch, but currently only 1 dimension
#     n = torch.floor(0.15 * torch.tensor(L_valid, device=ids.device, dtype=torch.float32)).to(torch.int64)
#     n = n.clamp(min=2, max=min(45, L_valid))
#     # n = torch.floor(
#     #         0.15 * torch.tensor(L_valid, device=ids.device, dtype=torch.float32)
#     #         ).to(torch.int64).clamp(min=2, max=45)
    
#     # choose n distinct valid positions 
#     # reshuffle
#     perm = torch.randperm(L_valid, device=ids.device)
#     # choose first n
#     chosen_local = perm[:n]
#     # get correct indices based off of valid indices
#     chosen = valid_idx[chosen_local]

#     # 4) split into amino acids and mask tokens
#     n_amino = torch.floor(0.2 * n).to(torch.int64).clamp(min=1)
#     n_mask = n - n_amino

#     # 5) shuffle and create positions
#     # need to shuffle as otherwise will always have the first part being masked with the mask token
#     # and the same goes for the amino acids, always masked at the second part of the sequence
#     shuffle   = torch.randperm(n.item(), device=ids.device)
#     mask_pos  = chosen[shuffle[:n_mask]]
#     amino_pos = chosen[shuffle[n_mask:]]

#     # 6) build labels - originals at selected positions, -100 for non-masked positions
#     masked_ids = ids.clone()
#     labels     = torch.full_like(ids, -100)
#     labels[chosen] = ids[chosen]

#     # 7) apply mask
#     masked_ids[mask_pos] = mask_token

#     # 8) assign amino acids masks from random amino acids
#     if n_amino > 0:
#         rand_idx = torch.randint(high=amino_acids.numel(), size=(n_amino.item(),), device=ids.device)
#         masked_ids[amino_pos] = amino_acids[rand_idx]

In [None]:
# # forward pass

# # for batch in tcr_loader:
# #     batched = batch["masked_input_ids_ESMprotein_batched"]
# #     batched = batched.to(device)
# #     labels = batch['labels'].to(device)

# #     logits_output = model_tcr.logits(
# #         batched, LogitsConfig(sequence=True, return_embeddings=True)
# #         )
    
# #     logits_s = logits_output.logits.sequence
# #     loss = F.cross_entropy(
# #         logits_s.view(-1, logits_s.size(-1)), # [B*L, V]
# #         labels.view(-1),                      # [B*L]
# #         ignore_index=-100
# #     )

In [None]:
# to print the modules within linear layers to find what to add Lora to

for name, mod in base.named_modules():
    if isinstance(mod, torch.nn.Linear):
        print(name, mod.weight.shape)