This script:
- Takes YAML and manifest files
- Runs Boltz on those files


Inputs:
- Directory containing YAML files needed to run Boltz
- Manifest file path for those YAML files

Outputs:
- Boltz directory containing runs of YAML files


Run Boltz for Train/Test/Val in Chunk Directories

In [None]:
from pathlib import Path
import subprocess, shlex
import pandas as pd

# === Configure paths ===
BASE_DIR = Path("/home/natasha/multimodal_model")

RUN_ROOT = BASE_DIR / "outputs"
RUN_ROOT.mkdir(parents=True, exist_ok=True)

BOLTZ_OUT_TRAIN = RUN_ROOT / "train"
BOLTZ_OUT_VAL   = RUN_ROOT / "val"
BOLTZ_OUT_TEST  = RUN_ROOT / "test"
BOLTZ_OUT_TRAIN.mkdir(parents=True, exist_ok=True)
BOLTZ_OUT_VAL.mkdir(parents=True, exist_ok=True)
BOLTZ_OUT_TEST.mkdir(parents=True, exist_ok=True)

YAML_DIR_TRAIN = BASE_DIR / "data" / "train"
YAML_DIR_VAL   = BASE_DIR / "data" / "val"
YAML_DIR_TEST  = BASE_DIR / "data" / "test"

# Where your chunks live (created by your chunking helper)
TRAIN_CHUNKS_ROOT = YAML_DIR_TRAIN / "_chunks"
TEST_CHUNKS_ROOT  = YAML_DIR_TEST / "_chunks"

# One GPU => run chunks sequentially
NPROC = 1

BOLTZ_CMD_TEMPLATE = (
    "conda run -n boltz-env --no-capture-output boltz predict {input_path} "
    "--out_dir {outdir} "
    "--accelerator gpu "
    "--devices 1 "
    "--model boltz2 "
    "--recycling_steps 1 "
    "--sampling_steps 10 "
    "--diffusion_samples 1 "
    "--max_parallel_samples 1 "
    "--max_msa_seqs 64 "
    "--num_subsampled_msa 34 "
    # IMPORTANT: remove --override once you're in production/resume mode
    "--write_embeddings"
)

# --override (only use when want to rerun on all values)
# also reduced num sampled msa to 34 and also number  of recycling steps to 10 (from 20), to try and speed things up

def run_cli(input_path: Path, outdir: Path) -> int:
    """
    input_path can be:
      - a single YAML file, or
      - a directory containing many YAMLs (chunk)
    """
    input_path = Path(input_path).resolve()
    outdir = Path(outdir).resolve()
    outdir.mkdir(parents=True, exist_ok=True)

    cmd = BOLTZ_CMD_TEMPLATE.format(input_path=str(input_path), outdir=str(outdir))
    print("CMD:", cmd)

    with open(outdir / "stdout.log", "w") as so, open(outdir / "stderr.log", "w") as se:
        proc = subprocess.run(
            shlex.split(cmd),
            stdout=so,
            stderr=se,
            text=True,
            cwd=str(BASE_DIR),
        )

    print("Return code:", proc.returncode)
    return proc.returncode


This notebook takes the manifest files, the YAML files and the csv files and matches the pair_id across all three sources to ensure consolidating tagging across all data sources

In [None]:
from pathlib import Path
import re
import pandas as pd
import yaml

BASE_DIR = Path("/home/natasha/multimodal_model")
DATA_DIR = BASE_DIR / "data"
OUT_PATH = BASE_DIR / "manifests" / "pair_table.csv"   # or wherever you want
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)


In [None]:
def _collect_sequence_entries(obj):
    """
    Recursively walk a loaded YAML object and collect any dicts that look like
    { ..., "sequence": <string>, ... }.
    Returns list of dicts with keys: name, chain, kind, sequence (as available).
    """
    found = []

    def walk(x):
        if isinstance(x, dict):
            if "sequence" in x and isinstance(x["sequence"], (str, type(None))):
                entry = {
                    "name": x.get("name") or x.get("id") or x.get("chain") or x.get("label"),
                    "kind": x.get("type") or x.get("kind") or x.get("molecule") or x.get("entity"),
                    "chain": x.get("chain") or x.get("chain_id") or x.get("asym_id"),
                    "sequence": x.get("sequence"),
                }
                found.append(entry)
            for v in x.values():
                walk(v)
        elif isinstance(x, list):
            for v in x:
                walk(v)

    walk(obj)
    return found


def _normalise_seq(s):
    if s is None:
        return None
    s = str(s).strip()
    if s == "" or s.lower() in {"none", "null", "nan"}:
        return None
    return s


def _assign_roles(entries):
    """
    Determine (peptide, tcra, tcrb, hla) from collected sequence entries.

    Priority:
      1) explicit 'name' contains hints (tcr a/b, peptide, hla/mhc)
      2) fallback to length heuristics:
         - peptide: <= 30
         - hla: >= 250 (often 300-400)
         - tcra/tcrb: 60-200
    """
    # clean and filter
    cleaned = []
    for e in entries:
        seq = _normalise_seq(e.get("sequence"))
        if seq is None:
            continue
        cleaned.append({**e, "sequence": seq, "len": len(seq), "name_l": (e.get("name") or "").lower()})

    # --- name-based picks ---
    pep = None
    hla = None
    tcra = None
    tcrb = None

    def pick_first(pred):
        for e in cleaned:
            if pred(e):
                return e["sequence"]
        return None

    # Peptide
    pep = pick_first(lambda e: any(k in e["name_l"] for k in ["pep", "peptide", "antigen"]))

    # HLA/MHC
    hla = pick_first(lambda e: any(k in e["name_l"] for k in ["hla", "mhc", "class i", "class_i", "class-i"]))

    # TCR alpha/beta
    tcra = pick_first(lambda e: any(k in e["name_l"] for k in ["tcra", "tcr_a", "tcr alpha", "alpha chain", "alpha_chain"]))
    tcrb = pick_first(lambda e: any(k in e["name_l"] for k in ["tcrb", "tcr_b", "tcr beta", "beta chain", "beta_chain"]))

    # --- fallback to length heuristics if missing ---
    # If peptide not set, choose the shortest <=30
    if pep is None:
        pep_cands = [e for e in cleaned if e["len"] <= 30]
        pep = min(pep_cands, key=lambda e: e["len"])["sequence"] if pep_cands else None

    # If HLA not set, choose the longest >=250 (or just absolute longest if none >=250)
    if hla is None:
        hla_cands = [e for e in cleaned if e["len"] >= 250]
        if hla_cands:
            hla = max(hla_cands, key=lambda e: e["len"])["sequence"]
        elif cleaned:
            hla = max(cleaned, key=lambda e: e["len"])["sequence"]

    # Remaining candidates for TCRa/TCRb: 60-200, excluding chosen pep/hla by identity
    remaining = [e for e in cleaned if e["sequence"] not in {pep, hla} and 60 <= e["len"] <= 220]

    # If tcra/tcrb still missing, pick two longest from remaining
    if (tcra is None) or (tcrb is None):
        remaining_sorted = sorted(remaining, key=lambda e: e["len"], reverse=True)
        if tcra is None and len(remaining_sorted) >= 1:
            tcra = remaining_sorted[0]["sequence"]
        if tcrb is None and len(remaining_sorted) >= 2:
            tcrb = remaining_sorted[1]["sequence"]

    return pep, tcra, tcrb, hla


In [None]:
def _infer_split_and_chunk(yaml_path: Path):
    p = str(yaml_path)
    split = None
    chunk = None

    # split inference
    if "/data/train/" in p:
        split = "train"
    elif "/data/val/" in p:
        split = "val"
    elif "/data/test/" in p:
        split = "test"
    else:
        split = "unknown"

    # chunk inference
    m = re.search(r"/_chunks/(chunk_\d{3})/", p)
    if m:
        chunk = m.group(1)
    return split, chunk


def build_pair_table(data_dir: Path) -> pd.DataFrame:
    yamls = sorted(data_dir.rglob("*.yaml"))
    rows = []
    bad = []

    for y in yamls:
        pair_id = y.stem  # expects pair_XXXX
        split, chunk = _infer_split_and_chunk(y)

        try:
            obj = yaml.safe_load(y.read_text())
        except Exception as e:
            bad.append((y, f"YAML parse error: {e}"))
            continue

        entries = _collect_sequence_entries(obj)
        pep, tcra, tcrb, hla = _assign_roles(entries)

        # If any are None, track as bad (you can decide whether to keep)
        if pep is None or hla is None:
            bad.append((y, f"Missing pep or hla after parsing (pep={pep is not None}, hla={hla is not None})"))
            continue

        rows.append({
            "pair_id": pair_id,
            "yaml_path": str(y),
            "split": split,
            "chunk": chunk,
            "pep_seq": pep,
            "tcra_seq": tcra,
            "tcrb_seq": tcrb,
            "hla_seq": hla,
            "pep_len": len(pep) if pep else None,
            "tcra_len": len(tcra) if tcra else None,
            "tcrb_len": len(tcrb) if tcrb else None,
            "hla_len": len(hla) if hla else None,
        })

    df = pd.DataFrame(rows).sort_values(["split", "chunk", "pair_id"], na_position="last").reset_index(drop=True)

    print(f"Found YAMLs: {len(yamls)}")
    print(f"Parsed OK : {len(df)}")
    print(f"Bad/skip  : {len(bad)}")
    if bad:
        print("\nFirst 10 skipped:")
        for p, reason in bad[:10]:
            print(" -", p, "=>", reason)

    return df


pair_table = build_pair_table(DATA_DIR)
pair_table.to_csv(OUT_PATH, index=False)
print(f"\nWrote: {OUT_PATH}")
pair_table.head()


Found YAMLs: 70539
Parsed OK : 70525
Bad/skip  : 14

First 10 skipped:
 - /home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_10061.yaml => YAML parse error: [Errno 2] No such file or directory: '/home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_10061.yaml'
 - /home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_10863.yaml => YAML parse error: [Errno 2] No such file or directory: '/home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_10863.yaml'
 - /home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_11083.yaml => YAML parse error: [Errno 2] No such file or directory: '/home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_11083.yaml'
 - /home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_11244.yaml => YAML parse error: [Errno 2] No such file or directory: '/home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_11244.yaml'
 - /home/natasha/multimodal_model/data/train/_chunks/chunk_000/pair_1

Unnamed: 0,pair_id,yaml_path,split,chunk,pep_seq,tcra_seq,tcrb_seq,hla_seq,pep_len,tcra_len,tcrb_len,hla_len
0,pair_000,/home/natasha/multimodal_model/data/test/_chun...,test,chunk_000,TSTLQEQIGW,EAGVTQFPSHSVIEKGQTVTLRCDPISGHDNLYWYRRVMGKEIKFL...,GNSVTQMEGPVTLSEEAFLTINCTYTATGYPSLFWYVQYPGEGLQL...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,10,117.0,112.0,362
1,pair_001,/home/natasha/multimodal_model/data/test/_chun...,test,chunk_000,GSLSPELRPIF,GEDVEQSLFLSVREGDSSVINCTYTDSSSTYLYWYKQEPGAGLQLL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,11,111.0,111.0,362
2,pair_002,/home/natasha/multimodal_model/data/test/_chun...,test,chunk_000,GTIRPEIPDYF,GEDVEQSLFLSVREGDSSVINCTYTDSSSTYLYWYKQEPGAGLQLL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,11,111.0,111.0,362
3,pair_003,/home/natasha/multimodal_model/data/test/_chun...,test,chunk_000,KAFSPEVIPMF,GEDVEQSLFLSVREGDSSVINCTYTDSSSTYLYWYKQEPGAGLQLL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,11,111.0,111.0,362
4,pair_004,/home/natasha/multimodal_model/data/test/_chun...,test,chunk_000,KSLTPEVRGYW,GEDVEQSLFLSVREGDSSVINCTYTDSSSTYLYWYKQEPGAGLQLL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,11,111.0,111.0,362


Take pair_id table and build pair_id in files

In [None]:
import pandas as pd
from pathlib import Path
from datetime import datetime

PAIR_TABLE_PATH = Path("/home/natasha/multimodal_model/manifests/pair_table.csv")

def norm_seq(s):
    if pd.isna(s):
        return ""
    return "".join([c for c in str(s).strip().upper() if c.isalpha()])

pair_table = pd.read_csv(PAIR_TABLE_PATH)

# normalised join keys
pair_table["Peptide_n"] = pair_table["pep_seq"].map(norm_seq)
pair_table["TCRa_n"] = pair_table["tcra_seq"].fillna("").map(norm_seq)
pair_table["TCRb_n"] = pair_table["tcrb_seq"].fillna("").map(norm_seq)
pair_table["HLAseq_n"] = pair_table["hla_seq"].map(norm_seq)

pair_table_keyed = pair_table[["pair_id","split","chunk","yaml_path","Peptide_n","TCRa_n","TCRb_n","HLAseq_n"]].copy()

print("pair_table rows:", len(pair_table_keyed), "unique pair_ids:", pair_table_keyed["pair_id"].nunique())


  pair_table = pd.read_csv(PAIR_TABLE_PATH)


pair_table rows: 70525 unique pair_ids: 32679


In [None]:
def tag_df_with_pair_id(df_path: Path, split_name: str) -> None:
    df = pd.read_csv(df_path)

    # Make join keys from df (adapt column names if yours differ)
    df["Peptide_n"] = df["Peptide"].map(norm_seq)
    df["TCRa_n"] = df["TCRa"].fillna("").map(norm_seq)
    df["TCRb_n"] = df["TCRb"].fillna("").map(norm_seq)
    df["HLAseq_n"] = df["HLA_sequence"].map(norm_seq)

    # Restrict pair_table to the same split to reduce accidental collisions
    pt = pair_table_keyed[pair_table_keyed["split"] == split_name].copy()

    merged = df.merge(
        pt,
        how="left",
        on=["Peptide_n","TCRa_n","TCRb_n","HLAseq_n"],
        suffixes=("", "_pt"),
        validate="m:1"  # each df row should map to at most one pair_id
    )

    # counts
    n_total = len(merged)
    n_tagged = merged["pair_id"].notna().sum()
    n_missing = n_total - n_tagged

    print(f"\n[{split_name}] {df_path}")
    print("  total rows :", n_total)
    print("  tagged     :", n_tagged)
    print("  missing    :", n_missing)

    # Backup then write updated CSV
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_path = df_path.with_suffix(df_path.suffix + f".bak_{ts}")
    df_path.rename(backup_path)
    print("  backup ->", backup_path)

    # Keep original columns, but add pair_id (and optionally yaml_path/chunk)
    out = merged.copy()
    # drop helper cols
    out = out.drop(columns=["Peptide_n","TCRa_n","TCRb_n","HLAseq_n"], errors="ignore")

    # Optional: keep yaml_path + chunk for easier debugging
    # out already has yaml_path/chunk from pair_table

    out.to_csv(df_path, index=False)
    print("  wrote  ->", df_path)

    # If anything is missing, save a diagnostic file
    if n_missing > 0:
        miss_path = df_path.with_suffix(df_path.suffix + f".missing_pair_id_{ts}.csv")
        out[out["pair_id"].isna()].to_csv(miss_path, index=False)
        print("  missing rows saved ->", miss_path)


In [None]:
tag_df_with_pair_id(Path("/home/natasha/multimodal_model/data/train/train_df.csv"), "train")
tag_df_with_pair_id(Path("/home/natasha/multimodal_model/data/val/val_df.csv"), "val")
tag_df_with_pair_id(Path("/home/natasha/multimodal_model/data/test/test_df.csv"), "test")


MergeError: Merge keys are not unique in right dataset; not a many-to-one merge

In [None]:
import pandas as pd
from pathlib import Path
from datetime import datetime

def norm_seq(s):
    if pd.isna(s):
        return ""
    return "".join([c for c in str(s).strip().upper() if c.isalpha()])

def tag_df_with_pair_id_allow_dupes(df_path: Path, split_name: str, pair_table: pd.DataFrame) -> None:
    df = pd.read_csv(df_path)

    # build df keys
    df["_pep"]  = df["Peptide"].map(norm_seq)
    df["_a"]    = df["TCRa"].fillna("").map(norm_seq)
    df["_b"]    = df["TCRb"].fillna("").map(norm_seq)
    df["_hla"]  = df["HLA_sequence"].map(norm_seq)

    # subset pair_table to split, and key it
    pt = pair_table[pair_table["split"] == split_name].copy()
    pt["_pep"] = pt["pep_seq"].map(norm_seq)
    pt["_a"]   = pt["tcra_seq"].fillna("").map(norm_seq)
    pt["_b"]   = pt["tcrb_seq"].fillna("").map(norm_seq)
    pt["_hla"] = pt["hla_seq"].map(norm_seq)

    # include lengths as additional key to reduce collisions
    pt["_pep_len"]  = pt["pep_len"]
    pt["_a_len"]    = pt["tcra_len"].fillna(0).astype(int)
    pt["_b_len"]    = pt["tcrb_len"].fillna(0).astype(int)
    pt["_hla_len"]  = pt["hla_len"]

    df["_pep_len"]  = df["_pep"].str.len()
    df["_a_len"]    = df["_a"].str.len()
    df["_b_len"]    = df["_b"].str.len()
    df["_hla_len"]  = df["_hla"].str.len()

    key_cols = ["_pep","_a","_b","_hla","_pep_len","_a_len","_b_len","_hla_len"]

    # group both sides by key, then assign pair_ids within each key group by order
    df["pair_id"] = pd.NA

    pt_groups = {k: g["pair_id"].tolist() for k, g in pt.groupby(key_cols, dropna=False)}
    used = set()

    missing_keys = 0
    overfull_keys = 0

    for k, idxs in df.groupby(key_cols, dropna=False).groups.items():
        candidates = pt_groups.get(k, [])
        if not candidates:
            missing_keys += len(idxs)
            continue

        # remove already used
        candidates = [c for c in candidates if c not in used]

        if len(candidates) < len(idxs):
            # not enough candidates to assign uniquely
            overfull_keys += (len(idxs) - len(candidates))

        # assign as many as we can
        for row_i, pid in zip(list(idxs), candidates):
            df.at[row_i, "pair_id"] = pid
            used.add(pid)

    n_total = len(df)
    n_tagged = df["pair_id"].notna().sum()
    n_missing = n_total - n_tagged

    print(f"\n[{split_name}] {df_path}")
    print("  total rows :", n_total)
    print("  tagged     :", n_tagged)
    print("  missing    :", n_missing)
    print("  missing_keys_rows:", missing_keys)
    print("  overfull_key_rows:", overfull_keys)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_path = df_path.with_suffix(df_path.suffix + f".bak_{ts}")
    df_path.rename(backup_path)
    print("  backup ->", backup_path)

    # drop helper cols
    df_out = df.drop(columns=key_cols, errors="ignore")
    df_out.to_csv(df_path, index=False)
    print("  wrote  ->", df_path)

    if n_missing > 0:
        miss_path = df_path.with_suffix(df_path.suffix + f".missing_pair_id_{ts}.csv")
        df_out[df_out["pair_id"].isna()].to_csv(miss_path, index=False)
        print("  missing rows saved ->", miss_path)


In [None]:
pair_table = pd.read_csv("/home/natasha/multimodal_model/manifests/pair_table.csv")

tag_df_with_pair_id_allow_dupes(Path("/home/natasha/multimodal_model/data/train/train_df.csv"), "train", pair_table)
tag_df_with_pair_id_allow_dupes(Path("/home/natasha/multimodal_model/data/val/val_df.csv"), "val", pair_table)
tag_df_with_pair_id_allow_dupes(Path("/home/natasha/multimodal_model/data/test/test_df.csv"), "test", pair_table)


  pair_table = pd.read_csv("/home/natasha/multimodal_model/manifests/pair_table.csv")


ParserError: Error tokenizing data. C error: EOF inside string starting at row 11522

In [None]:
import pandas as pd
from pathlib import Path

PAIR_TABLE_PATH = Path("/home/natasha/multimodal_model/manifests/pair_table.csv")

OUT_TRAIN = Path("/home/natasha/multimodal_model/data/train/train_df.fixed.csv")
OUT_VAL   = Path("/home/natasha/multimodal_model/data/val/val_df.fixed.csv")
OUT_TEST  = Path("/home/natasha/multimodal_model/data/test/test_df.fixed.csv")

pt = pd.read_csv(PAIR_TABLE_PATH)

def write_fixed(split: str, out_path: Path):
    sub = pt[pt["split"] == split].copy()

    fixed = pd.DataFrame({
        "pair_id": sub["pair_id"],
        "Peptide": sub["pep_seq"].fillna(""),
        "TCRa": sub["tcra_seq"].fillna(""),
        "TCRb": sub["tcrb_seq"].fillna(""),
        "HLA_sequence": sub["hla_seq"].fillna(""),
        "yaml_path": sub["yaml_path"],
        "chunk": sub["chunk"],
        # keep lengths because your model wants them
        "pep_len": sub["pep_len"],
        "tcra_len": sub["tcra_len"],
        "tcrb_len": sub["tcrb_len"],
        "hla_len": sub["hla_len"],
    })

    fixed.to_csv(out_path, index=False)
    print(f"[{split}] wrote {len(fixed):,} rows -> {out_path}")

write_fixed("train", OUT_TRAIN)
write_fixed("val",   OUT_VAL)
write_fixed("test",  OUT_TEST)


  pt = pd.read_csv(PAIR_TABLE_PATH)


[train] wrote 63,945 rows -> /home/natasha/multimodal_model/data/train/train_df.fixed.csv
[val] wrote 1,968 rows -> /home/natasha/multimodal_model/data/val/val_df.fixed.csv
[test] wrote 4,410 rows -> /home/natasha/multimodal_model/data/test/test_df.fixed.csv


In [2]:
# RUN TRAIN CHUNKS

def list_chunk_dirs(chunks_root: Path):
    chunks_root = Path(chunks_root).resolve()
    if not chunks_root.exists():
        raise FileNotFoundError(f"Chunks root not found: {chunks_root}")

    # chunk_000, chunk_001, ...
    chunk_dirs = sorted([p for p in chunks_root.iterdir() if p.is_dir()])
    if not chunk_dirs:
        raise ValueError(f"No chunk directories found in: {chunks_root}")
    return chunk_dirs

train_chunk_dirs = list_chunk_dirs(TRAIN_CHUNKS_ROOT)

for chunk_dir in train_chunk_dirs:
    chunk_name = chunk_dir.name
    outdir = BOLTZ_OUT_TRAIN / chunk_name
    print(f"\n=== TRAIN {chunk_name} ===")
    rc = run_cli(chunk_dir, outdir)
    if rc != 0:
        print(f"[STOP] Train chunk failed: {chunk_name}. See logs in {outdir}")
        break



=== TRAIN chunk_000 ===
CMD: conda run -n boltz-env --no-capture-output boltz predict /home/natasha/multimodal_model/data/train/_chunks/chunk_000 --out_dir /home/natasha/multimodal_model/outputs/train/chunk_000 --accelerator gpu --devices 1 --model boltz2 --recycling_steps 1 --sampling_steps 20 --diffusion_samples 1 --max_parallel_samples 1 --max_msa_seqs 64 --num_subsampled_msa 64 --override --write_embeddings
Return code: 0

=== TRAIN chunk_001 ===
CMD: conda run -n boltz-env --no-capture-output boltz predict /home/natasha/multimodal_model/data/train/_chunks/chunk_001 --out_dir /home/natasha/multimodal_model/outputs/train/chunk_001 --accelerator gpu --devices 1 --model boltz2 --recycling_steps 1 --sampling_steps 20 --diffusion_samples 1 --max_parallel_samples 1 --max_msa_seqs 64 --num_subsampled_msa 64 --override --write_embeddings
Return code: 0

=== TRAIN chunk_002 ===
CMD: conda run -n boltz-env --no-capture-output boltz predict /home/natasha/multimodal_model/data/train/_chunks/c

In [None]:
# RUN TEST CHUNKS

test_chunk_dirs = list_chunk_dirs(TEST_CHUNKS_ROOT)

for chunk_dir in test_chunk_dirs:
    chunk_name = chunk_dir.name
    outdir = BOLTZ_OUT_TEST / chunk_name
    print(f"\n=== TEST {chunk_name} ===")
    rc = run_cli(chunk_dir, outdir)
    if rc != 0:
        print(f"[STOP] Test chunk failed: {chunk_name}. See logs in {outdir}")
        break


In [None]:
# RUN VAL CHUNK

print("\n=== VAL (single directory) ===")
rc = run_cli(YAML_DIR_VAL, BOLTZ_OUT_VAL / "val_all")
if rc != 0:
    print(f"[FAIL] Val run failed. See logs in {BOLTZ_OUT_VAL / 'val_all'}")


Creating Symlink Folders + Troubleshooting

In [None]:
from pathlib import Path
import os
import math
import shutil
from typing import Literal, Optional

def chunk_yaml_directory(
    src_dir: Path,
    *,
    chunk_size: int = 2000,
    chunk_prefix: str = "chunk_",
    dst_parent: Optional[Path] = None,
    mode: Literal["symlink", "hardlink", "copy", "move"] = "symlink",
    pattern: str = "*.yaml",
    overwrite: bool = False,
) -> Path:
    """
    Split a directory of YAML files into chunk subfolders.

    By default, creates symlinks (fast, no duplication, does not break existing paths).

    Args:
        src_dir: Directory containing YAMLs (e.g., .../data/train).
        chunk_size: Number of YAMLs per chunk folder.
        chunk_prefix: Folder name prefix (chunk_000, chunk_001, ...).
        dst_parent: Where to create chunk folders. Defaults to src_dir / "_chunks".
        mode: One of {"symlink","hardlink","copy","move"}.
        pattern: Glob pattern for YAMLs.
        overwrite: If True, overwrites existing links/files inside chunk dirs.

    Returns:
        Path to the chunk root directory.
    """
    src_dir = Path(src_dir).resolve()
    if not src_dir.exists():
        raise FileNotFoundError(f"src_dir does not exist: {src_dir}")
    if chunk_size <= 0:
        raise ValueError("chunk_size must be > 0")

    yamls = sorted(src_dir.glob(pattern))
    if not yamls:
        raise ValueError(f"No files matching {pattern} found in {src_dir}")

    # Put chunks under src_dir/_chunks by default (keeps original directory intact)
    chunk_root = (dst_parent or (src_dir / "_chunks")).resolve()
    chunk_root.mkdir(parents=True, exist_ok=True)

    n = len(yamls)
    n_chunks = math.ceil(n / chunk_size)

    def _place(src: Path, dst: Path):
        if dst.exists() or dst.is_symlink():
            if overwrite:
                dst.unlink()
            else:
                return  # keep existing
        if mode == "symlink":
            os.symlink(src, dst)
        elif mode == "hardlink":
            os.link(src, dst)
        elif mode == "copy":
            shutil.copy2(src, dst)
        elif mode == "move":
            shutil.move(str(src), str(dst))
        else:
            raise ValueError(f"Unknown mode: {mode}")

    for i in range(n_chunks):
        chunk_dir = chunk_root / f"{chunk_prefix}{i:03d}"
        chunk_dir.mkdir(parents=True, exist_ok=True)

        start = i * chunk_size
        end = min((i + 1) * chunk_size, n)
        for src in yamls[start:end]:
            dst = chunk_dir / src.name
            _place(src, dst)

    return chunk_root


def make_train_test_chunks(
    base_data_dir: Path,
    *,
    train_subdir: str = "train",
    test_subdir: str = "test",
    chunk_size: int = 2000,
    mode: Literal["symlink", "hardlink", "copy", "move"] = "symlink",
) -> tuple[Path, Path]:
    """
    Create chunk folders for train and test YAML directories.
    """
    base_data_dir = Path(base_data_dir).resolve()

    train_dir = base_data_dir / train_subdir
    test_dir  = base_data_dir / test_subdir

    train_chunks = chunk_yaml_directory(train_dir, chunk_size=chunk_size, mode=mode)
    test_chunks  = chunk_yaml_directory(test_dir,  chunk_size=chunk_size, mode=mode)

    return train_chunks, test_chunks


In [None]:
BASE_DATA = Path("/home/natasha/multimodal_model/data")
train_chunks_root, test_chunks_root = make_train_test_chunks(
    BASE_DATA,
    chunk_size=2000,
    mode="symlink",  # recommended
)
print("Train chunks:", train_chunks_root)
print("Test chunks:", test_chunks_root)


Find 'bad' YAML files (Files with missing sequences)

In [17]:
from pathlib import Path
import re

NULL_SEQ_PATTERNS = [
    re.compile(r"^\s*sequence:\s*(null|None)\s*$", re.IGNORECASE),
    re.compile(r"^\s*sequence:\s*$"),  # sequence: (empty)
]

def find_yaml_with_null_sequences(folder: Path):
    folder = Path(folder)
    bad = []
    for p in folder.glob("*.yaml"):
        try:
            lines = p.read_text(errors="ignore").splitlines()
        except Exception:
            bad.append((p, "read_error"))
            continue
        for i, line in enumerate(lines):
            if any(rx.match(line) for rx in NULL_SEQ_PATTERNS):
                bad.append((p, f"null/empty sequence at line {i+1}: {line.strip()}"))
                break
    return bad

# Example:
bad = find_yaml_with_null_sequences(YAML_DIR_TRAIN)  # or a chunk dir
print("Bad YAMLs:", len(bad))
for p, reason in bad[:20]:
    print(p.name, "->", reason)


Bad YAMLs: 14
pair_8118.yaml -> null/empty sequence at line 17: sequence:
pair_11302.yaml -> null/empty sequence at line 17: sequence:
pair_11083.yaml -> null/empty sequence at line 17: sequence:
pair_9184.yaml -> null/empty sequence at line 17: sequence:
pair_9558.yaml -> null/empty sequence at line 17: sequence:
pair_10061.yaml -> null/empty sequence at line 17: sequence:
pair_11244.yaml -> null/empty sequence at line 13: sequence:
pair_8054.yaml -> null/empty sequence at line 17: sequence:
pair_9005.yaml -> null/empty sequence at line 17: sequence:
pair_10863.yaml -> null/empty sequence at line 17: sequence:
pair_9004.yaml -> null/empty sequence at line 17: sequence:
pair_8117.yaml -> null/empty sequence at line 17: sequence:
pair_9559.yaml -> null/empty sequence at line 17: sequence:
pair_8397.yaml -> null/empty sequence at line 17: sequence:


In [18]:
bad_2 = find_yaml_with_null_sequences(YAML_DIR_VAL)
print("Bad YAMLs:", len(bad_2))
for p, reason in bad_2[:20]:
    print(p.name, "->", reason)

bad_3 = find_yaml_with_null_sequences(YAML_DIR_TEST)
print("Bad YAMLs:", len(bad_3))

Bad YAMLs: 0
Bad YAMLs: 0


Remove 'BAD' YAML files (Only need to do this once)

In [19]:
from pathlib import Path
import pandas as pd
import shutil
from datetime import datetime

BASE_DIR = Path("/home/natasha/multimodal_model")

DATA_DIR = BASE_DIR / "data"
MANI_DIR = BASE_DIR / "manifests"

YAML_DIR_TRAIN = DATA_DIR / "train"
YAML_DIR_VAL   = DATA_DIR / "val"
YAML_DIR_TEST  = DATA_DIR / "test"

TRAIN_CHUNKS_ROOT = YAML_DIR_TRAIN / "_chunks"
TEST_CHUNKS_ROOT  = YAML_DIR_TEST / "_chunks"

MANIFEST_PATHS = {
    "train": MANI_DIR / "train_manifest.csv",
    "val":   MANI_DIR / "val_manifest.csv",
    "test":  MANI_DIR / "test_manifest.csv",
}

BAD_YAMLS = {
    "pair_8118.yaml",
    "pair_11302.yaml",
    "pair_11083.yaml",
    "pair_9184.yaml",
    "pair_9558.yaml",
    "pair_10061.yaml",
    "pair_11244.yaml",
    "pair_8054.yaml",
    "pair_9005.yaml",
    "pair_10863.yaml",
    "pair_9004.yaml",
    "pair_8117.yaml",
    "pair_9559.yaml",
    "pair_8397.yaml",
}


In [20]:
def delete_files_if_present(paths):
    deleted = []
    missing = []
    for p in paths:
        if p.exists() or p.is_symlink():
            try:
                p.unlink()
                deleted.append(p)
            except Exception as e:
                print(f"[WARN] Could not delete {p}: {e}")
        else:
            missing.append(p)
    return deleted, missing

# 1) delete from main YAML dirs
main_targets = []
for d in [YAML_DIR_TRAIN, YAML_DIR_VAL, YAML_DIR_TEST]:
    for name in BAD_YAMLS:
        main_targets.append(d / name)

deleted_main, missing_main = delete_files_if_present(main_targets)

print(f"Deleted from main dirs: {len(deleted_main)}")
print(f"Missing from main dirs (fine): {len(missing_main)}")

# 2) delete from chunk dirs (train/test only)
chunk_targets = []
for chunks_root in [TRAIN_CHUNKS_ROOT, TEST_CHUNKS_ROOT]:
    if chunks_root.exists():
        for name in BAD_YAMLS:
            chunk_targets.extend(chunks_root.rglob(name))

deleted_chunks, missing_chunks = delete_files_if_present(chunk_targets)

print(f"Deleted from chunks: {len(deleted_chunks)}")


Deleted from main dirs: 14
Missing from main dirs (fine): 28
Deleted from chunks: 0


In [21]:
def strip_bad_from_manifest(csv_path: Path, bad_names: set[str]) -> int:
    df = pd.read_csv(csv_path)
    if "yaml_path" not in df.columns:
        raise ValueError(f"{csv_path} does not have a 'yaml_path' column. Columns: {list(df.columns)}")

    before = len(df)
    basenames = df["yaml_path"].astype(str).apply(lambda x: Path(x).name)
    df2 = df[~basenames.isin(bad_names)].copy()
    after = len(df2)
    removed = before - after

    # backup then overwrite
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    backup_path = csv_path.with_suffix(csv_path.suffix + f".bak_{ts}")
    shutil.copy2(csv_path, backup_path)

    df2.to_csv(csv_path, index=False)
    print(f"{csv_path.name}: removed {removed} rows (backup: {backup_path.name})")
    return removed

total_removed = 0
for split, path in MANIFEST_PATHS.items():
    if path.exists():
        total_removed += strip_bad_from_manifest(path, BAD_YAMLS)
    else:
        print(f"[WARN] Manifest not found for {split}: {path}")

print("Total rows removed across manifests:", total_removed)


train_manifest.csv: removed 14 rows (backup: train_manifest.csv.bak_20260119_153914)
val_manifest.csv: removed 0 rows (backup: val_manifest.csv.bak_20260119_153914)
test_manifest.csv: removed 0 rows (backup: test_manifest.csv.bak_20260119_153914)
Total rows removed across manifests: 14


In [22]:
# Confirm no manifest references remain
for split, path in MANIFEST_PATHS.items():
    if not path.exists():
        continue
    df = pd.read_csv(path)
    basenames = df["yaml_path"].astype(str).apply(lambda x: Path(x).name)
    hits = sorted(set(basenames) & BAD_YAMLS)
    print(split, "remaining bad references:", hits)


train remaining bad references: []
val remaining bad references: []
test remaining bad references: []


Old Code - Before Added Chunking Folders

In [4]:
from pathlib import Path
import subprocess, shutil, os
import pandas as pd  # you use pd in cell 3
import os, re, textwrap
import torch
import pytorch_lightning as pl
import subprocess, shlex, sys, os, re
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np


# === Configure paths ===
BASE_DIR = Path("/home/natasha/multimodal_model") 

MANI_DIR  = BASE_DIR / "manifests"

# change these
MANIFEST_PATH_TRAIN = MANI_DIR / "train_manifest.csv"
MANIFEST_PATH_VAL   = MANI_DIR / "val_manifest.csv"
MANIFEST_PATH_TEST  = MANI_DIR / "test_manifest.csv"

YAML_DIR_TRAIN = BASE_DIR / "data" / "train"
YAML_DIR_VAL   = BASE_DIR / "data" / "val"
YAML_DIR_TEST  = BASE_DIR / "data" / "test"


# Boltz output directory

RUN_ROOT = BASE_DIR / "outputs"
RUN_ROOT.mkdir(parents=True, exist_ok=True)

# Boltz output directories split by train/val/test
BOLTZ_OUT_TRAIN = RUN_ROOT / "train"
BOLTZ_OUT_VAL   = RUN_ROOT / "val"
BOLTZ_OUT_TEST  = RUN_ROOT / "test"

BOLTZ_OUT_TRAIN.mkdir(parents=True, exist_ok=True)
BOLTZ_OUT_VAL.mkdir(parents=True, exist_ok=True)
BOLTZ_OUT_TEST.mkdir(parents=True, exist_ok=True)




Boltz Command and Execution Functions

In [5]:
# Choose one: "CLI" (command line) or "API" (inâ€‘process Python)
RUN_MODE = "CLI"          # change to "CLI" to use local CLI
NPROC = 5                # keep 1 for strictly step-by-step; increase later if desired

# --- CLI config (if RUN_MODE == "CLI") ---
# This calls your local repo code via `python -m ...`. Edit to match your runner.
# Use full path to boltz from tcr-multimodal environment
BOLTZ_CMD_TEMPLATE = (
    "conda run -n boltz-env --no-capture-output boltz predict {yaml} "
    "--out_dir {outdir} "
    "--accelerator gpu "
    "--devices 1 "
    "--model boltz2 "
    "--recycling_steps 1 "
    "--sampling_steps 20 "
    "--diffusion_samples 1 "
    "--max_parallel_samples 1 "
    "--max_msa_seqs 64 "
    "--num_subsampled_msa 64 "
    "--override "
    "--write_embeddings"
)

# removed --no_kernels - supposedly boltz has default kernals it uses to speed up the intensive operations

# --- API config (if RUN_MODE == "API") ---
# Adjust import and call to your environment. The key is save_z=True.
API_IMPORT = "from boltz.predict import boltz_predict"
API_CALL   = "boltz_predict"            # callable name
API_KWARGS = {"save_z": True}           # ensure 'z' is saved


In [None]:
# trying to offload RAM memory
def run_cli(yaml_path: Path, outdir: Path):
    outdir.mkdir(parents=True, exist_ok=True)
    cmd = BOLTZ_CMD_TEMPLATE.format(yaml=str(yaml_path), outdir=str(outdir))
    print("CMD:", cmd)
    with open(outdir / "stdout.log", "w") as so, open(outdir / "stderr.log", "w") as se:
        proc = subprocess.run(
            shlex.split(cmd),
            stdout=so,
            stderr=se,
            text=True,
            cwd=str(BASE_DIR),
        )
    print("Return code:", proc.returncode)
    return proc.returncode


def run_one(yaml_rel_path, YAML_DIR, BOLTZ_OUT_ROOT):
    yaml_path = (YAML_DIR / Path(yaml_rel_path).name).resolve()
    pair_id   = yaml_path.stem
    outdir    = BOLTZ_OUT_ROOT / pair_id
    if RUN_MODE.upper() == "CLI":
        rc = run_cli(yaml_path, outdir)
    else:
        rc = run_api(yaml_path, outdir)
    return pair_id, rc


def execute_boltz_runs(manifest_df, YAML_DIR, BOLTZ_OUT_ROOT):
    if NPROC == 1:
        results = []
        for _, row in manifest_df.iterrows():
            results.append(run_one(row["yaml_path"], YAML_DIR, BOLTZ_OUT_ROOT))
    else:
        results = []
        with ThreadPoolExecutor(max_workers=NPROC) as ex:
            futs = [ex.submit(run_one, y, YAML_DIR, BOLTZ_OUT_ROOT) for y in manifest_df["yaml_path"]]
            for fut in as_completed(futs):
                results.append(fut.result())

    print("Completed:", sum(rc==0 for _, rc in results), "/", len(results), "successes")
    results[:5]
    return


In [9]:
test = pd.read_csv(MANIFEST_PATH_TRAIN)
test = test.head(1)
test = execute_boltz_runs(test, YAML_DIR_TRAIN, BOLTZ_OUT_TRAIN)

Unnamed: 0,pair_id,yaml_path,pep_len,tcra_len,tcrb_len,hla_len
0,pair_000,data/train/pair_000.yaml,9,110,114,365
