# Boltz Batch Builder
This notebook:
1. Loads the full_positivies_hla_seq CSV (combined IEDB and VDJDB IMMREP positives with HLA sequences).
2. Selects the first 100 usable pairs with peptide, TCRα, TCRβ, and HLA sequence.
3. Writes **per-pair YAMLs** and runs jackhmmer to generate **A3M** files (TCRα, TCRβ, HLA).

> Edit paths as needed to match repo. The generated YAML points to `data/raw/MSA/jackhmmer_msas/*.a3m` paths.
> It also has a second section that does the same for negative pairs downloaded from IEDB

In [3]:
# # split the fasta file into alpha, beta, and mhc (only need to do this once)
# from pathlib import Path

# src = Path("/home/natasha/multimodal_model/data/raw/MSA/big_combo_subset_tcrs_50000_w_mhc_seqs.fasta")
# outdir = Path("/home/natasha/multimodal_model/data/raw/MSA/db_split")
# outdir.mkdir(parents=True, exist_ok=True)

# fa_alpha = outdir / "alpha.fasta"
# fa_beta  = outdir / "beta.fasta"
# fa_mhc   = outdir / "mhc.fasta"

# with src.open() as fin, \
#      fa_alpha.open("w") as fa, \
#      fa_beta.open("w") as fb, \
#      fa_mhc.open("w") as fm:
#     hdr, seq = None, []
#     for ln in fin:
#         if ln.startswith(">"):
#             if hdr:
#                 s = "".join(seq)
#                 if hdr.endswith("_a"):
#                     fa.write(hdr + "\n" + s + "\n")
#                 elif hdr.endswith("_b"):
#                     fb.write(hdr + "\n" + s + "\n")
#                 elif "mhc" in hdr.lower():
#                     fm.write(hdr + "\n" + s + "\n")
#             hdr, seq = ln.strip(), []
#         else:
#             seq.append(ln.strip())
#     # write last record
#     if hdr:
#         s = "".join(seq)
#         if hdr.endswith("_a"):
#             fa.write(hdr + s + "\n")
#         elif hdr.endswith("_b"):
#             fb.write(hdr + s + "\n")
#         elif "mhc" in hdr.lower():
#             fm.write(hdr + s + "\n")

# print("Split complete:")
# print("α:", fa_alpha.stat().st_size, "bytes")
# print("β:", fa_beta.stat().st_size, "bytes")
# print("MHC:", fa_mhc.stat().st_size, "bytes")


In [None]:
# Build MSA files using jackhmmer

from pathlib import Path
import subprocess, shutil, os
import pandas as pd  # you use pd in cell 3
import os, re, textwrap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

# -------- Paths --------
BASE_DIR = Path("/home/natasha/multimodal_model")
DB_COMBINED = BASE_DIR / "data" / "raw" / "MSA" / "big_combo_subset_tcrs_50000_w_mhc_seqs.fasta"

# use the split DBs created
DBS = {
    "tcra": BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "alpha.fasta",
    "tcrb": BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "beta.fasta",
    "mhc":  BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "mhc.fasta",
}
def pick_db_for(stem: str) -> Path:
    return DBS.get(stem, DB_COMBINED)

OUT_ROOT = BASE_DIR / "data" / "raw" / "MSA" / "jackhmmer_msas"
OUT_ROOT.mkdir(parents=True, exist_ok=True)

# -------- Controls --------
VERBOSE = False              # set False to silence prints
KEEP_INTERMEDIATES = False   # set False to delete .sto/.tbl/.a3m after filtering

# -------- jackhmmer / filtering parameters --------
JACK_ITERS = 1
EVALUE = 1e-10
CPU_THREADS = 4
CPU_THREADS = os.cpu_count() or 4

MAX_SEQS = 64
# keep identity de-dup very relaxed so hhfilter mostly just caps:
ID_THR_DEFAULT = 100        # 100 = no ID-based collapse
COV_THR_TCR     = 50
COV_THR_MHC     = 30

def have(cmd): return shutil.which(cmd) is not None
if VERBOSE:
    print("Deps:",
          "jackhmmer" if have("jackhmmer") else "MISSING",
          "esl-reformat" if have("esl-reformat") else ("reformat.pl" if have("reformat.pl") else "MISSING"),
          "hhfilter" if have("hhfilter") else "MISSING")


In [6]:
# === Configure your paths ===
#CSV_PATH = "data/raw/HLA/full_positives_hla_seq.csv"  # update if needed
CSV_PATH = "/home/natasha/multimodal_model/data/raw/HLA/full_positives_hla_seq.csv"
#BASE_DIR = Path(".")                      # point this to your repo root if running elsewhere
BASE_DIR = Path("/home/natasha/multimodal_model") #/ "data" / "raw"
MSA_DIR  = BASE_DIR / "data" / "raw" / "MSA"
PAIR_DIR = BASE_DIR / "data" / "pairs"
MANI_DIR = BASE_DIR / "data" / "manifests"

# create directories if they don't exist
MSA_DIR.mkdir(parents=True, exist_ok=True)
PAIR_DIR.mkdir(parents=True, exist_ok=True)
MANI_DIR.mkdir(parents=True, exist_ok=True)

# === Load CSV and preview ===
df = pd.read_csv(CSV_PATH)
df.head(3)


Unnamed: 0,Peptide,HLA,Va,Ja,CDR1a,CDR2a,CDR3a,CDR3a_extended,TCRa,Vb,...,CDR1b,CDR2b,CDR3b,CDR3b_extended,TCRb,references,receptor_id,just_10X,HLA_sequence,TCR_full
0,TTDPSFLGRY,HLA-A*01:01,TRAV9-2*01,TRAJ6*01,ATGYPS,ATKADDK,AASGGSYIPT,CAASGGSYIPTF,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...,TRBV9*01,...,SGDLS,YYNGEE,ASSVEETSAGGHEQF,CASSVEETSAGGHEQFF,DSGVTQTPKHLITATGQRVTLRCSPRSGDLSVYWYQQSLDQGLQFL...,http://www.iedb.org/reference/1039300,203509,True,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...
1,YLQPRTFLL,HLA-A*02:01,TRAV9-2*01,TRAJ45*01,ATGYPS,ATKADDK,AGGADGLT,CAGGADGLTF,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...,TRBV2*01,...,SNHLY,FYNNEI,ASSEWQGEKLF,CASSEWQGEKLFF,EPEVTQTPSHQVTQMGQEVILRCVPISNHLYFYWYRQILGQKVEFL...,http://www.iedb.org/reference/1040829,208619,False,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...
2,ILTGLNYEV,HLA-A*02:01,TRAV9-2*01,unknown,ATGYPS,ATKADDK,ALADMNRDDKII,CALADMNRDDKIIF,<unk>,TRBV9*01,...,SGDLS,YYNGEE,ASSVDPGQSYEQY,CASSVDPGQSYEQYF,<unk>,http://www.iedb.org/reference/1034376,29673,True,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,<unk><unk>


In [None]:
# === Filter usable rows ===
required = ["Peptide","HLA_sequence","TCRa","TCRb"]
clean = df.dropna(subset=required).copy()

def clean_seq(s: str) -> str:
    import re
    if not isinstance(s, str):
        return ""
    return re.sub(r"[^A-Za-z]", "", s).upper()

for c in required:
    clean[c] = clean[c].apply(clean_seq)

# usable = clean[(clean["Peptide"].str.len()>=8) & 
#                (clean["TCRa"].str.len()>=50) & 
#                (clean["TCRb"].str.len()>=50) & 
#                (clean["HLA_sequence"].str.len()>=100)].head(100).copy()

#usable = clean.head(100).copy()
usable = clean[(clean["TCRa"].str.len()>=50) | 
               (clean["TCRb"].str.len()>=50)].head(1000).copy()
# don't want to impose this restriction, as want to have missing data

len(usable)


6000

In [33]:
#sum(usable['TCRa'] == '<unk>')
for i, row in usable.iterrows():
    usable.loc[i, 'missing_alpha'] = 1 if row['TCRa'] == 'UNK' else 0
    usable.loc[i, 'missing_beta'] = 1 if row['TCRb'] == 'UNK' else 0
    usable.loc[i, 'missing_mhc'] = 1 if row['HLA_sequence'] == 'UNK' else 0

print(sum(usable['missing_alpha']),
sum(usable['missing_beta']),
sum(usable['missing_mhc']))

values_a = [i for i in usable['TCRa']]
values_b = [i for i in usable['TCRb']]


num_unk_a = values_a.count('UNK')
num_unk_b = values_b.count('UNK')
print(f"Number of 'UNK' in values_a: {num_unk_a}")
print(f"Number of 'UNK' in values_b: {num_unk_b}")

# 152 missing alpha and 152 missing beta in the first 1000 runs - seems suspicious that it's the same?

50.0 5.0 0.0
Number of 'UNK' in values_a: 50
Number of 'UNK' in values_b: 5


In [None]:
#usable.to_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_100_runs.csv', index=False)

# save 1000 runs, with missing data 
usable.to_csv('/home/natasha/multimodal_model/data/raw/HLA/boltz_6000_runs.csv', index=False)

In [62]:
def make_yaml(pair_id: str, seqA: str, seqB: str, pep: str, mhc: str):

    proteins = [
        {"protein":{"id":"A","sequence": seqA}}, #, "msa": yaml_msa_A}},
        {"protein":{"id":"B","sequence": seqB}}, #, "msa": yaml_msa_B}},
        {"protein":{"id":"C","sequence": pep}},  #,  "msa": "empty"}},
        {"protein":{"id":"D","sequence": mhc}},  #,  "msa": yaml_msa_D}},
    ]
    # if include_b2m:
    #     proteins.append({"protein":{"id":"E","sequence": B2M_SEQ, "msa": "empty"}})

    yaml_text = "version: 1\nsequences:\n"
    for p in proteins:
        pid = p["protein"]["id"]
        seq = p["protein"]["sequence"]
        #msa = p["protein"]["msa"]
        # below removed "msa: {msa}"
        yaml_text += textwrap.dedent(f"""          - protein:
              id: {pid}
              sequence: {seq}
              msa: empty
        """)
    return yaml_text


In [63]:
PAIR_DIR

PosixPath('/home/natasha/multimodal_model/data/pairs')

In [64]:
# === Generate YAMLs + manifest ===
rows = []
for i, row in usable.reset_index(drop=True).iterrows():
    pair_id = f"pair_{i:03d}"
    yml = make_yaml(pair_id, row["TCRa"], row["TCRb"], row["Peptide"], row["HLA_sequence"])
    yml_path = PAIR_DIR / f"{pair_id}.yaml"
    yml_path.write_text(yml)

    rows.append({
        "pair_id": pair_id,
        "yaml_path": f"data/pairs/{pair_id}.yaml",
        #"msa_A": f"data/raw/MSA/{pair_id}_A_tcra.a3m",
        #"msa_B": f"data/raw/MSA/{pair_id}_B_tcrb.a3m",
        #"msa_D": f"data/raw/MSA/{pair_id}_D_hla.a3m",
        "pep_len": len(row["Peptide"]),
        "tcra_len": len(row["TCRa"]),
        "tcrb_len": len(row["TCRb"]),
        "hla_len": len(row["HLA_sequence"]),
    })

mani = pd.DataFrame(rows)
mani_path = MANI_DIR / "boltz_100_manifest.csv"
mani.to_csv(mani_path, index=False)
mani.head(8)


Unnamed: 0,pair_id,yaml_path,pep_len,tcra_len,tcrb_len,hla_len
0,pair_000,data/pairs/pair_000.yaml,10,112,117,365
1,pair_001,data/pairs/pair_001.yaml,9,110,114,365
2,pair_002,data/pairs/pair_002.yaml,10,112,114,365
3,pair_003,data/pairs/pair_003.yaml,10,112,114,365
4,pair_004,data/pairs/pair_004.yaml,9,114,115,365
5,pair_005,data/pairs/pair_005.yaml,10,113,115,365
6,pair_006,data/pairs/pair_006.yaml,10,113,115,365
7,pair_007,data/pairs/pair_007.yaml,10,113,115,365


In [65]:
# new functions
def run(cmd):
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    out, err = p.communicate()
    return p.returncode, out, err

def sto_to_a3m(sto_path: Path, a3m_path: Path):
    if have("esl-reformat"):
        code, out, err = run(["esl-reformat", "a3m", str(sto_path)])
        if code == 0 and out:
            a3m_path.write_text(out)
            return True
    if have("reformat.pl"):
        code, out, err = run(["reformat.pl", "sto", "a3m", str(sto_path), str(a3m_path)])
        return code == 0 and a3m_path.exists()
    return False

# ---------- small helpers you call ----------
def count_a3m(p: Path) -> int:
    if not p.exists(): return -1
    return sum(1 for ln in p.open() if ln.startswith('>'))

def count_tbl_hits(tbl: Path) -> int:
    if not tbl.exists(): return -1
    return sum(1 for ln in tbl.read_text().splitlines() if ln and ln[0] != '#')

def file_info(p: Path) -> str:
    return f"{p}  exists={p.exists()}  size={p.stat().st_size if p.exists() else 0}"

def assert_single_query(qfa: Path):
    n = count_a3m(qfa)
    if VERBOSE:
        print(f"[CHK] query FASTA {file_info(qfa)}  nseq={n}")
    if n != 1:
        raise ValueError(f"Query FASTA must contain exactly 1 sequence; got {n} in {qfa}")

# ---------- hhfilter wrapper (handles -maxseq vs -n) ----------
def hhfilter_cap(in_a3m: Path, out_a3m: Path, max_seqs=MAX_SEQS, id_thr=ID_THR_DEFAULT, cov_thr=50):
    if not have("hhfilter"):
        if VERBOSE: print("WARN: hhfilter not found on PATH; copying input → output")
        in_a3m.replace(out_a3m)
        return True

    code, out, err = run(["hhfilter", "-h"])
    use_maxseq = ("-maxseq" in (out or "")) or ("-maxseq" in (err or ""))

    cmd = ["hhfilter", "-i", str(in_a3m), "-o", str(out_a3m),
           "-id", str(id_thr), "-cov", str(cov_thr)]
    cmd += (["-maxseq", str(max_seqs)] if use_maxseq else ["-n", str(max_seqs)])

    if VERBOSE: print("[CMD]", " ".join(map(str, cmd)))
    code, out, err = run(cmd)
    if VERBOSE and err: print("[HHFILTER][stderr]\n", (err.strip()[:800]))
    if VERBOSE: print("[HHFILTER] rc:", code)
    return code == 0 and out_a3m.exists()

# ---------- main builder ----------
def build_msa_for_chain(seq: str, out_dir: Path, stem: str) -> Path:
    """
    seq: raw AA sequence (no gaps)
    out_dir: where to write outputs
    stem: base filename (e.g., 'tcra', 'tcrb', 'mhc')
    returns: Path to final .a3m (filtered)
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    qfa = out_dir / f"{stem}.fa"
    qfa.write_text(f">{stem}\n{seq}\n")

    sto = out_dir / f"{stem}.sto"
    raw_a3m = out_dir / f"{stem}.a3m"
    filt_a3m = out_dir / f"{stem}.filt.a3m"
    tbl = out_dir / f"{stem}.tbl"

    db_fasta = pick_db_for(stem)

    if VERBOSE:
        print(f"\n=== {stem} ===")
        print("[PATHS]", qfa.resolve())
        print("[PATHS]", sto.resolve())
        print("[PATHS]", raw_a3m.resolve())
        print("[PATHS]", filt_a3m.resolve())
        print("[PATHS] DB_FASTA:", db_fasta)

    assert_single_query(qfa)

    # jackhmmer (single pass, strict inclusion; split DB already speeds this up)
    cmd = [
        "jackhmmer",
        "-N", str(JACK_ITERS),
        "-A", str(sto),
        "--tblout", str(tbl),
        "-E", str(EVALUE),
        "--incE", str(EVALUE),
        "--incdomE", str(EVALUE),
        "--cpu", str(CPU_THREADS),
        str(qfa), str(db_fasta)
    ]
    if VERBOSE: print("[CMD]", " ".join(map(str, cmd)))
    code, out, err = run(cmd)
    if VERBOSE:
        print("[JACK] rc:", code)
        if err: print("[JACK][stderr]\n", err.strip()[:800])
        print("[CHK] STO:", file_info(sto))
        print("[CHK] TBL:", file_info(tbl), "hits=", count_tbl_hits(tbl))

    # Fallback if sto is bad
    if code != 0 or (not sto.exists()) or sto.stat().st_size < 200:
        if VERBOSE: print("WARN: bad/empty .sto → falling back to single-seq A3M")
        raw_a3m.write_text(f">{stem}\n{seq}\n")
        raw = raw_a3m
    else:
        ok = sto_to_a3m(sto, raw_a3m)
        if VERBOSE:
            print("[CHK] A3M after sto->a3m:", file_info(raw_a3m), "nseq=", count_a3m(raw_a3m))
        raw = raw_a3m if ok else (out_dir / f"{stem}.single.a3m")
        if not ok:
            if VERBOSE: print("WARN: sto->a3m failed; using single-seq fallback")
            raw.write_text(f">{stem}\n{seq}\n")

    # hhfilter (chain-specific coverage)
    cov_thr = COV_THR_TCR if stem in ("tcra","tcrb") else COV_THR_MHC
    ok = hhfilter_cap(raw, filt_a3m, MAX_SEQS, ID_THR_DEFAULT, cov_thr)
    if VERBOSE:
        print("[CHK] filt A3M after hhfilter:", file_info(filt_a3m), "nseq=", count_a3m(filt_a3m))

    # Cleanup intermediates if requested
    if not KEEP_INTERMEDIATES:
        for p in (sto, tbl, raw_a3m):
            try: p.unlink()
            except Exception: pass

    return filt_a3m


In [66]:
def read_yaml_sequences(yaml_path):
    """Read sequences from YAML file and return dict with sequences by protein ID"""
    with open(yaml_path, 'r') as f:
        data = yaml.safe_load(f)
    sequences = {}
    for seq_data in data['sequences']:
        protein = seq_data['protein']
        sequences[protein['id']] = protein['sequence']
    return sequences

def update_yaml_with_msa(yaml_path, msa_paths):
    """Update YAML file with MSA paths for each protein"""
    with open(yaml_path, 'r') as f:
        data = yaml.safe_load(f)
    for i, seq_data in enumerate(data['sequences']):
        pid = seq_data['protein']['id']
        if pid in msa_paths:
            data['sequences'][i]['protein']['msa'] = msa_paths[pid]
        elif pid == 'C':  # peptide
            data['sequences'][i]['protein']['msa'] = 'empty'
    with open(yaml_path, 'w') as f:
        yaml.dump(data, f, default_flow_style=False, sort_keys=False)

In [67]:
# new MSA building

import yaml

manifest = pd.read_csv(mani_path)   # mani_path must be defined earlier
if VERBOSE:
    display(manifest.head(3))

# def read_yaml_sequences(yaml_path):
#     """Read sequences from YAML file and return dict with sequences by protein ID"""
#     with open(yaml_path, 'r') as f:
#         data = yaml.safe_load(f)
#     sequences = {}
#     for seq_data in data['sequences']:
#         protein = seq_data['protein']
#         sequences[protein['id']] = protein['sequence']
#     return sequences

# def update_yaml_with_msa(yaml_path, msa_paths):
#     """Update YAML file with MSA paths for each protein"""
#     with open(yaml_path, 'r') as f:
#         data = yaml.safe_load(f)
#     for i, seq_data in enumerate(data['sequences']):
#         pid = seq_data['protein']['id']
#         if pid in msa_paths:
#             data['sequences'][i]['protein']['msa'] = msa_paths[pid]
#         elif pid == 'C':  # peptide
#             data['sequences'][i]['protein']['msa'] = 'empty'
#     with open(yaml_path, 'w') as f:
#         yaml.dump(data, f, default_flow_style=False, sort_keys=False)

results = []
for _, row in manifest.iterrows():
    pair_id = row["pair_id"]
    yaml_path = BASE_DIR / row["yaml_path"]

    sequences = read_yaml_sequences(yaml_path)

    pair_out = OUT_ROOT / pair_id
    pair_out.mkdir(parents=True, exist_ok=True)

    msa_paths = {}

    if 'A' in sequences:
        msa_tcra = build_msa_for_chain(sequences['A'], pair_out, "tcra")
        msa_paths['A'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas/{pair_id}/tcra.filt.a3m"

    if 'B' in sequences:
        msa_tcrb = build_msa_for_chain(sequences['B'], pair_out, "tcrb")
        msa_paths['B'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas/{pair_id}/tcrb.filt.a3m"

    if 'D' in sequences:
        msa_mhc = build_msa_for_chain(sequences['D'], pair_out, "mhc")
        msa_paths['D'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas/{pair_id}/mhc.filt.a3m"

    update_yaml_with_msa(yaml_path, msa_paths)

    results.append((pair_id, msa_paths))
    if VERBOSE:
        print(f"Completed {pair_id}: {list(msa_paths.keys())}")
    print(f"Completed {pair_id}")

print(f"\nCompleted MSA building for {len(results)} pairs")


Completed pair_000
Completed pair_001
Completed pair_002
Completed pair_003
Completed pair_004
Completed pair_005
Completed pair_006
Completed pair_007
Completed pair_008
Completed pair_009
Completed pair_010
Completed pair_011
Completed pair_012
Completed pair_013
Completed pair_014
Completed pair_015
Completed pair_016
Completed pair_017
Completed pair_018
Completed pair_019
Completed pair_020
Completed pair_021
Completed pair_022
Completed pair_023
Completed pair_024
Completed pair_025
Completed pair_026
Completed pair_027
Completed pair_028
Completed pair_029
Completed pair_030
Completed pair_031
Completed pair_032
Completed pair_033
Completed pair_034
Completed pair_035
Completed pair_036
Completed pair_037
Completed pair_038
Completed pair_039
Completed pair_040
Completed pair_041
Completed pair_042
Completed pair_043
Completed pair_044
Completed pair_045
Completed pair_046
Completed pair_047
Completed pair_048
Completed pair_049
Completed pair_050
Completed pair_051
Completed pa

##### Negative Dataset Create MSAs

In [68]:
# Negative pairs

# === Load CSV and preview ===
CSV_PATH = "/home/natasha/multimodal_model/data/raw/HLA/IEDB_Negatives_HLA_class_I_with_HLA_seq.csv"
df = pd.read_csv(CSV_PATH)
df.head(3)

# === Filter usable rows ===
required = ["Peptide","HLA_sequence","TCR_alpha","TCR_beta"]
clean = df.dropna(subset=required).copy()

def clean_seq(s: str) -> str:
    import re
    if not isinstance(s, str):
        return ""
    return re.sub(r"[^A-Za-z]", "", s).upper()

for c in required:
    clean[c] = clean[c].apply(clean_seq)

usable = clean[(clean["Peptide"].str.len()>=8) & 
               (clean["TCR_alpha"].str.len()>=50) & 
               (clean["TCR_beta"].str.len()>=50) & 
               (clean["HLA_sequence"].str.len()>=100)].head(100).copy()

len(usable)


100

In [69]:
NEG_PAIR_DIR = Path("/home/natasha/multimodal_model/data/negative_pairs")
NEG_PAIR_DIR.mkdir(parents=True, exist_ok=True)

NEG_MANI_DIR = Path("/home/natasha/multimodal_model/data/negative_manifests")
NEG_MANI_DIR.mkdir(parents=True, exist_ok=True)



In [70]:
# do it for negative pairs

# === Generate YAMLs + A3Ms + manifest ===
rows = []
for i, row in usable.reset_index(drop=True).iterrows():
    pair_id = f"pair_{i:03d}"
    yml = make_yaml(pair_id, row["TCR_alpha"], row["TCR_beta"], row["Peptide"], row["HLA_sequence"])
    yml_path = NEG_PAIR_DIR / f"{pair_id}.yaml"
    yml_path.write_text(yml)

    rows.append({
        "pair_id": pair_id,
        "yaml_path": f"data/negative_pairs/{pair_id}.yaml",
        "pep_len": len(row["Peptide"]),
        "tcra_len": len(row["TCR_alpha"]),
        "tcrb_len": len(row["TCR_beta"]),
        "hla_len": len(row["HLA_sequence"]),
    })

mani = pd.DataFrame(rows)
mani_path = NEG_MANI_DIR / "boltz_100_manifest.csv"
mani.to_csv(mani_path, index=False)
mani.head(8)


Unnamed: 0,pair_id,yaml_path,pep_len,tcra_len,tcrb_len,hla_len
0,pair_000,data/negative_pairs/pair_000.yaml,9,203,241,365
1,pair_001,data/negative_pairs/pair_001.yaml,9,194,243,365
2,pair_002,data/negative_pairs/pair_002.yaml,10,194,243,365
3,pair_003,data/negative_pairs/pair_003.yaml,9,194,243,365
4,pair_004,data/negative_pairs/pair_004.yaml,9,201,244,365
5,pair_005,data/negative_pairs/pair_005.yaml,9,205,241,362
6,pair_006,data/negative_pairs/pair_006.yaml,9,206,242,362
7,pair_007,data/negative_pairs/pair_007.yaml,9,206,242,362


In [71]:
OUT_ROOT = Path("/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas_negative")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

In [72]:
import yaml

manifest = pd.read_csv(mani_path)   # mani_path must be defined earlier
if VERBOSE:
    display(manifest.head(3))

results = []
for _, row in manifest.iterrows():
    pair_id = row["pair_id"]
    yaml_path = BASE_DIR / row["yaml_path"]

    sequences = read_yaml_sequences(yaml_path)

    pair_out = OUT_ROOT / pair_id
    pair_out.mkdir(parents=True, exist_ok=True)

    msa_paths = {}

    if 'A' in sequences:
        msa_tcra = build_msa_for_chain(sequences['A'], pair_out, "tcra")
        msa_paths['A'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas_negative/{pair_id}/tcra.filt.a3m"

    if 'B' in sequences:
        msa_tcrb = build_msa_for_chain(sequences['B'], pair_out, "tcrb")
        msa_paths['B'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas_negative/{pair_id}/tcrb.filt.a3m"

    if 'D' in sequences:
        msa_mhc = build_msa_for_chain(sequences['D'], pair_out, "mhc")
        msa_paths['D'] = f"/home/natasha/multimodal_model/data/raw/MSA/jackhmmer_msas_negative/{pair_id}/mhc.filt.a3m"

    update_yaml_with_msa(yaml_path, msa_paths)

    results.append((pair_id, msa_paths))
    if VERBOSE:
        print(f"Completed {pair_id}: {list(msa_paths.keys())}")
    print(f"Completed {pair_id}")

print(f"\nCompleted MSA building for {len(results)} pairs")


Completed pair_000
Completed pair_001
Completed pair_002
Completed pair_003
Completed pair_004
Completed pair_005
Completed pair_006
Completed pair_007
Completed pair_008
Completed pair_009
Completed pair_010
Completed pair_011
Completed pair_012
Completed pair_013
Completed pair_014
Completed pair_015
Completed pair_016
Completed pair_017
Completed pair_018
Completed pair_019
Completed pair_020
Completed pair_021
Completed pair_022
Completed pair_023
Completed pair_024
Completed pair_025
Completed pair_026
Completed pair_027
Completed pair_028
Completed pair_029
Completed pair_030
Completed pair_031
Completed pair_032
Completed pair_033
Completed pair_034
Completed pair_035
Completed pair_036
Completed pair_037
Completed pair_038
Completed pair_039
Completed pair_040
Completed pair_041
Completed pair_042
Completed pair_043
Completed pair_044
Completed pair_045
Completed pair_046
Completed pair_047
Completed pair_048
Completed pair_049
Completed pair_050
Completed pair_051
Completed pa

##### Split Data in Train, Validate and Test and Create Pairs to Run in Boltz

Step 1: Split the data into the categories as outlined in IMMREP2025
- keep the most promiscuous peptides and TCRs in the training set to ensure training set is big enough

In [56]:
# === Configure your paths ===
#CSV_PATH = "data/raw/HLA/full_positives_hla_seq.csv"  # update if needed
CSV_PATH = "/home/natasha/multimodal_model/data/raw/HLA/full_positives_hla_seq.csv"
#BASE_DIR = Path(".")                      # point this to your repo root if running elsewhere
BASE_DIR  = Path("/home/natasha/multimodal_model") #/ "data" / "raw"
MSA_DIR   = BASE_DIR / "data" / "raw" / "MSA"
TRAIN_DIR = BASE_DIR / "data" / "train"
VAL_DIR   = BASE_DIR / "data" / "val"
TEST_DIR  = BASE_DIR / "data" / "test"
#PAIR_DIR  = BASE_DIR / "data" / "pairs"
MANI_DIR  = BASE_DIR / "data" / "manifests"

# create directories if they don't exist
MSA_DIR.mkdir(parents=True, exist_ok=True)
#PAIR_DIR.mkdir(parents=True, exist_ok=True)
MANI_DIR.mkdir(parents=True, exist_ok=True)
TRAIN_DIR.mkdir(parents=True, exist_ok=True)
VAL_DIR.mkdir(parents=True, exist_ok=True)
TEST_DIR.mkdir(parents=True, exist_ok=True)

# === Load CSV and preview ===
df = pd.read_csv(CSV_PATH)
df.head(3)


Unnamed: 0,Peptide,HLA,Va,Ja,CDR1a,CDR2a,CDR3a,CDR3a_extended,TCRa,Vb,...,CDR1b,CDR2b,CDR3b,CDR3b_extended,TCRb,references,receptor_id,just_10X,HLA_sequence,TCR_full
0,TTDPSFLGRY,HLA-A*01:01,TRAV9-2*01,TRAJ6*01,ATGYPS,ATKADDK,AASGGSYIPT,CAASGGSYIPTF,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...,TRBV9*01,...,SGDLS,YYNGEE,ASSVEETSAGGHEQF,CASSVEETSAGGHEQFF,DSGVTQTPKHLITATGQRVTLRCSPRSGDLSVYWYQQSLDQGLQFL...,http://www.iedb.org/reference/1039300,203509,True,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...
1,YLQPRTFLL,HLA-A*02:01,TRAV9-2*01,TRAJ45*01,ATGYPS,ATKADDK,AGGADGLT,CAGGADGLTF,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...,TRBV2*01,...,SNHLY,FYNNEI,ASSEWQGEKLF,CASSEWQGEKLFF,EPEVTQTPSHQVTQMGQEVILRCVPISNHLYFYWYRQILGQKVEFL...,http://www.iedb.org/reference/1040829,208619,False,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...
2,ILTGLNYEV,HLA-A*02:01,TRAV9-2*01,unknown,ATGYPS,ATKADDK,ALADMNRDDKII,CALADMNRDDKIIF,<unk>,TRBV9*01,...,SGDLS,YYNGEE,ASSVDPGQSYEQY,CASSVDPGQSYEQYF,<unk>,http://www.iedb.org/reference/1034376,29673,True,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,<unk><unk>


In [None]:
import numpy as np
import pandas as pd

# ============================================================
# 0. Hyperparams & constants
# ============================================================

RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)

# Peptides you *never* want to end up in unseen categories
PEPTIDES_TO_KEEP = [
    "KLGGALQAK", "GILGFVFTL", "AVFDRKSDAK", "RAKFKQLL", "SPRWYFYYL",
    "YLQPRTFLL", "TTDPSFLGRY", "GLCTLVAML", "RVRAYTYSK", "IVTDFSVIK",
    "LLWNGPMAV", "LLLDRLNQL", "NLVPMVATV", "LLAGIGTVPI", "RLRAEAQVK",
    "ELAGIGILTV", "YVLDHLIVV", "LTDEMIAQY", "CINGVCWTV", "TPRVTGGGAM",
    "VMATRRNVL", "KTFPPTEPK", "QYIKWPWYI", "DATYQRTRALVR", "NQKLIANQF",
    "FLRGRAYGL", "CTELKLSDY", "ATDALMTGF", "RPPIFIRRL", "NYNYLYRLF",
    "FLYALALLL", "VMTTVLATL", "CLGGLLTMV", "KSKRTPMGF", "RPHERNGFTVL",
    "MEVTPSGTWL", "FTSDYYQLY", "RPIIRPATL", "ALAGIGILTV", "LLYDANYFL",
    "HPVTKYIM", "RLPGVLPRA", "RFPLTFGWCF", "VYFLQSINF", "PTDNYITTY",
    "ALWEIQQVV", "QAKWRLQTL", "RTATKQYNV", "LLFGYPVYV"
]

# Backbone HLAs that should always stay in train (can also appear in val/test,
# but not as "unseen_HLA" regimes)
BACKBONE_HLAS = {
    'HLA-A*02:01',
    'HLA-A*01:01',
    'HLA-A*24:02',
    'HLA-B*07:02',
    'HLA-B*08:01',
    'HLA-A*03:01',
    'HLA-A*11:01',
}

# How many HLAs to use as explicit unseen-HLA regimes
N_TEST_UNSEEN_HLA = 2
N_VAL_UNSEEN_HLA  = 2

# Minimum peptides for an HLA to be eligible as an "unseen HLA" regime
MIN_PEPTIDES_UNSEEN_HLA_TEST = 10
MIN_PEPTIDES_UNSEEN_HLA_VAL  = 10  # you can lower this if needed


# ============================================================
# 1. Pre-processing
# ============================================================

def preprocess_df(df: pd.DataFrame) -> pd.DataFrame:
    """Clean up TCR chains, build TCR_full and masks."""
    df = df.copy()

    # 1. Fill empty/nan values with 'X' token
    df['TCRa'] = df['TCRa'].fillna('X')
    df['TCRb'] = df['TCRb'].fillna('X')

    # Replace empty strings with 'X'
    df.loc[df['TCRa'] == '', 'TCRa'] = 'X'
    df.loc[df['TCRb'] == '', 'TCRb'] = 'X'

    # Build TCR_full and alpha/beta masks
    df['TCR_full'] = df['TCRa'] + df['TCRb']
    df['m_alpha'] = 1
    df['m_beta'] = 1
    df.loc[df['TCRa'] == 'X', 'm_alpha'] = 0
    df.loc[df['TCRb'] == 'X', 'm_beta'] = 0

    # Remove rows with invalid TCR_full entries
    df = df.dropna(subset=['TCR_full'])
    df = df[df['TCR_full'] != ' ']
    df = df[df['TCR_full'] != 'nan']

    return df


# ============================================================
# 2. Category builder for unseen TCR / unseen peptide / completely unseen
# ============================================================

def make_unseen_tcr_peptide_categories(
    df_in: pd.DataFrame,
    peptides_to_keep,
    pct_tcr_cat1=0.02,
    pct_tcr_cat2=0.05,
    pct_peptide_cat3=0.01,
    category_col='category',
    prefix=''
):
    """
    Build three categories:
      - completely_unseen   (unseen TCR & unseen peptide)
      - unseen_TCR          (TCR unseen in train, peptide seen in train)
      - unseen_peptide      (peptide unseen in train, TCR seen in train)

    Returns:
      cat1_df, cat2_df, cat3_df, remaining_df
    """
    df = df_in.copy()

    # -----------------------------
    # Category 1: Completely unseen
    # -----------------------------
    unique_tcrs = df['TCR_full'].unique()
    if len(unique_tcrs) == 0:
        return (df.iloc[0:0].copy(),) * 4  # all empty

    n_cat1_tcrs = max(1, int(len(unique_tcrs) * pct_tcr_cat1))
    selected_tcrs_cat1 = set(rng.choice(unique_tcrs, size=n_cat1_tcrs, replace=False))

    tcr_pairs = df[df['TCR_full'].isin(selected_tcrs_cat1)]
    # remove peptides we insist must remain in training
    tcr_pairs = tcr_pairs[~tcr_pairs['Peptide'].isin(peptides_to_keep)]

    selected_peptides_cat1 = set(tcr_pairs['Peptide'].unique())
    selected_tcrs_cat1 = set(tcr_pairs['TCR_full'].unique())

    cat1_df = df[
        df['TCR_full'].isin(selected_tcrs_cat1) &
        df['Peptide'].isin(selected_peptides_cat1)
    ].copy()
    cat1_df[category_col] = f'{prefix}completely_unseen'

    # Candidates for other categories = everything that does NOT use these TCRs OR these peptides
    train_candidate = df[
        ~(df['TCR_full'].isin(selected_tcrs_cat1) | df['Peptide'].isin(selected_peptides_cat1))
    ].copy()

    # -----------------------------
    # Category 2: Unseen TCR (peptide seen)
    # -----------------------------
    remaining_unique_tcrs = train_candidate['TCR_full'].unique()
    if len(remaining_unique_tcrs) > 0:
        n_cat2_tcrs = max(1, int(len(remaining_unique_tcrs) * pct_tcr_cat2))
        selected_tcrs_cat2 = set(
            rng.choice(remaining_unique_tcrs, size=n_cat2_tcrs, replace=False)
        )
    else:
        selected_tcrs_cat2 = set()

    cat2_df = train_candidate[train_candidate['TCR_full'].isin(selected_tcrs_cat2)].copy()
    cat2_df[category_col] = f'{prefix}unseen_TCR'

    # Remove all these rows from candidate pool
    train_candidate = train_candidate.drop(cat2_df.index)

    # -----------------------------
    # Category 3: Unseen peptide (TCR seen)
    # -----------------------------
    remaining_unique_peptides = train_candidate['Peptide'].unique()
    if len(remaining_unique_peptides) > 0:
        n_cat3_peptides = max(1, int(len(remaining_unique_peptides) * pct_peptide_cat3))
        selected_peptides_cat3 = set(
            rng.choice(remaining_unique_peptides, size=n_cat3_peptides, replace=False)
        )
        selected_peptides_cat3 = selected_peptides_cat3 - set(peptides_to_keep)
    else:
        selected_peptides_cat3 = set()

    cat3_df = train_candidate[train_candidate['Peptide'].isin(selected_peptides_cat3)].copy()
    cat3_df[category_col] = f'{prefix}unseen_peptide'

    # Remove all rows containing these peptides from candidate pool
    remaining_df = train_candidate.drop(cat3_df.index)

    return cat1_df, cat2_df, cat3_df, remaining_df


# ============================================================
# 3. Top-level split function
# ============================================================

def split_tcr_dataset(
    df_raw: pd.DataFrame,
    peptides_to_keep=PEPTIDES_TO_KEEP,
    backbone_hlas=BACKBONE_HLAS,
    n_test_unseen_hla=N_TEST_UNSEEN_HLA,
    n_val_unseen_hla=N_VAL_UNSEEN_HLA,
    min_peptides_unseen_hla_test=MIN_PEPTIDES_UNSEEN_HLA_TEST,
    min_peptides_unseen_hla_val=MIN_PEPTIDES_UNSEEN_HLA_VAL
):
    """
    Main entry point:
      - cleans df
      - builds test (HLA unseen + Cat1/2/3)
      - builds validation (HLA unseen + Cat1/2/3 on remaining)
      - builds final train by globally removing all "unseen" entities
    """
    df = preprocess_df(df_raw)

    total_pairs = len(df)
    print(f"Total pairs after cleaning: {total_pairs}")
    print(f"Unique TCRs: {df['TCR_full'].nunique()}")
    print(f"Unique Peptides: {df['Peptide'].nunique()}")
    print(f"Unique HLAs: {df['HLA'].nunique()}")

    # --------------------------------------------------------
    # 3.1 HLA summary on full data
    # --------------------------------------------------------
    hla_table = df.groupby('HLA').agg(
        TCR_full=('TCR_full', 'nunique'),
        Peptide=('Peptide', 'nunique')
    ).reset_index().sort_values('Peptide', ascending=False)

    print("\nTop HLAs by peptide count:")
    print(hla_table.head(10))

    # --------------------------------------------------------
    # 3.2 Choose unseen HLAs for TEST
    # --------------------------------------------------------
    candidate_hlas_test = set(
        hla_table.loc[
            (hla_table['Peptide'] >= min_peptides_unseen_hla_test) &
            (~hla_table['HLA'].isin(backbone_hlas)),
            'HLA'
        ]
    )

    if len(candidate_hlas_test) == 0:
        print("\n[WARN] No HLAs eligible for unseen-HLA TEST regime with current threshold.")
        test_unseen_hlas = set()
    else:
        n_test = min(n_test_unseen_hla, len(candidate_hlas_test))
        test_unseen_hlas = set(
            rng.choice(list(candidate_hlas_test), size=n_test, replace=False)
        )

    print("\nTest unseen HLAs chosen:", test_unseen_hlas)

    # All rows with these HLAs go straight to test (HLA-unseen category)
    cat0_test_df = df[df['HLA'].isin(test_unseen_hlas)].copy()
    cat0_test_df['test_category'] = 'unseen_HLA'

    # Remaining pool for test Cat1/2/3
    pool_after_test_hla = df[~df['HLA'].isin(test_unseen_hlas)].copy()

    # --------------------------------------------------------
    # 3.3 Build Cat1/Cat2/Cat3 for TEST
    # --------------------------------------------------------
    cat1_test, cat2_test, cat3_test, _ = make_unseen_tcr_peptide_categories(
        pool_after_test_hla,
        peptides_to_keep=peptides_to_keep,
        category_col='test_category',
        prefix=''
    )

    print(f"\nTest Cat1 (completely_unseen) pairs: {len(cat1_test)}")
    print(f"Test Cat2 (unseen_TCR) pairs: {len(cat2_test)}")
    print(f"Test Cat3 (unseen_peptide) pairs: {len(cat3_test)}")

    other_test_df = pd.concat([cat1_test, cat2_test, cat3_test]).drop_duplicates()
    test_df = pd.concat([cat0_test_df, other_test_df]).drop_duplicates()

    print(f"\nTotal TEST pairs: {len(test_df)} "
          f"({len(test_df) / total_pairs * 100:.2f}%)")

    # --------------------------------------------------------
    # 3.4 Build initial dev pool (for val + future train)
    # BUT: we will later *globally* exclude unseen entities again for train_df.
    # For val construction, it's enough to remove test rows by index.
    # --------------------------------------------------------
    dev_pool = df[~df.index.isin(test_df.index)].copy()

    # --------------------------------------------------------
    # 3.5 Choose unseen HLAs for VALIDATION from dev_pool
    # We must exclude ANY HLA that appears in test (not just test_unseen_hlas)
    # to avoid overlap of unseen HLA regimes.
    # --------------------------------------------------------
    test_hlas = set(test_df['HLA'])

    hla_table_dev = dev_pool.groupby('HLA').agg(
        TCR_full=('TCR_full', 'nunique'),
        Peptide=('Peptide', 'nunique')
    ).reset_index()

    candidate_hlas_val = set(
        hla_table_dev.loc[
            (hla_table_dev['Peptide'] >= min_peptides_unseen_hla_val) &
            (~hla_table_dev['HLA'].isin(backbone_hlas)) &
            (~hla_table_dev['HLA'].isin(test_hlas)),
            'HLA'
        ]
    )

    if len(candidate_hlas_val) == 0:
        print("\n[WARN] No HLAs eligible for unseen-HLA VAL regime with current threshold.")
        val_unseen_hlas = set()
    else:
        n_val = min(n_val_unseen_hla, len(candidate_hlas_val))
        val_unseen_hlas = set(
            rng.choice(list(candidate_hlas_val), size=n_val, replace=False)
        )

    print("\nVal unseen HLAs chosen:", val_unseen_hlas)

    cat0_val_df = dev_pool[dev_pool['HLA'].isin(val_unseen_hlas)].copy()
    cat0_val_df['val_category'] = 'unseen_HLA'

    pool_after_val_hla = dev_pool[~dev_pool['HLA'].isin(val_unseen_hlas)].copy()

    # --------------------------------------------------------
    # 3.6 Build Cat1/Cat2/Cat3 for VALIDATION
    # --------------------------------------------------------
    cat1_val, cat2_val, cat3_val, _ = make_unseen_tcr_peptide_categories(
        pool_after_val_hla,
        peptides_to_keep=peptides_to_keep,
        category_col='val_category',
        prefix=''
    )

    print(f"\nVal Cat1 (completely_unseen) pairs: {len(cat1_val)}")
    print(f"Val Cat2 (unseen_TCR) pairs: {len(cat2_val)}")
    print(f"Val Cat3 (unseen_peptide) pairs: {len(cat3_val)}")

    val_df = pd.concat([cat0_val_df, cat1_val, cat2_val, cat3_val]).drop_duplicates()

    print(f"\nTotal VAL pairs: {len(val_df)} "
          f"({len(val_df) / total_pairs * 100:.2f}%)")

    # --------------------------------------------------------
    # 3.7 Build FINAL TRAIN by globally excluding all "unseen" entities
    # --------------------------------------------------------
    # TCRs/peptides used in unseen regimes must never appear in train

    # Test unseen sets
    test_cat1_tcrs = set(cat1_test['TCR_full'])
    test_cat1_peps = set(cat1_test['Peptide'])

    test_cat2_tcrs = set(cat2_test['TCR_full'])
    test_cat3_peps = set(cat3_test['Peptide'])

    # Val unseen sets
    val_cat1_tcrs = set(cat1_val['TCR_full'])
    val_cat1_peps = set(cat1_val['Peptide'])

    val_cat2_tcrs = set(cat2_val['TCR_full'])
    val_cat3_peps = set(cat3_val['Peptide'])

    forbidden_tcrs_for_train = (
        test_cat1_tcrs | test_cat2_tcrs |
        val_cat1_tcrs | val_cat2_tcrs
    )
    forbidden_peps_for_train = (
        test_cat1_peps | test_cat3_peps |
        val_cat1_peps | val_cat3_peps
    )
    forbidden_hlas_for_train = test_unseen_hlas | val_unseen_hlas

    train_df = df[
        (~df.index.isin(test_df.index)) &
        (~df.index.isin(val_df.index)) &
        (~df['TCR_full'].isin(forbidden_tcrs_for_train)) &
        (~df['Peptide'].isin(forbidden_peps_for_train)) &
        (~df['HLA'].isin(forbidden_hlas_for_train))
    ].copy()

    print(f"\nTotal TRAIN pairs: {len(train_df)} "
          f"({len(train_df) / total_pairs * 100:.2f}%)")

    meta = {
        'test_unseen_hlas': test_unseen_hlas,
        'val_unseen_hlas': val_unseen_hlas,
        'backbone_hlas': backbone_hlas,
    }

    return train_df, val_df, test_df, meta


# ============================================================
# 4. Sanity checks
# ============================================================

def run_sanity_checks(train_df, val_df, test_df):
    print("\n========== SANITY CHECKS ==========")

    train_tcrs = set(train_df['TCR_full'])
    train_peps = set(train_df['Peptide'])

    # --- Test categories
    if 'test_category' in test_df.columns:
        for cat_name, label in [
            ("Completely Unseen", 'completely_unseen'),
            ("Unseen TCR", 'unseen_TCR'),
            ("Unseen Peptide", 'unseen_peptide')
        ]:
            cat = test_df[test_df['test_category'] == label]
            tcrs = set(cat['TCR_full'])
            peps = set(cat['Peptide'])

            overlap_tcr = len(tcrs & train_tcrs)
            overlap_pep = len(peps & train_peps)

            print(f"\n[TEST] {cat_name} ({label})")
            print(f"  #pairs: {len(cat)}")
            print(f"  Overlap TCR with TRAIN: {overlap_tcr}")
            print(f"  Overlap Peptide with TRAIN: {overlap_pep}")

    # --- HLA overlaps
    train_hlas = set(train_df['HLA'])
    val_hlas   = set(val_df['HLA'])
    test_hlas  = set(test_df['HLA'])

    print("\nHLA counts:")
    print("  Train HLAs:", len(train_hlas))
    print("  Val HLAs:  ", len(val_hlas))
    print("  Test HLAs: ", len(test_hlas))

    unseen_hlas_test = test_hlas - train_hlas
    unseen_hlas_val  = val_hlas - train_hlas

    print("\nUnseen HLAs in TEST vs TRAIN:", unseen_hlas_test)
    print("Unseen HLAs in VAL vs TRAIN:", unseen_hlas_val)
    print("Overlap of unseen-HLA between VAL and TEST:",
          unseen_hlas_test & unseen_hlas_val)

    print("===================================\n")


In [57]:
train_df, val_df, test_df, meta = split_tcr_dataset(df)

run_sanity_checks(train_df, val_df, test_df)
print("Meta:", meta)

# save to data
train_df.to_csv(TRAIN_DIR / 'train_df.csv', index=False)
val_df.to_csv(VAL_DIR / 'val_df.csv', index=False)
test_df.to_csv(TEST_DIR / 'test_df.csv', index=False)





Total pairs after cleaning: 39926
Unique TCRs: 36197
Unique Peptides: 1457
Unique HLAs: 71

Top HLAs by peptide count:
               HLA  TCR_full  Peptide
1      HLA-A*02:01     12012      864
0      HLA-A*01:01      1961      122
23     HLA-B*07:02      1861      107
15     HLA-A*24:02       736       89
25     HLA-B*08:01      2114       59
3   HLA-A*02:01:48        40       35
29     HLA-B*27:05        24       31
32     HLA-B*35:01        60       22
54     HLA-B*57:03         1       22
12     HLA-A*11:01      2344       18

Test unseen HLAs chosen: {np.str_('HLA-A*02:01:48'), np.str_('HLA-B*57:03')}

Test Cat1 (completely_unseen) pairs: 116
Test Cat2 (unseen_TCR) pairs: 1866
Test Cat3 (unseen_peptide) pairs: 33

Total TEST pairs: 2097 (5.25%)

[WARN] No HLAs eligible for unseen-HLA VAL regime with current threshold.

Val unseen HLAs chosen: set()

Val Cat1 (completely_unseen) pairs: 121
Val Cat2 (unseen_TCR) pairs: 1752
Val Cat3 (unseen_peptide) pairs: 26

Total VAL pairs: 1899

In [None]:
# MSA Building and YAML generation

#### Old Code

In [None]:
# separate into train, val and test
# taken from data_test_val_split.py (IMMREP2025 folder)

pd.set_option('display.max_rows', None)


# 1. Fill empty/nan values with <unk> token
df['TCRa'] = df['TCRa'].fillna('X')
df['TCRb'] = df['TCRb'].fillna('X')

# Replace empty strings with <unk>
df.loc[df['TCRa'] == '', 'TCRa'] = 'X'
df.loc[df['TCRb'] == '', 'TCRb'] = 'X'

df['TCR_full'] = df['TCRa'] + df['TCRb']
df['m_alpha'] = 1
df['m_beta'] = 1
df.loc[df['TCRa'] == 'X', 'm_alpha'] = 0
df.loc[df['TCRb'] == 'X', 'm_beta'] = 0

# Remove rows with invalid TCR_full entries
df = df.dropna(subset=['TCR_full'])
df = df[df['TCR_full'] != ' ']
df = df[df['TCR_full'] != 'nan']

# # Calculate the number of peptides per TCR
# peptides_per_tcr = df.groupby('TCR_full')['Peptide'].nunique().reset_index(name='peptide_count')

# # Calculate the number of TCRs per peptide
# tcrs_per_peptide = df.groupby('Peptide')['TCR_full'].nunique().reset_index(name='tcr_count')


# print(peptides_per_tcr.head(3), tcrs_per_peptide.head(3))

# potentially need to separate HLAs as well? 


# ------------------------------
# 2. Overview of the Dataset
# ------------------------------

# how many unique peptides, tcrs and hla
total_pairs = len(df)
unique_tcrs = set(df['TCR_full'])
unique_peptides = set(df['Peptide'])
num_unique_tcrs = len(unique_tcrs)
num_unique_peptides = len(unique_peptides)
num_unique_hlas = len(set(df['HLA']))


print(f"Total pairs: {total_pairs}")
print(f"Unique TCRs: {num_unique_tcrs}")
print(f"Unique Peptides: {num_unique_peptides}")
print(f"Unique HLAs: {num_unique_hlas}")

#number_of_peptides_per_hla = df.groupby('HLA')['Peptide'].nunique().reset_index(name='peptide_count')
avg_peptides_per_tcr = total_pairs / num_unique_tcrs
avg_tcrs_per_peptide = total_pairs / num_unique_peptides
#number_of_tcrs_per_hla = df.groupby('HLA')['TCR_full'].nunique().reset_index(name='tcr_count')
#hlas = number_of_tcrs_per_hla.merge(number_of_peptides_per_hla, on='HLA', how='left')
hlas = df.groupby('HLA').agg({'TCR_full': 'nunique', 'Peptide': 'nunique'}).reset_index()

print(f"Avg. peptides per TCR: {avg_peptides_per_tcr:.2f}")
print(f"Avg. TCRs per Peptide: {avg_tcrs_per_peptide:.2f}")
#print(number_of_peptides_per_hla.head(10))
#print(number_of_tcrs_per_hla.head(10))
print(hlas.head(71))




# data_for_other_cats = combined_df[~combined_df['HLA'].isin(selected_hlas_cat0)]
data_for_other_cats = df.copy()

# Define a list of peptides that must remain in training (if desired)
peptides_to_keep = [
    "KLGGALQAK", "GILGFVFTL", "AVFDRKSDAK", "RAKFKQLL", "SPRWYFYYL",
    "YLQPRTFLL", "TTDPSFLGRY", "GLCTLVAML", "RVRAYTYSK", "IVTDFSVIK",
    "LLWNGPMAV", "LLLDRLNQL", "NLVPMVATV", "LLAGIGTVPI", "RLRAEAQVK",
    "ELAGIGILTV", "YVLDHLIVV", "LTDEMIAQY", "CINGVCWTV", "TPRVTGGGAM",
    "VMATRRNVL", "KTFPPTEPK", "QYIKWPWYI", "DATYQRTRALVR", "NQKLIANQF",
    "FLRGRAYGL", "CTELKLSDY", "ATDALMTGF", "RPPIFIRRL", "NYNYLYRLF",
    "FLYALALLL", "VMTTVLATL", "CLGGLLTMV", "KSKRTPMGF", "RPHERNGFTVL",
    "MEVTPSGTWL", "FTSDYYQLY", "RPIIRPATL", "ALAGIGILTV", "LLYDANYFL",
    "HPVTKYIM", "RLPGVLPRA", "RFPLTFGWCF", "VYFLQSINF", "PTDNYITTY",
    "ALWEIQQVV", "QAKWRLQTL", "RTATKQYNV", "LLFGYPVYV"
]


# --- Category 1: Completely Unseen Pairs ---
# Choose percentages for unique TCRs and peptides (start with 2% each)
pct_tcr_cat1 = 0.02
pct_peptide_cat1 = 0.02

unique_tcrs = set(data_for_other_cats['TCR_full'])
num_unique_tcrs = len(unique_tcrs)
selected_tcrs_cat1 = set(np.random.choice(list(unique_tcrs), size=int(num_unique_tcrs * pct_tcr_cat1), replace=False))

# Get all pairs that involve the selected TCRs
tcr_pairs = data_for_other_cats[data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1)]
# Optionally, remove pairs with peptides we want to keep
tcr_pairs = tcr_pairs[~tcr_pairs['Peptide'].isin(peptides_to_keep)]

# Derive the set of peptides from these pairs
selected_peptides_cat1 = set(tcr_pairs['Peptide'].unique())
# Update the TCR set to only those that remain after filtering
selected_tcrs_cat1 = set(tcr_pairs['TCR_full'].unique())

# Category 1: Define as all pairs where BOTH the TCR is in selected_tcrs_cat1
# AND the peptide is in selected_peptides_cat1
cat1_df = data_for_other_cats[
    data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) &
    data_for_other_cats['Peptide'].isin(selected_peptides_cat1)
].copy()
cat1_df['test_category'] = 'completely_unseen'
cat1_df.loc[cat1_df['TCR_full'].str.startswith('<unk>'), 'test_category'] = 'completely_unseen_unknownalpha'
cat1_df.loc[cat1_df['TCR_full'].str.endswith('<unk>'), 'test_category'] = 'completely_unseen_unknownbeta'

print("Category 1 (completely unseen) pairs:", len(cat1_df))

# print how many unique hlas are in cat1
unique_hlas_cat1 = set(cat1_df['HLA'])
print(f"Unique HLAs in Category 1: {len(unique_hlas_cat1)}")#,
        #'Unique HLAs in cat1', unique_hlas_cat1)

# Now, to ensure these TCRs and peptides do not appear anywhere in training,
# define the training candidate as all rows that do NOT contain any selected TCR or selected peptide.
train_candidate = data_for_other_cats[
    ~(data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) | data_for_other_cats['Peptide'].isin(selected_peptides_cat1))
]

# ------------------------------
# --- Category 2: Unseen TCR but Seen Peptide ---
# From train_candidate, select a percentage of unique TCRs (e.g., 5%)
remaining_unique_tcrs = set(train_candidate['TCR_full'])
pct_tcr_cat2 = 0.05
selected_tcrs_cat2 = set(np.random.choice(list(remaining_unique_tcrs), size=int(len(remaining_unique_tcrs) * pct_tcr_cat2), replace=False))

cat2_df = train_candidate[train_candidate['TCR_full'].isin(selected_tcrs_cat2)].copy()
# Add more specific tags for TCRs with unknown regions
cat2_df['test_category'] = 'unseen_TCR'
cat2_df.loc[cat2_df['TCR_full'].str.startswith('<unk>'), 'test_category'] = 'unseen_TCR_unknownalpha'
cat2_df.loc[cat2_df['TCR_full'].str.endswith('<unk>'), 'test_category'] = 'unseen_TCR_unknownbeta'

print("Category 2 (unseen TCR) pairs:", len(cat2_df))

# print how many unique hlas are in cat2
unique_hlas_cat2 = set(cat2_df['HLA'])
print(f"Unique HLAs in Category 2: {len(unique_hlas_cat2)}")#,
        #'Unique HLAs in cat2', unique_hlas_cat2)

# Remove Category 2 pairs from train_candidate
train_candidate = train_candidate.drop(cat2_df.index)

# ------------------------------
# --- Category 3: Unseen Peptide but Seen TCR ---
# From train_candidate, select a percentage of unique peptides (e.g., 1%)
remaining_unique_peptides = set(train_candidate['Peptide'])
pct_peptide_cat3 = 0.01
selected_peptides_cat3 = set(np.random.choice(list(remaining_unique_peptides), size=int(len(remaining_unique_peptides) * pct_peptide_cat3), replace=False))
# Remove any peptides that need to be kept in training from the selected peptides
selected_peptides_cat3 = selected_peptides_cat3 - set(peptides_to_keep)

cat3_df = train_candidate[train_candidate['Peptide'].isin(selected_peptides_cat3)].copy()

cat3_df['test_category'] = 'unseen_peptide'
print("Category 3 (unseen peptide) pairs:", len(cat3_df))

# print how many unique hlas are in cat3
unique_hlas_cat3 = set(cat3_df['HLA'])
print(f"Unique HLAs in Category 3: {len(unique_hlas_cat3)}")#,
        #'Unique HLAs in cat3', unique_hlas_cat3)

# Combine all test categories except unseen_HLA
other_test_df = pd.concat([cat1_df, cat2_df, cat3_df])
other_test_df = other_test_df.drop_duplicates()
unique_hlas_other_test = set(other_test_df['HLA'])
print(f"Unique HLAs in test set: {len(unique_hlas_other_test)}")#,
        #'Unique HLAs in test set', unique_hlas_other_test)

# Final training set is the remainder of data_for_other_cats that does NOT contain any TCR or peptide selected for Category 1,
# AND also does not contain the pairs selected for Categories 2 and 3.
final_train_df = train_candidate.drop(cat3_df.index)

# Alternatively, if you want to be sure no overlap exists:
final_train_df = data_for_other_cats[
    ~(data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) |
      data_for_other_cats['Peptide'].isin(selected_peptides_cat1) |
      data_for_other_cats['TCR_full'].isin(selected_tcrs_cat2) |
      data_for_other_cats['Peptide'].isin(selected_peptides_cat3))
]
unique_hlas_final_train = set(final_train_df['HLA'])
print(f"Unique HLAs in final train set: {len(unique_hlas_final_train)}")#,
        #'Unique HLAs in final train set', unique_hlas_final_train)

# overlapping hlas between train and test
overlapping_hlas = unique_hlas_final_train.intersection(unique_hlas_other_test)
print(f"Overlapping HLAs between train and test: {len(overlapping_hlas)}")#,
        #'Overlapping HLAs', overlapping_hlas)


print(f"\nFinal Split:")
print(f"Training set: {len(final_train_df)} pairs ({len(final_train_df) / total_pairs * 100:.2f}%)")
#print(f"Test set (unseen HLA): {len(cat0_df)} pairs ({len(cat0_df) / total_pairs * 100:.2f}%)")
print(f"Test set (other): {len(other_test_df)} pairs ({len(other_test_df) / total_pairs * 100:.2f}%)")


# Create sets from the final training set for lookup
train_tcr_set = set(final_train_df['TCR_full'])
train_peptide_set = set(final_train_df['Peptide'])

print("\nRunning Validation Tests...")

# Test Case 1: Completely Unseen Pairs (Category: 'completely_unseen')
cat1 = other_test_df[other_test_df['test_category'] == 'completely_unseen']
cat1_tcrs = set(cat1['TCR_full'])
cat1_peptides = set(cat1['Peptide'])
overlap_tcr_cat1 = cat1_tcrs.intersection(train_tcr_set)
overlap_peptides_cat1 = cat1_peptides.intersection(train_peptide_set)

print("\nCategory 1: Completely Unseen Pairs")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat1)} (should be 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat1)} (should be 0)")
assert len(overlap_tcr_cat1) == 0, "Error: Some TCRs in 'completely_unseen' category are in training!"
assert len(overlap_peptides_cat1) == 0, "Error: Some peptides in 'completely_unseen' category are in training!"

# Test Case 2: Unseen TCR but Seen Peptide (Category: 'unseen_TCR')
cat2 = other_test_df[other_test_df['test_category'] == 'unseen_TCR']
cat2_tcrs = set(cat2['TCR_full'])
cat2_peptides = set(cat2['Peptide'])
overlap_tcr_cat2 = cat2_tcrs.intersection(train_tcr_set)
overlap_peptides_cat2 = cat2_peptides.intersection(train_peptide_set)

print("\nCategory 2: Unseen TCR but Seen Peptide")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat2)} (should be 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat2)} (should be > 0)")
assert len(overlap_tcr_cat2) == 0, "Error: Some TCRs in 'unseen_TCR' category are in training!"
assert len(overlap_peptides_cat2) > 0, "Error: No peptides in 'unseen_TCR' category are in training!"

# Test Case 3: Unseen Peptide but Seen TCR (Category: 'unseen_peptide')
cat3 = other_test_df[other_test_df['test_category'] == 'unseen_peptide']
cat3_tcrs = set(cat3['TCR_full'])
cat3_peptides = set(cat3['Peptide'])
overlap_tcr_cat3 = cat3_tcrs.intersection(train_tcr_set)
overlap_peptides_cat3 = cat3_peptides.intersection(train_peptide_set)

print("\nCategory 3: Unseen Peptide but Seen TCR")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat3)} (should be > 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat3)} (should be 0)")
assert len(overlap_peptides_cat3) == 0, "Error: Some peptides in 'unseen_peptide' category are in training!"
assert len(overlap_tcr_cat3) > 0, "Error: No TCRs in 'unseen_peptide' category are in training!"


Total pairs: 39926
Unique TCRs: 36197
Unique Peptides: 1457
Unique HLAs: 71
Avg. peptides per TCR: 1.10
Avg. TCRs per Peptide: 27.40
                  HLA  TCR_full  Peptide
0         HLA-A*01:01      1961      122
1         HLA-A*02:01     12012      864
2     HLA-A*02:01:110         1        1
3      HLA-A*02:01:48        40       35
4      HLA-A*02:01:59         2        2
5      HLA-A*02:01:98         2        2
6         HLA-A*02:05        22        3
7         HLA-A*02:06         6        3
8   HLA-A*02:06:01:03         1        1
9        HLA-A*02:266         1        1
10        HLA-A*03:01     14831       18
11        HLA-A*08:01        64        1
12        HLA-A*11:01      2344       18
13     HLA-A*11:01:18         1        2
14        HLA-A*24:01        74        2
15        HLA-A*24:02       736       89
16     HLA-A*24:02:33         1        1
17     HLA-A*24:02:84         2        3
18        HLA-A*25:01         1        1
19        HLA-A*29:02         5        6
20    

In [None]:
# Further split train into train and val
train_df, val_df = train_test_split(final_train_df, test_size=0.1, random_state=42)

# Save the splits
# train_df.to_csv(TRAIN_DIR / "train.csv", index=False)
# val_df.to_csv(VAL_DIR / "val.csv", index=False)



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
# # ------------------------------
# # 1. Data Loading & Preprocessing
# # ------------------------------

# # Read the input files
# iedb_df = pd.read_csv('iedb_positives.csv')
# vdjdb_df = pd.read_csv('vdjdb_positives.csv')

# # Add source column to each dataframe
# iedb_df['source'] = 'iedb'
# vdjdb_df['source'] = 'vdjdb'

# # Combine TCRa and TCRb columns to form a unique TCR identifier
# # Fill empty/nan values with <unk> token
# iedb_df['TCRa'] = iedb_df['TCRa'].fillna('<unk>')
# iedb_df['TCRb'] = iedb_df['TCRb'].fillna('<unk>')
# vdjdb_df['TCRa'] = vdjdb_df['TCRa'].fillna('<unk>')
# vdjdb_df['TCRb'] = vdjdb_df['TCRb'].fillna('<unk>')

# # Replace empty strings with <unk>
# iedb_df.loc[iedb_df['TCRa'] == '', 'TCRa'] = '<unk>'
# iedb_df.loc[iedb_df['TCRb'] == '', 'TCRb'] = '<unk>'
# vdjdb_df.loc[vdjdb_df['TCRa'] == '', 'TCRa'] = '<unk>'
# vdjdb_df.loc[vdjdb_df['TCRb'] == '', 'TCRb'] = '<unk>'

# # iedb_df['TCR_full'] = iedb_df['TCRa'] + '<sep>' +  iedb_df['TCRb']
# # vdjdb_df['TCR_full'] = vdjdb_df['TCRa'] + '<sep>' + vdjdb_df['TCRb']

# iedb_df['TCR_full'] = iedb_df['TCRa'] + iedb_df['TCRb']
# vdjdb_df['TCR_full'] = vdjdb_df['TCRa'] + vdjdb_df['TCRb']

# # Combine both datasets
# combined_df = pd.concat([iedb_df, vdjdb_df], ignore_index=True)

# # Remove rows with invalid TCR_full entries
# combined_df = combined_df.dropna(subset=['TCR_full'])
# combined_df = combined_df[combined_df['TCR_full'] != ' ']
# combined_df = combined_df[combined_df['TCR_full'] != 'nan']

# #uknowqn - <unk>
# # separation <sep>
# # TCR vs peptide token - if need to add for just 1 model

# # ------------------------------
# # Plot histograms of data
# # Plot histograms
# # Calculate the number of peptides per TCR
# peptides_per_tcr = combined_df.groupby('TCR_full')['Peptide'].nunique().reset_index(name='peptide_count')

# # Calculate the number of TCRs per peptide
# tcrs_per_peptide = combined_df.groupby('Peptide')['TCR_full'].nunique().reset_index(name='tcr_count')


# # Plot histogram for peptides per TCR
# plt.figure(figsize=(10, 5))
# max_peptides = peptides_per_tcr['peptide_count'].max()
# print(f"Max peptides per TCR: {max_peptides}")
# sns.histplot(peptides_per_tcr['peptide_count'], bins=30)
# plt.xlabel("Number of Peptides per TCR")
# plt.ylabel("Number of TCRs")
# plt.title("Distribution of Peptides per TCR")
# plt.xlim(0, min(max_peptides, peptides_per_tcr['peptide_count'].quantile(0.99)))
# #plt.show()
# plt.savefig('Peptides_per_TCR.png')

# # Plot histogram for TCRs per peptide  
# plt.figure(figsize=(10, 5))
# max_tcrs = tcrs_per_peptide['tcr_count'].max()
# print(f"Max TCRs per peptide: {max_tcrs}")
# sns.histplot(tcrs_per_peptide['tcr_count'], bins=100, color='orange')
# plt.xlabel("Number of TCRs per Peptide")
# plt.ylabel("Number of Peptides")
# plt.title("Distribution of TCRs per Peptide")
# plt.xlim(0, min(max_tcrs, tcrs_per_peptide['tcr_count'].mean() + 3 * tcrs_per_peptide['tcr_count'].std()))
# #plt.show()
# plt.savefig('TCRs_per_pepdite.png')


# ------------------------------
# 2. Overview of the Dataset
# ------------------------------

total_pairs = len(combined_df)
unique_tcrs = set(combined_df['TCR_full'])
unique_peptides = set(combined_df['Peptide'])
num_unique_tcrs = len(unique_tcrs)
num_unique_peptides = len(unique_peptides)

print(f"Total pairs: {total_pairs}")
print(f"Unique TCRs: {num_unique_tcrs}")
print(f"Unique Peptides: {num_unique_peptides}")

avg_peptides_per_tcr = total_pairs / num_unique_tcrs
avg_tcrs_per_peptide = total_pairs / num_unique_peptides

print(f"Avg. peptides per TCR: {avg_peptides_per_tcr:.2f}")
print(f"Avg. TCRs per Peptide: {avg_tcrs_per_peptide:.2f}")

# ------------------------------
# 3. Data Splitting: Create Test Categories
# ------------------------------

# Print the number of unique HLAs at the start
unique_hlas = set(combined_df['HLA'])
num_unique_hlas = len(unique_hlas)
print(f"Number of unique HLAs: {num_unique_hlas}")

# Select 5 HLAs for unseen_HLA (or all if fewer)
num_hlas_to_select = min(5, num_unique_hlas)
if num_unique_hlas < 5:
    print(f"Warning: Only {num_unique_hlas} unique HLAs available, using all for unseen_HLA category.")
selected_hlas_cat0 = set(np.random.choice(list(unique_hlas), size=num_hlas_to_select, replace=False))

cat0_df = combined_df[combined_df['HLA'].isin(selected_hlas_cat0)].copy()
cat0_df['test_category'] = 'unseen_HLA'
print(f"Category 0 (unseen HLA) pairs: {len(cat0_df)}")

# Remove these from the pool before any other split
data_for_other_cats = combined_df[~combined_df['HLA'].isin(selected_hlas_cat0)]

# Define a list of peptides that must remain in training (if desired)
peptides_to_keep = [
    "KLGGALQAK", "GILGFVFTL", "AVFDRKSDAK", "RAKFKQLL", "SPRWYFYYL",
    "YLQPRTFLL", "TTDPSFLGRY", "GLCTLVAML", "RVRAYTYSK", "IVTDFSVIK",
    "LLWNGPMAV", "LLLDRLNQL", "NLVPMVATV", "LLAGIGTVPI", "RLRAEAQVK",
    "ELAGIGILTV", "YVLDHLIVV", "LTDEMIAQY", "CINGVCWTV", "TPRVTGGGAM",
    "VMATRRNVL", "KTFPPTEPK", "QYIKWPWYI", "DATYQRTRALVR", "NQKLIANQF",
    "FLRGRAYGL", "CTELKLSDY", "ATDALMTGF", "RPPIFIRRL", "NYNYLYRLF",
    "FLYALALLL", "VMTTVLATL", "CLGGLLTMV", "KSKRTPMGF", "RPHERNGFTVL",
    "MEVTPSGTWL", "FTSDYYQLY", "RPIIRPATL", "ALAGIGILTV", "LLYDANYFL",
    "HPVTKYIM", "RLPGVLPRA", "RFPLTFGWCF", "VYFLQSINF", "PTDNYITTY",
    "ALWEIQQVV", "QAKWRLQTL", "RTATKQYNV", "LLFGYPVYV"
]

# --- Category 1: Completely Unseen Pairs ---
# Choose percentages for unique TCRs and peptides (start with 2% each)
pct_tcr_cat1 = 0.02
pct_peptide_cat1 = 0.02

unique_tcrs = set(data_for_other_cats['TCR_full'])
num_unique_tcrs = len(unique_tcrs)
selected_tcrs_cat1 = set(np.random.choice(list(unique_tcrs), size=int(num_unique_tcrs * pct_tcr_cat1), replace=False))

# Get all pairs that involve the selected TCRs
tcr_pairs = data_for_other_cats[data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1)]
# Optionally, remove pairs with peptides we want to keep
tcr_pairs = tcr_pairs[~tcr_pairs['Peptide'].isin(peptides_to_keep)]

# Derive the set of peptides from these pairs
selected_peptides_cat1 = set(tcr_pairs['Peptide'].unique())
# Update the TCR set to only those that remain after filtering
selected_tcrs_cat1 = set(tcr_pairs['TCR_full'].unique())

# Category 1: Define as all pairs where BOTH the TCR is in selected_tcrs_cat1
# AND the peptide is in selected_peptides_cat1
cat1_df = data_for_other_cats[
    data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) &
    data_for_other_cats['Peptide'].isin(selected_peptides_cat1)
].copy()
cat1_df['test_category'] = 'completely_unseen'
cat1_df.loc[cat1_df['TCR_full'].str.startswith('<unk>'), 'test_category'] = 'completely_unseen_unknownalpha'
cat1_df.loc[cat1_df['TCR_full'].str.endswith('<unk>'), 'test_category'] = 'completely_unseen_unknownbeta'

print("Category 1 (completely unseen) pairs:", len(cat1_df))

# Now, to ensure these TCRs and peptides do not appear anywhere in training,
# define the training candidate as all rows that do NOT contain any selected TCR or selected peptide.
train_candidate = data_for_other_cats[
    ~(data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) | data_for_other_cats['Peptide'].isin(selected_peptides_cat1))
]

# ------------------------------
# --- Category 2: Unseen TCR but Seen Peptide ---
# From train_candidate, select a percentage of unique TCRs (e.g., 5%)
remaining_unique_tcrs = set(train_candidate['TCR_full'])
pct_tcr_cat2 = 0.05
selected_tcrs_cat2 = set(np.random.choice(list(remaining_unique_tcrs), size=int(len(remaining_unique_tcrs) * pct_tcr_cat2), replace=False))

cat2_df = train_candidate[train_candidate['TCR_full'].isin(selected_tcrs_cat2)].copy()
# Add more specific tags for TCRs with unknown regions
cat2_df['test_category'] = 'unseen_TCR'
cat2_df.loc[cat2_df['TCR_full'].str.startswith('<unk>'), 'test_category'] = 'unseen_TCR_unknownalpha'
cat2_df.loc[cat2_df['TCR_full'].str.endswith('<unk>'), 'test_category'] = 'unseen_TCR_unknownbeta'

print("Category 2 (unseen TCR) pairs:", len(cat2_df))

# Remove Category 2 pairs from train_candidate
train_candidate = train_candidate.drop(cat2_df.index)

# ------------------------------
# --- Category 3: Unseen Peptide but Seen TCR ---
# From train_candidate, select a percentage of unique peptides (e.g., 1%)
remaining_unique_peptides = set(train_candidate['Peptide'])
pct_peptide_cat3 = 0.01
selected_peptides_cat3 = set(np.random.choice(list(remaining_unique_peptides), size=int(len(remaining_unique_peptides) * pct_peptide_cat3), replace=False))
# Remove any peptides that need to be kept in training from the selected peptides
selected_peptides_cat3 = selected_peptides_cat3 - set(peptides_to_keep)

cat3_df = train_candidate[train_candidate['Peptide'].isin(selected_peptides_cat3)].copy()

cat3_df['test_category'] = 'unseen_peptide'
print("Category 3 (unseen peptide) pairs:", len(cat3_df))

# Combine all test categories except unseen_HLA
other_test_df = pd.concat([cat1_df, cat2_df, cat3_df])
other_test_df = other_test_df.drop_duplicates()

# Final training set is the remainder of data_for_other_cats that does NOT contain any TCR or peptide selected for Category 1,
# AND also does not contain the pairs selected for Categories 2 and 3.
final_train_df = train_candidate.drop(cat3_df.index)

# Alternatively, if you want to be sure no overlap exists:
final_train_df = data_for_other_cats[
    ~(data_for_other_cats['TCR_full'].isin(selected_tcrs_cat1) |
      data_for_other_cats['Peptide'].isin(selected_peptides_cat1) |
      data_for_other_cats['TCR_full'].isin(selected_tcrs_cat2) |
      data_for_other_cats['Peptide'].isin(selected_peptides_cat3))
]

print(f"\nFinal Split:")
print(f"Training set: {len(final_train_df)} pairs ({len(final_train_df) / total_pairs * 100:.2f}%)")
print(f"Test set (unseen HLA): {len(cat0_df)} pairs ({len(cat0_df) / total_pairs * 100:.2f}%)")
print(f"Test set (other): {len(other_test_df)} pairs ({len(other_test_df) / total_pairs * 100:.2f}%)")

# Save to three files for correct input
cat0_df.to_csv('test_df_unseen_HLA.csv', index=False)
other_test_df.to_csv('test_df_other.csv', index=False)
final_train_df.to_csv('train_df_other.csv', index=False)

# ------------------------------
# 4. Test Cases to Validate the Splits
# ------------------------------

# Create sets from the final training set for lookup
train_tcr_set = set(final_train_df['TCR_full'])
train_peptide_set = set(final_train_df['Peptide'])

print("\nRunning Validation Tests...")

# Test Case 1: Completely Unseen Pairs (Category: 'completely_unseen')
cat1 = other_test_df[other_test_df['test_category'] == 'completely_unseen']
cat1_tcrs = set(cat1['TCR_full'])
cat1_peptides = set(cat1['Peptide'])
overlap_tcr_cat1 = cat1_tcrs.intersection(train_tcr_set)
overlap_peptides_cat1 = cat1_peptides.intersection(train_peptide_set)

print("\nCategory 1: Completely Unseen Pairs")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat1)} (should be 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat1)} (should be 0)")
assert len(overlap_tcr_cat1) == 0, "Error: Some TCRs in 'completely_unseen' category are in training!"
assert len(overlap_peptides_cat1) == 0, "Error: Some peptides in 'completely_unseen' category are in training!"

# Test Case 2: Unseen TCR but Seen Peptide (Category: 'unseen_TCR')
cat2 = other_test_df[other_test_df['test_category'] == 'unseen_TCR']
cat2_tcrs = set(cat2['TCR_full'])
cat2_peptides = set(cat2['Peptide'])
overlap_tcr_cat2 = cat2_tcrs.intersection(train_tcr_set)
overlap_peptides_cat2 = cat2_peptides.intersection(train_peptide_set)

print("\nCategory 2: Unseen TCR but Seen Peptide")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat2)} (should be 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat2)} (should be > 0)")
assert len(overlap_tcr_cat2) == 0, "Error: Some TCRs in 'unseen_TCR' category are in training!"
assert len(overlap_peptides_cat2) > 0, "Error: No peptides in 'unseen_TCR' category are in training!"

# Test Case 3: Unseen Peptide but Seen TCR (Category: 'unseen_peptide')
cat3 = other_test_df[other_test_df['test_category'] == 'unseen_peptide']
cat3_tcrs = set(cat3['TCR_full'])
cat3_peptides = set(cat3['Peptide'])
overlap_tcr_cat3 = cat3_tcrs.intersection(train_tcr_set)
overlap_peptides_cat3 = cat3_peptides.intersection(train_peptide_set)

print("\nCategory 3: Unseen Peptide but Seen TCR")
print(f"Overlapping TCRs in training: {len(overlap_tcr_cat3)} (should be > 0)")
print(f"Overlapping Peptides in training: {len(overlap_peptides_cat3)} (should be 0)")
assert len(overlap_peptides_cat3) == 0, "Error: Some peptides in 'unseen_peptide' category are in training!"
assert len(overlap_tcr_cat3) > 0, "Error: No TCRs in 'unseen_peptide' category are in training!"

# # Overall Consistency Check: Ensure train + test equals original dataset
# We actually don't want to assert this because the
# Check this amount of data vs

# total_split = len(final_train_df) + len(other_test_df)
# assert total_split == len(combined_df), f"Error: Total split size {total_split} does not equal original dataset size {len(combined_df)}."

print("\nAll test cases passed. Data split is valid!")

# need to split into three
# first one needs to be to train the simple ESM fine tuned model
# second one needs to be to train the ESM fine tuned + NC
# Third is testing data for both. But do I need to have two separate versions of this
# one where the input is just one sequence and the other where the input is two sequences

# After final_train_df is defined, stratify it into two sets by promiscuity

# Calculate promiscuity for peptides and TCRs
peptide_counts = final_train_df.groupby('Peptide')['TCR_full'].nunique()
tcr_counts = final_train_df.groupby('TCR_full')['Peptide'].nunique()

# Top 10% most promiscuous
promiscuous_peptide_thresh = peptide_counts.quantile(0.9)
promiscuous_tcr_thresh = tcr_counts.quantile(0.9)
promiscuous_peptides = set(peptide_counts[peptide_counts >= promiscuous_peptide_thresh].index)
promiscuous_tcrs = set(tcr_counts[tcr_counts >= promiscuous_tcr_thresh].index)

# Label each row as promiscuous/non-promiscuous for both TCR and peptide
final_train_df['promiscuous_peptide'] = final_train_df['Peptide'].isin(promiscuous_peptides)
final_train_df['promiscuous_tcr'] = final_train_df['TCR_full'].isin(promiscuous_tcrs)

# Define stratify_label for stratified split
final_train_df['stratify_label'] = final_train_df['promiscuous_peptide'].astype(str) + '_' + final_train_df['promiscuous_tcr'].astype(str)

# Stratified split: maintain ratio of promiscuous/non-promiscuous in both sets
train1_df, train2_df = train_test_split(
    final_train_df,
    test_size=0.4,
    random_state=42,
    stratify=final_train_df['stratify_label']
)

# Save the two training sets
train1_df.to_csv('train_df_stratified_1.csv', index=False)
train2_df.to_csv('train_df_stratified_2.csv', index=False)

# Print summary statistics for validation
print("\n--- Training Set 1 Stats ---")
print(f"Total: {len(train1_df)}")
print(f"Promiscuous peptides: {train1_df['promiscuous_peptide'].sum()} ({train1_df['promiscuous_peptide'].mean()*100:.2f}%)")
print(f"Promiscuous TCRs: {train1_df['promiscuous_tcr'].sum()} ({train1_df['promiscuous_tcr'].mean()*100:.2f}%)")
print(f"Unique peptides: {train1_df['Peptide'].nunique()}")
print(f"Unique TCRs: {train1_df['TCR_full'].nunique()}")

print("\n--- Training Set 2 Stats ---")
print(f"Total: {len(train2_df)}")
print(f"Promiscuous peptides: {train2_df['promiscuous_peptide'].sum()} ({train2_df['promiscuous_peptide'].mean()*100:.2f}%)")
print(f"Promiscuous TCRs: {train2_df['promiscuous_tcr'].sum()} ({train2_df['promiscuous_tcr'].mean()*100:.2f}%)")
print(f"Unique peptides: {train2_df['Peptide'].nunique()}")
print(f"Unique TCRs: {train2_df['TCR_full'].nunique()}")

# Optionally, sample and print distributions
plt.figure(figsize=(10,4))
plt.hist(train1_df.groupby('TCR_full')['Peptide'].nunique(), bins=30, alpha=0.5, label='Train1: Peptides per TCR')
plt.hist(train2_df.groupby('TCR_full')['Peptide'].nunique(), bins=30, alpha=0.5, label='Train2: Peptides per TCR')
plt.legend()
plt.title('Distribution of Peptides per TCR in Training Sets')
plt.savefig('train_peptides_per_tcr_dist.png')

plt.figure(figsize=(10,4))
plt.hist(train1_df.groupby('Peptide')['TCR_full'].nunique(), bins=30, alpha=0.5, label='Train1: TCRs per Peptide')
plt.hist(train2_df.groupby('Peptide')['TCR_full'].nunique(), bins=30, alpha=0.5, label='Train2: TCRs per Peptide')
plt.legend()
plt.title('Distribution of TCRs per Peptide in Training Sets')
plt.savefig('train_tcrs_per_peptide_dist.png')

# Save top 10% promiscuous TCRs and peptides to CSV
pd.DataFrame({'TCR_full': list(promiscuous_tcrs)}).to_csv('top10pct_promiscuous_tcrs.csv', index=False)
pd.DataFrame({'Peptide': list(promiscuous_peptides)}).to_csv('top10pct_promiscuous_peptides.csv', index=False)

# Create random pairings for negative examples
# First, get the total number of positive examples as a reference
# n_positives = len(final_train_df)

# # Get unique TCRs and peptides from the original dataset
# all_tcrs = final_train_df['TCR_full'].unique()
# all_peptides = final_train_df['Peptide'].unique()

# # Create random pairs
# np.random.seed(42)  # for reproducibility
# random_tcrs = np.random.choice(all_tcrs, size=n_positives)
# random_peptides = np.random.choice(all_peptides, size=n_positives)

# # Create negative dataset
# negatives_df = pd.DataFrame({
#     'TCR_full': random_tcrs,
#     'Peptide': random_peptides
# })

# # Create a set of positive pairs for filtering
# positive_pairs = set(zip(final_train_df['TCR_full'], final_train_df['Peptide']))

# # Filter out any accidental positives from the negative dataset
# negative_pairs = set(zip(negatives_df['TCR_full'], negatives_df['Peptide']))
# true_negatives = negative_pairs - positive_pairs

# # Create final negative dataset from the filtered pairs
# final_negatives = pd.DataFrame(list(true_negatives), columns=['TCR_full', 'Peptide'])

# # Save to CSV
# final_negatives.to_csv('train_negatives.csv', index=False)

# Load train dataset 1
train1_df = pd.read_csv('train_df_stratified_1.csv')

# Get unique TCRs and peptides from train1 only
train1_tcrs = train1_df['TCR_full'].unique()
train1_peptides = train1_df['Peptide'].unique()

# Create random pairs using only train1 data
n_positives = len(train1_df)
np.random.seed(42)  # for reproducibility
random_tcrs = np.random.choice(train1_tcrs, size=n_positives)
random_peptides = np.random.choice(train1_peptides, size=n_positives)

# Create negative dataset
negatives_df = pd.DataFrame({
    'TCR_full': random_tcrs,
    'Peptide': random_peptides
})

# Create a set of positive pairs for filtering
positive_pairs = set(zip(train1_df['TCR_full'], train1_df['Peptide']))

# Filter out any accidental positives from the negative dataset
negative_pairs = set(zip(negatives_df['TCR_full'], negatives_df['Peptide']))
true_negatives = negative_pairs - positive_pairs

# Create final negative dataset from the filtered pairs
final_negatives = pd.DataFrame(list(true_negatives), columns=['TCR_full', 'Peptide'])

# Add binding labels
train1_df['Binding'] = 1  # Positive examples
final_negatives['Binding'] = 0  # Negative examples

# Combine positive and negative examples
train_df_full = pd.concat([train1_df, final_negatives])

# Save combined dataset
train_df_full.to_csv('train_df_stratified_1_full.csv', index=False)


# Print summary statistics
print("\n--- Negative Dataset Stats ---")
print(f"Total negative pairs: {len(final_negatives)}")
print(f"Unique TCRs in negatives: {final_negatives['TCR_full'].nunique()}")
print(f"Unique peptides in negatives: {final_negatives['Peptide'].nunique()}")


# create negatives for test set for baseline model case

# Load your test dataset
test_df = pd.read_csv('test_df_other.csv')  # Replace with your actual test dataset filename

# Get the number of positive examples in the test dataset
n_positives = len(test_df)

# Get unique TCRs and peptides from the test dataset
all_tcrs = test_df['TCR_full'].unique()
all_peptides = test_df['Peptide'].unique()

# Create random pairs
np.random.seed(42)  # for reproducibility
random_tcrs = np.random.choice(all_tcrs, size=n_positives)
random_peptides = np.random.choice(all_peptides, size=n_positives)

# Create negative dataset
negatives_df = pd.DataFrame({
    'TCR_full': random_tcrs,
    'Peptide': random_peptides
})

# Create a set of positive pairs for filtering
positive_pairs = set(zip(test_df['TCR_full'], test_df['Peptide']))

# Filter out any accidental positives from the negative dataset
negative_pairs = set(zip(negatives_df['TCR_full'], negatives_df['Peptide']))
true_negatives = negative_pairs - positive_pairs

# Create final negative dataset and combine with positives for test set
test_df['Binding'] = 1  # Add positive labels
final_negatives = pd.DataFrame(list(true_negatives), columns=['TCR_full', 'Peptide'])
final_negatives['Binding'] = 0  # Add negative labels
test_with_negatives = pd.concat([test_df, final_negatives])



test_with_negatives.to_csv('test_with_negatives.csv', index=False)

