This script splits data into train, validate, test.
The script splits into three categories:
- unseen tcrs
- unseen peptides
- completely unseen pairs

There are also some further limitations on the data:
1. Keep most cross-reactive TCRs
2. Makes sure HLA is as balanced as possible in train and val/test

#### 0. Imports and Paths

In [2]:
from pathlib import Path
import subprocess, shutil, os
import pandas as pd  # you use pd in cell 3
import os, re, textwrap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import shutil

In [3]:
# === Paths ===
#CSV_PATH = "data/raw/HLA/full_positives_hla_seq.csv"  # update if needed
CSV_PATH = "/home/natasha/multimodal_model/data/raw/HLA/full_positives_hla_seq.csv"
BASE_DIR  = Path("/home/natasha/multimodal_model") #/ "data" / "raw"
MSA_DIR   = BASE_DIR / "data" / "raw" / "MSA"
TRAIN_DIR = BASE_DIR / "data" / "train"
VAL_DIR   = BASE_DIR / "data" / "val"
TEST_DIR  = BASE_DIR / "data" / "test"
MANI_DIR  = BASE_DIR / "data" / "manifests"

# create directories if they don't exist
MSA_DIR.mkdir(parents=True, exist_ok=True)
MANI_DIR.mkdir(parents=True, exist_ok=True)
TRAIN_DIR.mkdir(parents=True, exist_ok=True)
VAL_DIR.mkdir(parents=True, exist_ok=True)
TEST_DIR.mkdir(parents=True, exist_ok=True)

# === Load CSV and preview ===
df = pd.read_csv(CSV_PATH)
#df.head(3)


In [4]:
# -------- Paths --------
#BASE_DIR = Path("/home/natasha/multimodal_model")
DB_COMBINED = BASE_DIR / "data" / "raw" / "MSA" / "big_combo_subset_tcrs_50000_w_mhc_seqs.fasta"

# use the split DBs created
DBS = {
    "tcra": BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "alpha.fasta",
    "tcrb": BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "beta.fasta",
    "mhc":  BASE_DIR / "data" / "raw" / "MSA" / "db_split" / "mhc.fasta",
}
def pick_db_for(stem: str) -> Path:
    return DBS.get(stem, DB_COMBINED)

OUT_ROOT = BASE_DIR / "data" / "raw" / "MSA" / "jackhmmer_msas"
OUT_ROOT.mkdir(parents=True, exist_ok=True)

# -------- Controls --------
VERBOSE = False              # set False to silence prints
KEEP_INTERMEDIATES = False   # set False to delete .sto/.tbl/.a3m after filtering

# -------- jackhmmer / filtering parameters --------
JACK_ITERS = 1
EVALUE = 1e-10
CPU_THREADS = 4
CPU_THREADS = os.cpu_count() or 4

MAX_SEQS = 64
# keep identity de-dup very relaxed so hhfilter mostly just caps:
ID_THR_DEFAULT = 100        # 100 = no ID-based collapse
COV_THR_TCR     = 50
COV_THR_MHC     = 30

def have(cmd): return shutil.which(cmd) is not None
if VERBOSE:
    print("Deps:",
          "jackhmmer" if have("jackhmmer") else "MISSING",
          "esl-reformat" if have("esl-reformat") else ("reformat.pl" if have("reformat.pl") else "MISSING"),
          "hhfilter" if have("hhfilter") else "MISSING")


#### 1. Split Data into Train, Validate, Test

In [6]:
# === Load CSV and preview ===
df = pd.read_csv(CSV_PATH)
df.head(3)

df = df.reset_index(drop=True)
df["orig_row"] = df.index.astype(int)
df["pair_id"] = df["orig_row"].map(lambda i: f"pair_{i:06d}")


In [5]:
import numpy as np
import pandas as pd

# ============================================================
# 0. Hyperparams & constants
# ============================================================

RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)

# Peptides you *never* want to end up in unseen categories
PEPTIDES_TO_KEEP = [
    "KLGGALQAK", "GILGFVFTL", "AVFDRKSDAK", "RAKFKQLL", "SPRWYFYYL",
    "YLQPRTFLL", "TTDPSFLGRY", "GLCTLVAML", "RVRAYTYSK", "IVTDFSVIK",
    "LLWNGPMAV", "LLLDRLNQL", "NLVPMVATV", "LLAGIGTVPI", "RLRAEAQVK",
    "ELAGIGILTV", "YVLDHLIVV", "LTDEMIAQY", "CINGVCWTV", "TPRVTGGGAM",
    "VMATRRNVL", "KTFPPTEPK", "QYIKWPWYI", "DATYQRTRALVR", "NQKLIANQF",
    "FLRGRAYGL", "CTELKLSDY", "ATDALMTGF", "RPPIFIRRL", "NYNYLYRLF",
    "FLYALALLL", "VMTTVLATL", "CLGGLLTMV", "KSKRTPMGF", "RPHERNGFTVL",
    "MEVTPSGTWL", "FTSDYYQLY", "RPIIRPATL", "ALAGIGILTV", "LLYDANYFL",
    "HPVTKYIM", "RLPGVLPRA", "RFPLTFGWCF", "VYFLQSINF", "PTDNYITTY",
    "ALWEIQQVV", "QAKWRLQTL", "RTATKQYNV", "LLFGYPVYV"
]

# Backbone HLAs that should always stay in train (can also appear in val/test,
# but not as "unseen_HLA" regimes)
BACKBONE_HLAS = {
    'HLA-A*02:01',
    'HLA-A*01:01',
    'HLA-A*24:02',
    'HLA-B*07:02',
    'HLA-B*08:01',
    'HLA-A*03:01',
    'HLA-A*11:01',
}

# How many HLAs to use as explicit unseen-HLA regimes
N_TEST_UNSEEN_HLA = 2
N_VAL_UNSEEN_HLA  = 2

# Minimum peptides for an HLA to be eligible as an "unseen HLA" regime
MIN_PEPTIDES_UNSEEN_HLA_TEST = 10
MIN_PEPTIDES_UNSEEN_HLA_VAL  = 10  # you can lower this if needed


# ============================================================
# 1. Pre-processing
# ============================================================

def preprocess_df(df: pd.DataFrame) -> pd.DataFrame:
    """Clean up TCR chains, build TCR_full and masks."""
    df = df.copy()

    # 1. Fill empty/nan values with 'X' token
    df['TCRa'] = df['TCRa'].fillna('X')
    df['TCRb'] = df['TCRb'].fillna('X')

    # Replace empty strings with 'X'
    df.loc[df['TCRa'] == '', 'TCRa'] = 'X'
    df.loc[df['TCRb'] == '', 'TCRb'] = 'X'

    # Build TCR_full and alpha/beta masks
    df['TCR_full'] = df['TCRa'] + df['TCRb']
    df['m_alpha'] = 1
    df['m_beta'] = 1
    df.loc[df['TCRa'] == 'X', 'm_alpha'] = 0
    df.loc[df['TCRb'] == 'X', 'm_beta'] = 0

    # Remove rows with invalid TCR_full entries
    df = df.dropna(subset=['TCR_full'])
    df = df[df['TCR_full'] != ' ']
    df = df[df['TCR_full'] != 'nan']

    return df


# ============================================================
# 2. Category builder for unseen TCR / unseen peptide / completely unseen
# ============================================================

def make_unseen_tcr_peptide_categories(
    df_in: pd.DataFrame,
    peptides_to_keep,
    pct_tcr_cat1=0.02,
    pct_tcr_cat2=0.05,
    pct_peptide_cat3=0.01,
    category_col='category',
    prefix=''
):
    """
    Build three categories:
      - completely_unseen   (unseen TCR & unseen peptide)
      - unseen_TCR          (TCR unseen in train, peptide seen in train)
      - unseen_peptide      (peptide unseen in train, TCR seen in train)

    Returns:
      cat1_df, cat2_df, cat3_df, remaining_df
    """
    df = df_in.copy()

    # -----------------------------
    # Category 1: Completely unseen
    # -----------------------------
    unique_tcrs = df['TCR_full'].unique()
    if len(unique_tcrs) == 0:
        return (df.iloc[0:0].copy(),) * 4  # all empty

    n_cat1_tcrs = max(1, int(len(unique_tcrs) * pct_tcr_cat1))
    selected_tcrs_cat1 = set(rng.choice(unique_tcrs, size=n_cat1_tcrs, replace=False))

    tcr_pairs = df[df['TCR_full'].isin(selected_tcrs_cat1)]
    # remove peptides we insist must remain in training
    tcr_pairs = tcr_pairs[~tcr_pairs['Peptide'].isin(peptides_to_keep)]

    selected_peptides_cat1 = set(tcr_pairs['Peptide'].unique())
    selected_tcrs_cat1 = set(tcr_pairs['TCR_full'].unique())

    cat1_df = df[
        df['TCR_full'].isin(selected_tcrs_cat1) &
        df['Peptide'].isin(selected_peptides_cat1)
    ].copy()
    cat1_df[category_col] = f'{prefix}completely_unseen'

    # Candidates for other categories = everything that does NOT use these TCRs OR these peptides
    train_candidate = df[
        ~(df['TCR_full'].isin(selected_tcrs_cat1) | df['Peptide'].isin(selected_peptides_cat1))
    ].copy()

    # -----------------------------
    # Category 2: Unseen TCR (peptide seen)
    # -----------------------------
    remaining_unique_tcrs = train_candidate['TCR_full'].unique()
    if len(remaining_unique_tcrs) > 0:
        n_cat2_tcrs = max(1, int(len(remaining_unique_tcrs) * pct_tcr_cat2))
        selected_tcrs_cat2 = set(
            rng.choice(remaining_unique_tcrs, size=n_cat2_tcrs, replace=False)
        )
    else:
        selected_tcrs_cat2 = set()

    cat2_df = train_candidate[train_candidate['TCR_full'].isin(selected_tcrs_cat2)].copy()
    cat2_df[category_col] = f'{prefix}unseen_TCR'

    # Remove all these rows from candidate pool
    train_candidate = train_candidate.drop(cat2_df.index)

    # -----------------------------
    # Category 3: Unseen peptide (TCR seen)
    # -----------------------------
    remaining_unique_peptides = train_candidate['Peptide'].unique()
    if len(remaining_unique_peptides) > 0:
        n_cat3_peptides = max(1, int(len(remaining_unique_peptides) * pct_peptide_cat3))
        selected_peptides_cat3 = set(
            rng.choice(remaining_unique_peptides, size=n_cat3_peptides, replace=False)
        )
        selected_peptides_cat3 = selected_peptides_cat3 - set(peptides_to_keep)
    else:
        selected_peptides_cat3 = set()

    cat3_df = train_candidate[train_candidate['Peptide'].isin(selected_peptides_cat3)].copy()
    cat3_df[category_col] = f'{prefix}unseen_peptide'

    # Remove all rows containing these peptides from candidate pool
    remaining_df = train_candidate.drop(cat3_df.index)

    return cat1_df, cat2_df, cat3_df, remaining_df


# ============================================================
# 3. Top-level split function
# ============================================================

def split_tcr_dataset(
    df_raw: pd.DataFrame,
    peptides_to_keep=PEPTIDES_TO_KEEP,
    backbone_hlas=BACKBONE_HLAS,
    n_test_unseen_hla=N_TEST_UNSEEN_HLA,
    n_val_unseen_hla=N_VAL_UNSEEN_HLA,
    min_peptides_unseen_hla_test=MIN_PEPTIDES_UNSEEN_HLA_TEST,
    min_peptides_unseen_hla_val=MIN_PEPTIDES_UNSEEN_HLA_VAL
):
    """
    Main entry point:
      - cleans df
      - builds test (HLA unseen + Cat1/2/3)
      - builds validation (HLA unseen + Cat1/2/3 on remaining)
      - builds final train by globally removing all "unseen" entities
    """
    df = preprocess_df(df_raw)

    total_pairs = len(df)
    print(f"Total pairs after cleaning: {total_pairs}")
    print(f"Unique TCRs: {df['TCR_full'].nunique()}")
    print(f"Unique Peptides: {df['Peptide'].nunique()}")
    print(f"Unique HLAs: {df['HLA'].nunique()}")

    # --------------------------------------------------------
    # 3.1 HLA summary on full data
    # --------------------------------------------------------
    hla_table = df.groupby('HLA').agg(
        TCR_full=('TCR_full', 'nunique'),
        Peptide=('Peptide', 'nunique')
    ).reset_index().sort_values('Peptide', ascending=False)

    print("\nTop HLAs by peptide count:")
    print(hla_table.head(10))

    # --------------------------------------------------------
    # 3.2 Choose unseen HLAs for TEST
    # --------------------------------------------------------
    candidate_hlas_test = set(
        hla_table.loc[
            (hla_table['Peptide'] >= min_peptides_unseen_hla_test) &
            (~hla_table['HLA'].isin(backbone_hlas)),
            'HLA'
        ]
    )

    if len(candidate_hlas_test) == 0:
        print("\n[WARN] No HLAs eligible for unseen-HLA TEST regime with current threshold.")
        test_unseen_hlas = set()
    else:
        n_test = min(n_test_unseen_hla, len(candidate_hlas_test))
        test_unseen_hlas = set(
            rng.choice(list(candidate_hlas_test), size=n_test, replace=False)
        )

    print("\nTest unseen HLAs chosen:", test_unseen_hlas)

    # All rows with these HLAs go straight to test (HLA-unseen category)
    cat0_test_df = df[df['HLA'].isin(test_unseen_hlas)].copy()
    cat0_test_df['test_category'] = 'unseen_HLA'

    # Remaining pool for test Cat1/2/3
    pool_after_test_hla = df[~df['HLA'].isin(test_unseen_hlas)].copy()

    # --------------------------------------------------------
    # 3.3 Build Cat1/Cat2/Cat3 for TEST
    # --------------------------------------------------------
    cat1_test, cat2_test, cat3_test, _ = make_unseen_tcr_peptide_categories(
        pool_after_test_hla,
        peptides_to_keep=peptides_to_keep,
        category_col='test_category',
        prefix=''
    )

    print(f"\nTest Cat1 (completely_unseen) pairs: {len(cat1_test)}")
    print(f"Test Cat2 (unseen_TCR) pairs: {len(cat2_test)}")
    print(f"Test Cat3 (unseen_peptide) pairs: {len(cat3_test)}")

    other_test_df = pd.concat([cat1_test, cat2_test, cat3_test]).drop_duplicates()
    test_df = pd.concat([cat0_test_df, other_test_df]).drop_duplicates()

    print(f"\nTotal TEST pairs: {len(test_df)} "
          f"({len(test_df) / total_pairs * 100:.2f}%)")

    # --------------------------------------------------------
    # 3.4 Build initial dev pool (for val + future train)
    # BUT: we will later *globally* exclude unseen entities again for train_df.
    # For val construction, it's enough to remove test rows by index.
    # --------------------------------------------------------
    dev_pool = df[~df.index.isin(test_df.index)].copy()

    # --------------------------------------------------------
    # 3.5 Choose unseen HLAs for VALIDATION from dev_pool
    # We must exclude ANY HLA that appears in test (not just test_unseen_hlas)
    # to avoid overlap of unseen HLA regimes.
    # --------------------------------------------------------
    test_hlas = set(test_df['HLA'])

    hla_table_dev = dev_pool.groupby('HLA').agg(
        TCR_full=('TCR_full', 'nunique'),
        Peptide=('Peptide', 'nunique')
    ).reset_index()

    candidate_hlas_val = set(
        hla_table_dev.loc[
            (hla_table_dev['Peptide'] >= min_peptides_unseen_hla_val) &
            (~hla_table_dev['HLA'].isin(backbone_hlas)) &
            (~hla_table_dev['HLA'].isin(test_hlas)),
            'HLA'
        ]
    )

    if len(candidate_hlas_val) == 0:
        print("\n[WARN] No HLAs eligible for unseen-HLA VAL regime with current threshold.")
        val_unseen_hlas = set()
    else:
        n_val = min(n_val_unseen_hla, len(candidate_hlas_val))
        val_unseen_hlas = set(
            rng.choice(list(candidate_hlas_val), size=n_val, replace=False)
        )

    print("\nVal unseen HLAs chosen:", val_unseen_hlas)

    cat0_val_df = dev_pool[dev_pool['HLA'].isin(val_unseen_hlas)].copy()
    cat0_val_df['val_category'] = 'unseen_HLA'

    pool_after_val_hla = dev_pool[~dev_pool['HLA'].isin(val_unseen_hlas)].copy()

    # --------------------------------------------------------
    # 3.6 Build Cat1/Cat2/Cat3 for VALIDATION
    # --------------------------------------------------------
    cat1_val, cat2_val, cat3_val, _ = make_unseen_tcr_peptide_categories(
        pool_after_val_hla,
        peptides_to_keep=peptides_to_keep,
        category_col='val_category',
        prefix=''
    )

    print(f"\nVal Cat1 (completely_unseen) pairs: {len(cat1_val)}")
    print(f"Val Cat2 (unseen_TCR) pairs: {len(cat2_val)}")
    print(f"Val Cat3 (unseen_peptide) pairs: {len(cat3_val)}")

    val_df = pd.concat([cat0_val_df, cat1_val, cat2_val, cat3_val]).drop_duplicates()

    print(f"\nTotal VAL pairs: {len(val_df)} "
          f"({len(val_df) / total_pairs * 100:.2f}%)")

    # --------------------------------------------------------
    # 3.7 Build FINAL TRAIN by globally excluding all "unseen" entities
    # --------------------------------------------------------
    # TCRs/peptides used in unseen regimes must never appear in train

    # Test unseen sets
    test_cat1_tcrs = set(cat1_test['TCR_full'])
    test_cat1_peps = set(cat1_test['Peptide'])

    test_cat2_tcrs = set(cat2_test['TCR_full'])
    test_cat3_peps = set(cat3_test['Peptide'])

    # Val unseen sets
    val_cat1_tcrs = set(cat1_val['TCR_full'])
    val_cat1_peps = set(cat1_val['Peptide'])

    val_cat2_tcrs = set(cat2_val['TCR_full'])
    val_cat3_peps = set(cat3_val['Peptide'])

    forbidden_tcrs_for_train = (
        test_cat1_tcrs | test_cat2_tcrs |
        val_cat1_tcrs | val_cat2_tcrs
    )
    forbidden_peps_for_train = (
        test_cat1_peps | test_cat3_peps |
        val_cat1_peps | val_cat3_peps
    )
    forbidden_hlas_for_train = test_unseen_hlas | val_unseen_hlas

    train_df = df[
        (~df.index.isin(test_df.index)) &
        (~df.index.isin(val_df.index)) &
        (~df['TCR_full'].isin(forbidden_tcrs_for_train)) &
        (~df['Peptide'].isin(forbidden_peps_for_train)) &
        (~df['HLA'].isin(forbidden_hlas_for_train))
    ].copy()

    print(f"\nTotal TRAIN pairs: {len(train_df)} "
          f"({len(train_df) / total_pairs * 100:.2f}%)")

    meta = {
        'test_unseen_hlas': test_unseen_hlas,
        'val_unseen_hlas': val_unseen_hlas,
        'backbone_hlas': backbone_hlas,
    }

    return train_df, val_df, test_df, meta


# ============================================================
# 4. Sanity checks
# ============================================================

def run_sanity_checks(train_df, val_df, test_df):
    print("\n========== SANITY CHECKS ==========")

    train_tcrs = set(train_df['TCR_full'])
    train_peps = set(train_df['Peptide'])

    # --- Test categories
    if 'test_category' in test_df.columns:
        for cat_name, label in [
            ("Completely Unseen", 'completely_unseen'),
            ("Unseen TCR", 'unseen_TCR'),
            ("Unseen Peptide", 'unseen_peptide')
        ]:
            cat = test_df[test_df['test_category'] == label]
            tcrs = set(cat['TCR_full'])
            peps = set(cat['Peptide'])

            overlap_tcr = len(tcrs & train_tcrs)
            overlap_pep = len(peps & train_peps)

            print(f"\n[TEST] {cat_name} ({label})")
            print(f"  #pairs: {len(cat)}")
            print(f"  Overlap TCR with TRAIN: {overlap_tcr}")
            print(f"  Overlap Peptide with TRAIN: {overlap_pep}")

    # --- HLA overlaps
    train_hlas = set(train_df['HLA'])
    val_hlas   = set(val_df['HLA'])
    test_hlas  = set(test_df['HLA'])

    print("\nHLA counts:")
    print("  Train HLAs:", len(train_hlas))
    print("  Val HLAs:  ", len(val_hlas))
    print("  Test HLAs: ", len(test_hlas))

    unseen_hlas_test = test_hlas - train_hlas
    unseen_hlas_val  = val_hlas - train_hlas

    print("\nUnseen HLAs in TEST vs TRAIN:", unseen_hlas_test)
    print("Unseen HLAs in VAL vs TRAIN:", unseen_hlas_val)
    print("Overlap of unseen-HLA between VAL and TEST:",
          unseen_hlas_test & unseen_hlas_val)

    print("===================================\n")


In [None]:
train_df, val_df, test_df, meta = split_tcr_dataset(df)

run_sanity_checks(train_df, val_df, test_df)
print("Meta:", meta)

# # save to data
# train_df.to_csv(TRAIN_DIR / 'train_df.csv', index=False)
# val_df.to_csv(VAL_DIR / 'val_df.csv', index=False)
# test_df.to_csv(TEST_DIR / 'test_df.csv', index=False)





Total pairs after cleaning: 39926
Unique TCRs: 36197
Unique Peptides: 1457
Unique HLAs: 71

Top HLAs by peptide count:
               HLA  TCR_full  Peptide
1      HLA-A*02:01     12012      864
0      HLA-A*01:01      1961      122
23     HLA-B*07:02      1861      107
15     HLA-A*24:02       736       89
25     HLA-B*08:01      2114       59
3   HLA-A*02:01:48        40       35
29     HLA-B*27:05        24       31
32     HLA-B*35:01        60       22
54     HLA-B*57:03         1       22
12     HLA-A*11:01      2344       18

Test unseen HLAs chosen: {np.str_('HLA-A*02:01:48'), np.str_('HLA-B*57:01')}

Test Cat1 (completely_unseen) pairs: 114
Test Cat2 (unseen_TCR) pairs: 1863
Test Cat3 (unseen_peptide) pairs: 32

Total TEST pairs: 2214 (5.55%)

Val unseen HLAs chosen: {np.str_('HLA-B*35:08'), np.str_('HLA-B*57:03')}

Val Cat1 (completely_unseen) pairs: 112
Val Cat2 (unseen_TCR) pairs: 1769
Val Cat3 (unseen_peptide) pairs: 28

Total VAL pairs: 1968 (4.93%)

Total TRAIN pairs: 339

In [7]:
# Additionally, remove any rows where the whole TCR is missing?

#### 2. Build manifest file and YAML files for Boltz runs

##### 2a) Split the fasta file into alpha, beta, and mhc (only need to do this once)

In [8]:
# # split the fasta file into alpha, beta, and mhc (only need to do this once)
# from pathlib import Path

# src = Path("/home/natasha/multimodal_model/data/raw/MSA/big_combo_subset_tcrs_50000_w_mhc_seqs.fasta")
# outdir = Path("/home/natasha/multimodal_model/data/raw/MSA/db_split")
# outdir.mkdir(parents=True, exist_ok=True)

# fa_alpha = outdir / "alpha.fasta"
# fa_beta  = outdir / "beta.fasta"
# fa_mhc   = outdir / "mhc.fasta"

# with src.open() as fin, \
#      fa_alpha.open("w") as fa, \
#      fa_beta.open("w") as fb, \
#      fa_mhc.open("w") as fm:
#     hdr, seq = None, []
#     for ln in fin:
#         if ln.startswith(">"):
#             if hdr:
#                 s = "".join(seq)
#                 if hdr.endswith("_a"):
#                     fa.write(hdr + "\n" + s + "\n")
#                 elif hdr.endswith("_b"):
#                     fb.write(hdr + "\n" + s + "\n")
#                 elif "mhc" in hdr.lower():
#                     fm.write(hdr + "\n" + s + "\n")
#             hdr, seq = ln.strip(), []
#         else:
#             seq.append(ln.strip())
#     # write last record
#     if hdr:
#         s = "".join(seq)
#         if hdr.endswith("_a"):
#             fa.write(hdr + s + "\n")
#         elif hdr.endswith("_b"):
#             fb.write(hdr + s + "\n")
#         elif "mhc" in hdr.lower():
#             fm.write(hdr + s + "\n")

# print("Split complete:")
# print("α:", fa_alpha.stat().st_size, "bytes")
# print("β:", fa_beta.stat().st_size, "bytes")
# print("MHC:", fa_mhc.stat().st_size, "bytes")


##### 2b) Build MSA + YAML + Manifest Files

In [9]:
## run this on GPU - would it speed up?
# unlikely, as it's the search that is slow not the actual memory compute

from concurrent.futures import ProcessPoolExecutor, as_completed

from __future__ import annotations

import os, re, shutil, subprocess, textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict

import pandas as pd


MSA_ROOT = BASE_DIR / "data" / "raw" / "MSA" / "jackhmmer_msas"


In [10]:
# load dataframes
# save to data
train_df = pd.read_csv(TRAIN_DIR / 'train_df.csv')
val_df   = pd.read_csv(VAL_DIR / 'val_df.csv')
test_df  = pd.read_csv(TEST_DIR / 'test_df.csv')


In [11]:
# Remove rows where both TCRa and TCRb are <unk>
train_df = train_df[~((train_df['TCRa'] == "<unk>") & (train_df['TCRb'] == "<unk>"))]


In [12]:
train_df.to_csv(TRAIN_DIR / 'train_df.csv', index=False)

In [None]:
# -------- Build MSA files, YAML and manifest files --------
# Correcting for missing chains, skip in YAML if no chain is present

# ----------------------------
# Config
# ----------------------------
@dataclass
class MSAConfig:
    out_root: Path
    db_combined: Path
    dbs: Dict[str, Path]                 # {"tcra": ..., "tcrb": ..., "mhc": ...}

    verbose: bool = False
    keep_intermediates: bool = False

    jack_iters: int = 1
    evalue: float = 1e-10
    cpu_threads: int = os.cpu_count() or 4

    max_seqs: int = 64
    id_thr: int = 100
    cov_thr_tcr: int = 50
    cov_thr_mhc: int = 30


def have(cmd: str) -> bool:
    return shutil.which(cmd) is not None


def run(cmd):
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    out, err = p.communicate()
    return p.returncode, out, err


def pick_db_for(cfg: MSAConfig, stem: str) -> Path:
    return cfg.dbs.get(stem, cfg.db_combined)


# ----------------------------
# Sequence cleaning / presence
# ----------------------------
def clean_seq(s) -> str:
    if not isinstance(s, str):
        return ""
    return re.sub(r"[^A-Za-z]", "", s).upper()


def has_seq(s: str, min_len: int = 1) -> bool:
    return isinstance(s, str) and len(s) >= min_len


MISSING_TOKENS = {
    "", "<unk>", "<UNK>", "UNK", "UNKNOWN",
    "NA", "N/A", "NONE", "NULL", "NAN"
}

def normalise_seq_for_yaml(s) -> str:
    s0 = "" if not isinstance(s, str) else s.strip()
    if s0 in {"<unk>", "<UNK>"}:
        return ""
    s = clean_seq(s0)     # strips <> and lowercases etc -> "UNK"
    return "" if s in MISSING_TOKENS else s




# ----------------------------
# MSA helpers
# ----------------------------
def sto_to_a3m(sto_path: Path, a3m_path: Path) -> bool:
    if have("esl-reformat"):
        code, out, err = run(["esl-reformat", "a3m", str(sto_path)])
        if code == 0 and out:
            a3m_path.write_text(out)
            return True
    if have("reformat.pl"):
        code, out, err = run(["reformat.pl", "sto", "a3m", str(sto_path), str(a3m_path)])
        return code == 0 and a3m_path.exists()
    return False


def count_a3m(p: Path) -> int:
    if not p.exists():
        return -1
    return sum(1 for ln in p.open() if ln.startswith(">"))


def hhfilter_cap(cfg: MSAConfig, in_a3m: Path, out_a3m: Path, cov_thr: int) -> bool:
    """
    Runs hhfilter if available; otherwise COPY (not move!) input → output.
    """
    if not have("hhfilter"):
        if cfg.verbose:
            print("WARN: hhfilter not found on PATH; copying input → output")
        shutil.copyfile(in_a3m, out_a3m)
        return True

    # detect which flag is supported
    code, out, err = run(["hhfilter", "-h"])
    use_maxseq = ("-maxseq" in (out or "")) or ("-maxseq" in (err or ""))

    cmd = [
        "hhfilter",
        "-i", str(in_a3m),
        "-o", str(out_a3m),
        "-id", str(cfg.id_thr),
        "-cov", str(cov_thr),
    ]
    cmd += (["-maxseq", str(cfg.max_seqs)] if use_maxseq else ["-n", str(cfg.max_seqs)])

    if cfg.verbose:
        print("[CMD]", " ".join(map(str, cmd)))
    code, out, err = run(cmd)
    if cfg.verbose and err:
        print("[HHFILTER][stderr]\n", err.strip()[:800])
    return code == 0 and out_a3m.exists()


def build_msa_for_chain(cfg: MSAConfig, seq: str, out_dir: Path, stem: str) -> Optional[Path]:
    """
    Returns Path to filtered .a3m OR None if seq is missing/empty.
    """  
    seq = normalise_seq_for_yaml(seq)
    if not has_seq(seq, min_len=1):
        return None


    out_dir.mkdir(parents=True, exist_ok=True)
    qfa = out_dir / f"{stem}.fa"
    qfa.write_text(f">{stem}\n{seq}\n")

    sto = out_dir / f"{stem}.sto"
    raw_a3m = out_dir / f"{stem}.a3m"
    filt_a3m = out_dir / f"{stem}.filt.a3m"
    tbl = out_dir / f"{stem}.tbl"

    db_fasta = pick_db_for(cfg, stem)

    # jackhmmer
    cmd = [
        "jackhmmer",
        "-N", str(cfg.jack_iters),
        "-A", str(sto),
        "--tblout", str(tbl),
        "-E", str(cfg.evalue),
        "--incE", str(cfg.evalue),
        "--incdomE", str(cfg.evalue),
        "--cpu", str(cfg.cpu_threads),
        str(qfa), str(db_fasta),
    ]
    if cfg.verbose:
        print(f"\n=== {stem} ===")
        print("[CMD]", " ".join(map(str, cmd)))

    code, out, err = run(cmd)

    # If jackhmmer failed or STO empty → just use single-seq A3M
    if code != 0 or (not sto.exists()) or sto.stat().st_size < 200:
        if cfg.verbose:
            print("WARN: bad/empty .sto → falling back to single-seq A3M")
            if err:
                print("[JACK][stderr]\n", err.strip()[:800])
        raw_a3m.write_text(f">{stem}\n{seq}\n")
        raw = raw_a3m
    else:
        ok = sto_to_a3m(sto, raw_a3m)
        if not ok:
            if cfg.verbose:
                print("WARN: sto->a3m failed; using single-seq fallback")
            raw_a3m.write_text(f">{stem}\n{seq}\n")
        raw = raw_a3m

    # hhfilter with chain-specific coverage
    cov_thr = cfg.cov_thr_tcr if stem in ("tcra", "tcrb") else cfg.cov_thr_mhc
    ok = hhfilter_cap(cfg, raw, filt_a3m, cov_thr=cov_thr)

    if cfg.verbose:
        print("[CHK] filt a3m:", filt_a3m, "nseq=", count_a3m(filt_a3m))

    # cleanup
    if not cfg.keep_intermediates:
        for p in (sto, tbl, raw_a3m, qfa):
            try:
                p.unlink()
            except Exception:
                pass

    return filt_a3m if ok else None


def relpath_from_base(p: Path, base: Path) -> str:
    return str(p.relative_to(base))


def make_yaml(
    base_dir: Path,
    tcra_seq: str,
    tcrb_seq: str,
    pep: str,
    mhc_seq: str,
    tcra_msa: Optional[Path],
    tcrb_msa: Optional[Path],
    mhc_msa: Optional[Path],
) -> str:
    def msa_field(p: Optional[Path]) -> str:
        return "empty" if p is None else relpath_from_base(p, base_dir)

    # normalise sequences (turn UNK/NA/etc into "")
    tcra_seq = normalise_seq_for_yaml(tcra_seq)
    tcrb_seq = normalise_seq_for_yaml(tcrb_seq)
    pep      = normalise_seq_for_yaml(pep)
    mhc_seq  = normalise_seq_for_yaml(mhc_seq)

    lines = ["version: 1", "sequences:"]

    def add(pid: str, seq: str, msa: str):
        lines.append(textwrap.dedent(f"""\
        - protein:
            id: {pid}
            sequence: {seq}
            msa: {msa}
        """).rstrip())

    # Only include a chain if sequence exists
    if tcra_seq:
        add("A", tcra_seq, msa_field(tcra_msa))
    if tcrb_seq:
        add("B", tcrb_seq, msa_field(tcrb_msa))
    if pep:
        add("C", pep, "empty")          # keep peptide msa empty
    if mhc_seq:
        add("D", mhc_seq, msa_field(mhc_msa))

    return "\n".join(lines) + "\n"



def _pair_id(i: int, pad: int = 3) -> str:
    return f"pair_{i:0{pad}d}"


def get_pair_id(row, i: int, pad: int = 3) -> str:
    """
    Prefer an existing pair_id column; otherwise fall back to index-based IDs.
    """
    pid = row.get("pair_id", None)
    if isinstance(pid, str) and pid.strip():
        return pid.strip()
    return f"pair_{i:0{pad}d}"


def _yaml_done(yml_path: Path) -> bool:
    # “Done” = YAML exists and is non-trivial size
    return yml_path.exists() and yml_path.stat().st_size > 50


def _msas_done(pair_msa_dir: Path, stems=("tcra", "tcrb", "mhc")) -> bool:
    # “Done” = for every stem: either MSA exists OR the chain may be missing.
    # We can’t know missingness from filesystem alone, so this is best-effort.
    # Use only if you want aggressive skipping when MSAs are present.
    for st in stems:
        if (pair_msa_dir / f"{st}.filt.a3m").exists():
            continue
        # if no file exists, consider not done (so it can be rebuilt)
        return False
    return True


def _process_one_row(args):
    (i, row_dict, split_name, base_dir_str, yaml_dir_str, msa_root_str, cfg, pad) = args
    base_dir = Path(base_dir_str)
    yaml_dir = Path(yaml_dir_str)
    msa_root = Path(msa_root_str)

    pair_id = get_pair_id(row_dict, i=i, pad=pad)

    yml_path = yaml_dir / f"{pair_id}.yaml"
    pair_msa_dir = msa_root / split_name / pair_id

    pep  = normalise_seq_for_yaml(row_dict.get("Peptide", ""))
    mhc  = normalise_seq_for_yaml(row_dict.get("HLA_sequence", ""))
    tcra = normalise_seq_for_yaml(row_dict.get("TCRa", ""))
    tcrb = normalise_seq_for_yaml(row_dict.get("TCRb", ""))

    tcra_a3m = build_msa_for_chain(cfg, tcra, pair_msa_dir, "tcra")
    tcrb_a3m = build_msa_for_chain(cfg, tcrb, pair_msa_dir, "tcrb")
    mhc_a3m  = build_msa_for_chain(cfg, mhc,  pair_msa_dir, "mhc")

    yml_text = make_yaml(
        base_dir=base_dir,
        tcra_seq=tcra,
        tcrb_seq=tcrb,
        pep=pep,
        mhc_seq=mhc,
        tcra_msa=tcra_a3m,
        tcrb_msa=tcrb_a3m,
        mhc_msa=mhc_a3m,
    )
    yml_path.write_text(yml_text)

    return pair_id, True, ""



def process_split_parallel_resume(
    df: pd.DataFrame,
    split_name: str,
    base_dir: Path,
    yaml_dir: Path,
    msa_root: Path,
    cfg: MSAConfig,
    *,
    pad: int = 3,
    resume: bool = True,
    skip_if_yaml_exists: bool = True,
    skip_if_msas_exist: bool = False,
    max_workers: Optional[int] = None,
    mani_dir: Optional[Path] = None,
) -> pd.DataFrame:

    yaml_dir.mkdir(parents=True, exist_ok=True)
    if mani_dir is not None:
        mani_dir.mkdir(parents=True, exist_ok=True)

    if max_workers is None:
        ncpu = os.cpu_count() or 8
        max_workers = max(2, min(6, ncpu // max(1, cfg.cpu_threads)))

    manifest_rows = []
    df2 = df.reset_index(drop=True)

    # --- build manifest rows (pair_id-stable)
    for i, row in df2.iterrows():
        pair_id = get_pair_id(row, i=i, pad=pad)

        pep  = normalise_seq_for_yaml(row.get("Peptide", ""))
        mhc  = normalise_seq_for_yaml(row.get("HLA_sequence", ""))
        tcra = normalise_seq_for_yaml(row.get("TCRa", ""))
        tcrb = normalise_seq_for_yaml(row.get("TCRb", ""))

        yml_path = yaml_dir / f"{pair_id}.yaml"
        try:
            yml_rel = str(yml_path.relative_to(base_dir))
        except ValueError:
            yml_rel = str(yml_path)

        manifest_rows.append({
            "pair_id": pair_id,
            "yaml_path": yml_rel,
            "pep_len": len(pep),
            "tcra_len": len(tcra),
            "tcrb_len": len(tcrb),
            "hla_len": len(mhc),
        })

    # --- build todo list (resume-aware)
    todo = []
    for i, row in df2.iterrows():
        pair_id = get_pair_id(row, i=i, pad=pad)

        yml_path = yaml_dir / f"{pair_id}.yaml"
        pair_msa_dir = msa_root / split_name / pair_id

        if resume:
            if skip_if_yaml_exists and _yaml_done(yml_path):
                continue
            if skip_if_msas_exist and _msas_done(pair_msa_dir):
                continue

        todo.append((i, row.to_dict(), split_name, str(base_dir), str(yaml_dir), str(msa_root), cfg, pad))

    print(f"[{split_name}] total rows: {len(df2)} | to process: {len(todo)} | workers: {max_workers}")

    n_ok = 0
    n_fail = 0
    if todo:
        with ProcessPoolExecutor(max_workers=max_workers) as ex:
            futures = [ex.submit(_process_one_row, t) for t in todo]
            for fut in as_completed(futures):
                try:
                    pair_id, wrote, err = fut.result()
                    n_ok += int(wrote)
                except Exception as e:
                    n_fail += 1
                    print("[ERR]", repr(e))

    print(f"[{split_name}] completed: {n_ok} | failed: {n_fail} | skipped: {len(df2) - len(todo)}")

    manifest_df = pd.DataFrame(manifest_rows)
    if mani_dir is not None:
        manifest_df.to_csv(mani_dir / f"{split_name}_manifest.csv", index=False)

    return manifest_df




In [None]:
for col in ["Peptide", "HLA_sequence", "TCRa", "TCRb"]:
    train_df[col] = train_df[col].apply(clean_seq)
    val_df[col]   = val_df[col].apply(clean_seq)
    test_df[col]  = test_df[col].apply(clean_seq)


cfg = MSAConfig(
    out_root=OUT_ROOT,                 # (not used in functions, but fine to keep)
    db_combined=DB_COMBINED,
    dbs=DBS,
    verbose=VERBOSE,
    keep_intermediates=KEEP_INTERMEDIATES,
    jack_iters=JACK_ITERS,
    evalue=EVALUE,
    cpu_threads=CPU_THREADS,
    max_seqs=MAX_SEQS,
    id_thr=ID_THR_DEFAULT,
    cov_thr_tcr=COV_THR_TCR,
    cov_thr_mhc=COV_THR_MHC,
)



train_manifest = process_split_parallel_resume(
    train_df, "train", BASE_DIR, TRAIN_DIR, MSA_ROOT, cfg,
    resume=True,
    skip_if_yaml_exists=True,
    max_workers=6,
)



val_manifest = process_split_parallel_resume(
    val_df, "val", BASE_DIR, VAL_DIR, MSA_ROOT, cfg,
    resume=True, skip_if_yaml_exists=True
)
test_manifest = process_split_parallel_resume(
    test_df, "test", BASE_DIR, TEST_DIR, MSA_ROOT, cfg,
    resume=True, skip_if_yaml_exists=True
)


In [16]:
cfg = MSAConfig(
    out_root=OUT_ROOT,                 # (not used in functions, but fine to keep)
    db_combined=DB_COMBINED,
    dbs=DBS,
    verbose=VERBOSE,
    keep_intermediates=KEEP_INTERMEDIATES,
    jack_iters=JACK_ITERS,
    evalue=EVALUE,
    cpu_threads=CPU_THREADS,
    max_seqs=MAX_SEQS,
    id_thr=ID_THR_DEFAULT,
    cov_thr_tcr=COV_THR_TCR,
    cov_thr_mhc=COV_THR_MHC,
)


In [None]:


train_manifest = process_split_parallel_resume(
    train_df, "train", BASE_DIR, TRAIN_DIR, MSA_ROOT, cfg,
    resume=True,
    skip_if_yaml_exists=True,
    max_workers=6,
)




[train] total rows: 32691 | to process: 3907 | workers: 6
[train] completed: 3907 | failed: 0 | skipped: 28784


In [18]:
val_manifest = process_split_parallel_resume(
    val_df, "val", BASE_DIR, VAL_DIR, MSA_ROOT, cfg,
    resume=True, skip_if_yaml_exists=True
)
test_manifest = process_split_parallel_resume(
    test_df, "test", BASE_DIR, TEST_DIR, MSA_ROOT, cfg,
    resume=True, skip_if_yaml_exists=True
)


[val] total rows: 1968 | to process: 1968 | workers: 2
[val] completed: 1968 | failed: 0 | skipped: 0
[test] total rows: 2214 | to process: 2214 | workers: 2
[test] completed: 2214 | failed: 0 | skipped: 0


Fix 'bad' YAML files that have 'UNK' in them

In [19]:
def yaml_has_unk(yml_path: Path) -> bool:
    if not yml_path.exists():
        return False
    try:
        txt = yml_path.read_text()
    except Exception:
        return False
    return (
        "sequence: UNK" in txt
        or "sequence: UNKNOWN" in txt
        or "sequence: NA" in txt
    )


In [20]:
def rewrite_bad_yaml(yml_path: Path) -> bool:
    """
    Rewrite a YAML file in-place if it contains UNK/UNKNOWN sequences.
    Returns True if rewritten, False otherwise.
    """
    if not yaml_has_unk(yml_path):
        return False

    try:
        lines = yml_path.read_text().splitlines()
    except Exception as e:
        print(f"[WARN] Could not read {yml_path}: {e}")
        return False

    new_lines = []
    i = 0
    n = len(lines)

    while i < n:
        line = lines[i]

        # detect start of a protein block
        if line.strip() == "- protein:":
            block = [line]
            i += 1
            while i < n and not lines[i].startswith("- protein:"):
                block.append(lines[i])
                i += 1

            # decide whether to keep this block
            block_text = "\n".join(block)
            if (
                "sequence: UNK" in block_text
                or "sequence: UNKNOWN" in block_text
                or "sequence: NA" in block_text
            ):
                # drop this entire protein block
                continue
            else:
                new_lines.extend(block)
        else:
            new_lines.append(line)
            i += 1

    # Write back only if something actually changed
    new_text = "\n".join(new_lines).rstrip() + "\n"
    yml_path.write_text(new_text)
    return True


In [21]:
def rewrite_bad_yamls_in_dir(yaml_dir: Path) -> None:
    yamls = sorted(yaml_dir.glob("*.yaml"))
    n_bad = 0
    n_fixed = 0

    for yml in yamls:
        if yaml_has_unk(yml):
            n_bad += 1
            if rewrite_bad_yaml(yml):
                n_fixed += 1

    print(f"[CLEANUP] {yaml_dir}")
    print(f"  bad YAMLs detected: {n_bad}")
    print(f"  rewritten:          {n_fixed}")


In [22]:
rewrite_bad_yamls_in_dir(TRAIN_DIR)
rewrite_bad_yamls_in_dir(VAL_DIR)
rewrite_bad_yamls_in_dir(TEST_DIR)


[CLEANUP] /home/natasha/multimodal_model/data/train
  bad YAMLs detected: 4845
  rewritten:          4845
[CLEANUP] /home/natasha/multimodal_model/data/val
  bad YAMLs detected: 230
  rewritten:          230
[CLEANUP] /home/natasha/multimodal_model/data/test
  bad YAMLs detected: 220
  rewritten:          220


In [None]:
# DONE

# TO DO:

# Remove YAML files where both TCRa abd TCR b are missing - so ones that only have two proteins tagged in them
# Check that YAML file names are the same as in the paid_id in the manifest for Boltz runs and correct lengths in model 
# Save manifest files


In [None]:
# Remove YAML files where both TCRa and TCRb are missing

from pathlib import Path
import shutil
import yaml

def _load_yaml(yml_path: Path) -> dict:
    return yaml.safe_load(yml_path.read_text())

def _extract_protein_seq(doc: dict, pid: str) -> str:
    # returns "" if missing
    seqs = doc.get("sequences", []) or []
    for item in seqs:
        prot = (item or {}).get("protein", {}) or {}
        if str(prot.get("id", "")).strip() == pid:
            s = prot.get("sequence", "")
            return "" if s is None else str(s).strip()
    return ""

def remove_no_tcr_yamls(
    yaml_dir: Path,
    *,
    quarantine_dir: Path | None = None,
    dry_run: bool = True,
) -> dict:
    yamls = sorted(yaml_dir.glob("*.yaml"))
    n_total = 0
    n_removed = 0
    bad_parse = 0

    if quarantine_dir is not None:
        quarantine_dir.mkdir(parents=True, exist_ok=True)

    for yml in yamls:
        n_total += 1
        try:
            doc = _load_yaml(yml)
        except Exception:
            bad_parse += 1
            continue

        a = _extract_protein_seq(doc, "A")
        b = _extract_protein_seq(doc, "B")

        no_tcr = (len(a) == 0) and (len(b) == 0)
        if not no_tcr:
            continue

        n_removed += 1
        if dry_run:
            continue

        if quarantine_dir is not None:
            shutil.move(str(yml), str(quarantine_dir / yml.name))
        else:
            yml.unlink(missing_ok=True)

    summary = {
        "yaml_dir": str(yaml_dir),
        "total_yaml": n_total,
        "bad_parse": bad_parse,
        "no_tcr_found": n_removed,
        "action": "dry_run" if dry_run else ("moved" if quarantine_dir else "deleted"),
    }
    print(summary)
    return summary


In [26]:
TRAIN_DIR = Path("/home/natasha/multimodal_model/data/train")
VAL_DIR   = Path("/home/natasha/multimodal_model/data/val")
TEST_DIR  = Path("/home/natasha/multimodal_model/data/test")

# 1) Dry-run first (recommended)
remove_no_tcr_yamls(TRAIN_DIR, dry_run=True)
remove_no_tcr_yamls(VAL_DIR, dry_run=True)
remove_no_tcr_yamls(TEST_DIR, dry_run=True)

# 2) Then actually move them somewhere (safer than delete)
remove_no_tcr_yamls(TRAIN_DIR, quarantine_dir=TRAIN_DIR/"_quarantine_no_tcr", dry_run=False)
remove_no_tcr_yamls(VAL_DIR,   quarantine_dir=VAL_DIR/"_quarantine_no_tcr",   dry_run=False)
remove_no_tcr_yamls(TEST_DIR,  quarantine_dir=TEST_DIR/"_quarantine_no_tcr",  dry_run=False)


{'yaml_dir': '/home/natasha/multimodal_model/data/train', 'total_yaml': 32691, 'bad_parse': 0, 'no_tcr_found': 1409, 'action': 'dry_run'}
{'yaml_dir': '/home/natasha/multimodal_model/data/val', 'total_yaml': 1968, 'bad_parse': 0, 'no_tcr_found': 16, 'action': 'dry_run'}
{'yaml_dir': '/home/natasha/multimodal_model/data/test', 'total_yaml': 2214, 'bad_parse': 0, 'no_tcr_found': 18, 'action': 'dry_run'}
{'yaml_dir': '/home/natasha/multimodal_model/data/train', 'total_yaml': 32691, 'bad_parse': 0, 'no_tcr_found': 1409, 'action': 'moved'}
{'yaml_dir': '/home/natasha/multimodal_model/data/val', 'total_yaml': 1968, 'bad_parse': 0, 'no_tcr_found': 16, 'action': 'moved'}
{'yaml_dir': '/home/natasha/multimodal_model/data/test', 'total_yaml': 2214, 'bad_parse': 0, 'no_tcr_found': 18, 'action': 'moved'}


{'yaml_dir': '/home/natasha/multimodal_model/data/test',
 'total_yaml': 2214,
 'bad_parse': 0,
 'no_tcr_found': 18,
 'action': 'moved'}

##### Function to write manifest files manually (if needed)

In [27]:
import pandas as pd
from pathlib import Path
import yaml

def write_manifest_from_yamls(
    yaml_dir: Path,
    *,
    base_dir: Path | None = None,
    split_name: str | None = None,
    mani_dir: Path | None = None,
    df: pd.DataFrame | None = None,
    label_cols: list[str] | None = None,
) -> pd.DataFrame:
    """
    Creates a manifest exactly matching YAMLs in yaml_dir.

    If df is provided, it will try to attach label columns by mapping pair_id -> row index
    (pair_000, pair_001, ...). This assumes your original pair_id numbering came from df row order.
    """
    yamls = sorted(yaml_dir.glob("*.yaml"))
    rows = []
    bad_parse = 0

    for yml in yamls:
        pair_id = yml.stem  # "pair_000123" etc.

        try:
            doc = yaml.safe_load(yml.read_text())
        except Exception:
            bad_parse += 1
            continue

        a = _extract_protein_seq(doc, "A")
        b = _extract_protein_seq(doc, "B")
        c = _extract_protein_seq(doc, "C")
        d = _extract_protein_seq(doc, "D")

        # yaml path in manifest: relative to base_dir if provided; otherwise absolute
        if base_dir is not None:
            try:
                yaml_path = str(yml.relative_to(base_dir))
            except ValueError:
                yaml_path = str(yml)
        else:
            yaml_path = str(yml)

        row = {
            "pair_id": pair_id,
            "yaml_path": yaml_path,
            "pep_len": len(c),
            "tcra_len": len(a),
            "tcrb_len": len(b),
            "hla_len": len(d),
        }

        # Optional: join labels from original df by decoding pair_id index
        if df is not None and label_cols:
            try:
                idx = int(pair_id.split("_")[1])
                for col in label_cols:
                    row[col] = df.iloc[idx][col]
            except Exception:
                for col in label_cols:
                    row[col] = None

        rows.append(row)

    manifest_df = pd.DataFrame(rows)
    print(f"[MANIFEST] {yaml_dir} | yamls={len(yamls)} | rows={len(manifest_df)} | bad_parse={bad_parse}")

    if mani_dir is not None:
        mani_dir.mkdir(parents=True, exist_ok=True)
        if split_name is None:
            split_name = yaml_dir.name
        out = mani_dir / f"{split_name}_manifest.csv"
        manifest_df.to_csv(out, index=False)
        print(f"[MANIFEST] wrote: {out}")

    return manifest_df


In [28]:
BASE_DIR = Path("/home/natasha/multimodal_model")
MANI_DIR = BASE_DIR / "manifests"

train_manifest = write_manifest_from_yamls(TRAIN_DIR, base_dir=BASE_DIR, split_name="train", mani_dir=MANI_DIR)
val_manifest   = write_manifest_from_yamls(VAL_DIR,   base_dir=BASE_DIR, split_name="val",   mani_dir=MANI_DIR)
test_manifest  = write_manifest_from_yamls(TEST_DIR,  base_dir=BASE_DIR, split_name="test",  mani_dir=MANI_DIR)


[MANIFEST] /home/natasha/multimodal_model/data/train | yamls=31282 | rows=31282 | bad_parse=0
[MANIFEST] wrote: /home/natasha/multimodal_model/manifests/train_manifest.csv
[MANIFEST] /home/natasha/multimodal_model/data/val | yamls=1952 | rows=1952 | bad_parse=0
[MANIFEST] wrote: /home/natasha/multimodal_model/manifests/val_manifest.csv
[MANIFEST] /home/natasha/multimodal_model/data/test | yamls=2196 | rows=2196 | bad_parse=0
[MANIFEST] wrote: /home/natasha/multimodal_model/manifests/test_manifest.csv


In [None]:
train_manifest.head(10)


Unnamed: 0,pair_id,yaml_path,pep_len,tcra_len,tcrb_len,hla_len
0,pair_000,data/train/pair_000.yaml,9,110,114,365
1,pair_002,data/train/pair_002.yaml,10,112,114,365
2,pair_003,data/train/pair_003.yaml,10,112,114,365
3,pair_004,data/train/pair_004.yaml,9,114,115,365
4,pair_005,data/train/pair_005.yaml,10,113,115,365
5,pair_006,data/train/pair_006.yaml,10,113,115,365
6,pair_007,data/train/pair_007.yaml,10,113,115,365
7,pair_008,data/train/pair_008.yaml,10,113,115,365
8,pair_009,data/train/pair_009.yaml,10,113,115,365
9,pair_010,data/train/pair_010.yaml,10,114,119,365


In [None]:
def write_manifest_from_df(df: pd.DataFrame, split_name: str, base_dir: Path, yaml_dir: Path, mani_dir: Path, *, pad: int = 3) -> pd.DataFrame:
    mani_dir.mkdir(parents=True, exist_ok=True)

    manifest_rows = []
    df2 = df.reset_index(drop=True)

    for i, row in df2.iterrows():
        pair_id = f"pair_{i:0{pad}d}"

        pep  = normalise_seq_for_yaml(row.get("Peptide", ""))
        mhc  = normalise_seq_for_yaml(row.get("HLA_sequence", ""))
        tcra = normalise_seq_for_yaml(row.get("TCRa", ""))
        tcrb = normalise_seq_for_yaml(row.get("TCRb", ""))

        yml_path = yaml_dir / f"{pair_id}.yaml"
        try:
            yml_rel = str(yml_path.relative_to(base_dir))
        except ValueError:
            yml_rel = str(yml_path)

        manifest_rows.append({
            "pair_id": pair_id,
            "yaml_path": yml_rel,
            "pep_len": len(pep),
            "tcra_len": len(tcra),
            "tcrb_len": len(tcrb),
            "hla_len": len(mhc),
        })

    manifest_df = pd.DataFrame(manifest_rows)
    manifest_df.to_csv(mani_dir / f"{split_name}_manifest.csv", index=False)
    return manifest_df


Add an explicit 'pair_id' column to the dataframes

In [None]:
# DONE

# What I need to do:

# I need to rewrite the code so it skips the alpha or beta chain if it's missing - check this for the manifest file creation, would it spot if there was no chain to log as length = 0?
# I need to also then decide whether or not to remove the whole row and therefore data point if the TCR is completely missing

In [None]:
# MSA_ROOT = BASE_DIR / "data" / "raw" / "MSA" / "jackhmmer_msas"

# train_manifest = process_split(train_df, "train", BASE_DIR, TRAIN_DIR, MSA_ROOT, cfg)
# val_manifest   = process_split(val_df,   "val",   BASE_DIR, VAL_DIR,   MSA_ROOT, cfg)
# test_manifest  = process_split(test_df,  "test",  BASE_DIR, TEST_DIR,  MSA_ROOT, cfg)

# train_manifest.to_csv(MANI_DIR / "train_manifest.csv", index=False)
# val_manifest.to_csv(MANI_DIR / "val_manifest.csv", index=False)
# test_manifest.to_csv(MANI_DIR / "test_manifest.csv", index=False)
