This notebook takes the manifest files, the YAML files and the csv files and matches the pair_id across all three sources to ensure consolidating tagging across all data sources

In [37]:
import pandas as pd
from pathlib import Path
import yaml

BASE = Path("/home/natasha/multimodal_model")
OUT  = BASE / "manifests" / "pair_table.canonical.csv"

SPLITS = {
    "train": BASE / "data/train/_chunks",
    "val":   BASE / "data/val",
    "test":  BASE / "data/test/_chunks",
}

def parse_yaml_sequences(yaml_path: Path):
    d = yaml.safe_load(yaml_path.read_text())
    seqs = {}
    for item in d.get("sequences", []):
        prot = item.get("protein", {})
        pid = prot.get("id")
        seq = prot.get("sequence", "")
        if pid:
            seqs[pid] = seq
    # A=tcra, B=tcrb, C=pep, D=hla (per your writer)
    return (
        seqs.get("C",""),
        seqs.get("A",""),
        seqs.get("B",""),
        seqs.get("D",""),
    )

rows = []
bad = []
for split, root in SPLITS.items():
    yamls = sorted(root.rglob("pair_*.yaml"))
    # IMPORTANT: only keep canonical ones (avoid duplicate filenames across multiple dirs)
    # For train/test, chunk view is canonical; for val root is canonical.
    print(split, "yamls:", len(yamls))
    for yml in yamls:
        try:
            pep, tcra, tcrb, hla = parse_yaml_sequences(yml)
            pair_id = yml.stem
            rows.append({
                "pair_id": pair_id,
                "split": split,
                "yaml_path": str(yml.relative_to(BASE)),
                "pep_seq": pep,
                "tcra_seq": tcra,
                "tcrb_seq": tcrb,
                "hla_seq": hla,
                "pep_len": len(pep),
                "tcra_len": len(tcra),
                "tcrb_len": len(tcrb),
                "hla_len": len(hla),
            })
        except Exception as e:
            bad.append((split, str(yml), repr(e)))

pt = pd.DataFrame(rows).sort_values(["split","pair_id"])
pt.to_csv(OUT, index=False)
print("Wrote:", OUT)
print("Bad:", len(bad))
if bad[:5]:
    print("First 5 bad:", bad[:5])


train yamls: 31166
val yamls: 1960
test yamls: 2190
Wrote: /home/natasha/multimodal_model/manifests/pair_table.canonical.csv
Bad: 0


In [38]:
import pandas as pd
from pathlib import Path

PAIR_TABLE = Path("/home/natasha/multimodal_model/manifests/pair_table.canonical.csv")
OUT_DIR = Path("/home/natasha/multimodal_model/data")

pt = pd.read_csv(PAIR_TABLE)

def write_split(split: str):
    df = pt[pt["split"] == split].copy()

    # Build exactly what your dataset classes want
    df["TCR_full"] = df["tcra_seq"].fillna("").astype(str) + df["tcrb_seq"].fillna("").astype(str)

    out = pd.DataFrame({
        "pair_id": df["pair_id"],
        "Peptide": df["pep_seq"],
        "HLA_sequence": df["hla_seq"],
        "TCR_full": df["TCR_full"],
    })

    out_path = OUT_DIR / split / f"{split}_df_clean.csv"
    out.to_csv(out_path, index=False)
    print(split, "->", out_path, "rows:", len(out))

write_split("train")
write_split("val")
write_split("test")


train -> /home/natasha/multimodal_model/data/train/train_df_clean.csv rows: 31166
val -> /home/natasha/multimodal_model/data/val/val_df_clean.csv rows: 1960
test -> /home/natasha/multimodal_model/data/test/test_df_clean.csv rows: 2190


In [39]:
import pandas as pd
from pathlib import Path

BASE = Path("/home/natasha/multimodal_model")
for split in ["train","val","test"]:
    p = BASE / f"data/{split}/{split}_df_clean.csv"
    df = pd.read_csv(p)
    print(split, df.columns.tolist())
    print("rows:", len(df), "unique pair_id:", df["pair_id"].nunique())
    print("any missing peptide?", df["Peptide"].isna().any(), "any missing HLA?", df["HLA_sequence"].isna().any())
    print("----")


train ['pair_id', 'Peptide', 'HLA_sequence', 'TCR_full']
rows: 31166 unique pair_id: 31166
any missing peptide? False any missing HLA? False
----
val ['pair_id', 'Peptide', 'HLA_sequence', 'TCR_full']
rows: 1960 unique pair_id: 1960
any missing peptide? False any missing HLA? False
----
test ['pair_id', 'Peptide', 'HLA_sequence', 'TCR_full']
rows: 2190 unique pair_id: 2190
any missing peptide? False any missing HLA? False
----


In [40]:
import pandas as pd
from pathlib import Path

BASE = Path("/home/natasha/multimodal_model")

def audit(split):
    p = BASE / f"data/{split}/{split}_df_clean.csv"
    df = pd.read_csv(p)

    # count missing vs empty
    n_nan = df["HLA_sequence"].isna().sum()
    n_empty = (df["HLA_sequence"].fillna("").astype(str).str.strip() == "").sum()

    print(f"\n[{split}] rows={len(df)}")
    print("  HLA_sequence NaN:", n_nan)
    print("  HLA_sequence empty/blank:", n_empty)

audit("train")
audit("val")
audit("test")



[train] rows=31166
  HLA_sequence NaN: 0
  HLA_sequence empty/blank: 0

[val] rows=1960
  HLA_sequence NaN: 0
  HLA_sequence empty/blank: 0

[test] rows=2190
  HLA_sequence NaN: 0
  HLA_sequence empty/blank: 0


In [41]:
# Remove any bad HLAs

import pandas as pd
from pathlib import Path

BASE = Path("/home/natasha/multimodal_model")

def bad_pair_ids_from_clean(split: str):
    df = pd.read_csv(BASE / f"data/{split}/{split}_df_clean.csv")
    bad = df[df["HLA_sequence"].fillna("").astype(str).str.strip() == ""]
    return set(bad["pair_id"].tolist())

bad_train = bad_pair_ids_from_clean("train")
bad_val   = bad_pair_ids_from_clean("val")
bad_test  = bad_pair_ids_from_clean("test")

BAD = bad_train | bad_val | bad_test
print("Bad counts:", {"train": len(bad_train), "val": len(bad_val), "test": len(bad_test)})
print("Total BAD:", len(BAD))


Bad counts: {'train': 0, 'val': 0, 'test': 0}
Total BAD: 0


In [42]:
# Remove the bad HLAS from the YAMLs

import os, shutil, time
from pathlib import Path
import pandas as pd

BASE = Path("/home/natasha/multimodal_model")
PAIR_TABLE_PATH = BASE / "manifests/pair_table.canonical.csv"
pt = pd.read_csv(PAIR_TABLE_PATH)

ts = time.strftime("%Y%m%d_%H%M%S")
QUAR = BASE / "data" / "_quarantine_missing_hla" / ts
QUAR.mkdir(parents=True, exist_ok=True)

def quarantine_yaml_for_pair(split: str, pair_id: str):
    rows = pt[(pt["split"] == split) & (pt["pair_id"] == pair_id)]
    if rows.empty:
        return ("no_row_in_pair_table", None, None)

    # pair_table yaml_path is typically relative like data/train/_chunks/...
    yrel = rows.iloc[0]["yaml_path"]
    y = (BASE / yrel) if str(yrel).startswith("data/") else Path(yrel)

    if not y.exists():
        return ("yaml_missing", str(y), None)

    dest_dir = QUAR / split
    dest_dir.mkdir(parents=True, exist_ok=True)

    if y.is_symlink():
        target = y.resolve()
        # move the *target* YAML (canonical file), and remove the symlink
        moved_target = None
        if target.exists():
            dest_target = dest_dir / target.name
            shutil.move(str(target), str(dest_target))
            moved_target = str(dest_target)
        y.unlink()  # remove symlink
        return ("symlink_removed", str(y), moved_target)

    else:
        dest = dest_dir / y.name
        shutil.move(str(y), str(dest))
        return ("file_moved", str(y), str(dest))

# Quarantine YAMLs for BAD pair_ids
moved = []
for split, badset in [("train", bad_train), ("val", bad_val), ("test", bad_test)]:
    for pid in sorted(badset):
        moved.append((split, pid) + quarantine_yaml_for_pair(split, pid))

# summary
status_counts = {}
for row in moved:
    status_counts[row[2]] = status_counts.get(row[2], 0) + 1
print("Quarantine status counts:", status_counts)
print("Quarantine folder:", QUAR)


Quarantine status counts: {}
Quarantine folder: /home/natasha/multimodal_model/data/_quarantine_missing_hla/20260120_145751


In [43]:
# Remove from pair_id table and regenerate manifests and clean dataframes

import shutil, time
import pandas as pd
from pathlib import Path

BASE = Path("/home/natasha/multimodal_model")
MANI_DIR = BASE / "manifests"
ts = time.strftime("%Y%m%d_%H%M%S")

BAD_ALL = BAD  # from earlier cell

def backup_file(p: Path):
    b = p.with_suffix(p.suffix + f".bak_{ts}")
    shutil.copy2(p, b)
    return b

# 1) Update pair_table.canonical.csv
pt_path = MANI_DIR / "pair_table.canonical.csv"
pt = pd.read_csv(pt_path)
backup_file(pt_path)
pt2 = pt[~pt["pair_id"].isin(BAD_ALL)].copy()
pt2.to_csv(pt_path, index=False)
print("pair_table:", len(pt), "->", len(pt2))

# 2) Update split manifests if present
for split in ["train", "val", "test"]:
    mpath = MANI_DIR / f"{split}_manifest.csv"
    if mpath.exists():
        m = pd.read_csv(mpath)
        backup_file(mpath)
        m2 = m[~m["pair_id"].isin(BAD_ALL)].copy()
        m2.to_csv(mpath, index=False)
        print(f"{split}_manifest:", len(m), "->", len(m2))

# 3) Regenerate *_df_clean.csv directly from the UPDATED pair_table (single source of truth)
pt = pd.read_csv(pt_path)

def write_split_clean(split: str):
    df = pt[pt["split"] == split].copy()
    df["TCR_full"] = df["tcra_seq"].fillna("X").astype(str) + df["tcrb_seq"].fillna("X").astype(str)

    out = pd.DataFrame({
        "pair_id": df["pair_id"],
        "Peptide": df["pep_seq"],
        "HLA_sequence": df["hla_seq"],
        "TCR_full": df["TCR_full"],
    })

    out_path = BASE / f"data/{split}/{split}_df_clean.csv"
    backup_file(out_path)
    out.to_csv(out_path, index=False)
    print(split, "clean rows:", len(out), "unique pair_id:", out["pair_id"].nunique())

for split in ["train", "val", "test"]:
    write_split_clean(split)


pair_table: 35316 -> 35316
train_manifest: 31154 -> 31154
val_manifest: 1938 -> 1938
test_manifest: 2182 -> 2182
train clean rows: 31166 unique pair_id: 31166
val clean rows: 1960 unique pair_id: 1960
test clean rows: 2190 unique pair_id: 2190


In [44]:
import pandas as pd
from pathlib import Path

BASE = Path("/home/natasha/multimodal_model")
pt = pd.read_csv(BASE / "manifests/pair_table.canonical.csv")

def count_yamls(split: str) -> int:
    # count based on pair_table paths (ground truth)
    sub = pt[pt["split"] == split]
    exists = 0
    for yrel in sub["yaml_path"].tolist():
        y = (BASE / yrel) if str(yrel).startswith("data/") else Path(yrel)
        if y.exists():
            exists += 1
    return exists

for split in ["train", "val", "test"]:
    df = pd.read_csv(BASE / f"data/{split}/{split}_df_clean.csv")
    mpath = BASE / f"manifests/{split}_manifest.csv"
    mani_n = pd.read_csv(mpath).shape[0] if mpath.exists() else None
    pt_n = pt[pt["split"] == split].shape[0]
    yaml_n = count_yamls(split)

    missing_hla = (df["HLA_sequence"].fillna("").astype(str).str.strip() == "").sum()

    print(f"\n[{split}]")
    print(" df_clean rows:", len(df))
    print(" manifest rows:", mani_n)
    print(" pair_table rows:", pt_n)
    print(" yaml exists:", yaml_n)
    print(" missing HLA:", missing_hla)



[train]
 df_clean rows: 31166
 manifest rows: 31154
 pair_table rows: 31166
 yaml exists: 31166
 missing HLA: 0

[val]
 df_clean rows: 1960
 manifest rows: 1938
 pair_table rows: 1960
 yaml exists: 1960
 missing HLA: 0

[test]
 df_clean rows: 2190
 manifest rows: 2182
 pair_table rows: 2190
 yaml exists: 2190
 missing HLA: 0
