In [1]:
import os
import numpy as np
import pandas as pd
from collections import Counter

home = "/home/natasha/multimodal_model"
negative_manifest_path = os.path.join(home, "data", "negative_manifests", "boltz_100_manifest.csv")
neg = pd.read_csv(negative_manifest_path)

neg.head(), len(neg)


(    pair_id                          yaml_path  pep_len  tcra_len  tcrb_len  \
 0  pair_000  data/negative_pairs/pair_000.yaml        9       203       241   
 1  pair_001  data/negative_pairs/pair_001.yaml        9       194       243   
 2  pair_002  data/negative_pairs/pair_002.yaml       10       194       243   
 3  pair_003  data/negative_pairs/pair_003.yaml        9       194       243   
 4  pair_004  data/negative_pairs/pair_004.yaml        9       201       244   
 
    hla_len  
 0      365  
 1      365  
 2      365  
 3      365  
 4      365  ,
 100)

In [2]:
def get_emb_path(base_path, split_dir, yaml_path):
    pair_id = os.path.splitext(os.path.basename(yaml_path))[0]
    emb_path = os.path.join(
        base_path,
        "outputs", "boltz_runs",
        split_dir,
        pair_id,
        f"boltz_results_{pair_id}",
        "predictions",
        pair_id,
        f"embeddings_{pair_id}.npz",
    )
    return pair_id, emb_path

# sanity check on first row
pair_id0, path0 = get_emb_path(home, "negatives", neg.loc[0, "yaml_path"])
pair_id0, path0, os.path.exists(path0)


('pair_000',
 '/home/natasha/multimodal_model/outputs/boltz_runs/negatives/pair_000/boltz_results_pair_000/predictions/pair_000/embeddings_pair_000.npz',
 False)

In [3]:
records = []
missing = 0

for i, row in neg.iterrows():
    pair_id, emb_path = get_emb_path(home, "negatives", row["yaml_path"])
    exists = os.path.exists(emb_path)
    if not exists:
        missing += 1
        continue

    rec = {
        "idx": i,
        "pair_id": pair_id,
        "emb_path": emb_path,
        "pep_len": int(row["pep_len"]),
        "tcra_len": int(row["tcra_len"]),
        "tcrb_len": int(row["tcrb_len"]),
        "hla_len": int(row["hla_len"]),
    }
    rec["L_T_manifest"]  = rec["tcra_len"] + rec["tcrb_len"]
    rec["L_PH_manifest"] = rec["pep_len"] + rec["hla_len"]
    rec["L_manifest"]    = rec["L_T_manifest"] + rec["L_PH_manifest"]
    records.append(rec)

df = pd.DataFrame(records)
print("neg rows total:", len(neg))
print("neg rows with embeddings:", len(df))
print("missing embeddings:", missing)
df[["L_T_manifest","L_PH_manifest","L_manifest"]].describe()


neg rows total: 100
neg rows with embeddings: 32
missing embeddings: 68


Unnamed: 0,L_T_manifest,L_PH_manifest,L_manifest
count,32.0,32.0,32.0
mean,433.6875,373.5,807.1875
std,44.137204,1.565763,43.999588
min,265.0,370.0,639.0
25%,441.0,373.0,814.0
50%,442.0,374.0,817.0
75%,449.0,375.0,820.5
max,456.0,375.0,831.0


In [4]:
def safe_load_zshape(npz_path):
    try:
        with np.load(npz_path) as arr:
            z = arr["z"]
        # accept (L,L,d) or (1,L,L,d)
        if z.ndim == 4:
            if z.shape[0] == 1:
                z = z[0]
            else:
                # unexpected; take first item
                z = z[0]
        if z.ndim != 3:
            return None, f"bad_ndim:{z.ndim}"
        L0, L1, dB = z.shape
        if L0 != L1:
            return (L0, L1, dB), "non_square"
        return (L0, L1, dB), "ok"
    except Exception as e:
        return None, f"err:{type(e).__name__}"

shapes = []
statuses = Counter()

for j, rec in df.iterrows():
    shape, status = safe_load_zshape(rec["emb_path"])
    statuses[status] += 1
    if shape is not None:
        shapes.append((rec["idx"], rec["pair_id"], *shape, rec["L_manifest"], rec["L_T_manifest"], rec["L_PH_manifest"]))

print("status counts:", statuses)

shape_df = pd.DataFrame(shapes, columns=["idx","pair_id","L0","L1","dB","L_manifest","L_T_manifest","L_PH_manifest"])
shape_df[["L0","dB","L_manifest","L_T_manifest","L_PH_manifest"]].describe()


status counts: Counter({'ok': 32})


Unnamed: 0,L0,dB,L_manifest,L_T_manifest,L_PH_manifest
count,32.0,32.0,32.0,32.0,32.0
mean,807.1875,128.0,807.1875,433.6875,373.5
std,43.999588,0.0,43.999588,44.137204,1.565763
min,639.0,128.0,639.0,265.0,370.0
25%,814.0,128.0,814.0,441.0,373.0
50%,817.0,128.0,817.0,442.0,374.0
75%,820.5,128.0,820.5,449.0,375.0
max,831.0,128.0,831.0,456.0,375.0


In [5]:
shape_df["L_pad"] = shape_df["L0"]
shape_df["delta_total"] = shape_df["L_manifest"] - shape_df["L_pad"]
shape_df["delta_T"]     = shape_df["L_T_manifest"] - shape_df["L_pad"]  # crude (just to see extremes)

print("How many have L_manifest > L_pad ?", (shape_df["delta_total"] > 0).sum(), "out of", len(shape_df))
shape_df.sort_values("delta_total", ascending=False).head(15)


How many have L_manifest > L_pad ? 0 out of 32


Unnamed: 0,idx,pair_id,L0,L1,dB,L_manifest,L_T_manifest,L_PH_manifest,L_pad,delta_total,delta_T
0,2,pair_002,812,812,128,812,437,375,812,0,-375
1,4,pair_004,819,819,128,819,445,374,819,0,-374
2,10,pair_010,819,819,128,819,448,371,819,0,-371
3,13,pair_013,822,822,128,822,448,374,822,0,-374
4,21,pair_021,823,823,128,823,449,374,823,0,-374
5,28,pair_028,823,823,128,823,449,374,823,0,-374
6,30,pair_030,818,818,128,818,444,374,818,0,-374
7,32,pair_032,814,814,128,814,441,373,814,0,-373
8,35,pair_035,814,814,128,814,441,373,814,0,-373
9,38,pair_038,814,814,128,814,441,373,814,0,-373


In [6]:
shape_counts = shape_df["L_pad"].value_counts().head(20)
shape_counts


L_pad
819    4
817    4
816    3
814    3
822    3
813    2
820    2
823    2
812    1
831    1
815    1
818    1
830    1
824    1
807    1
639    1
642    1
Name: count, dtype: int64

In [8]:
print("max L_T_manifest:", int(shape_df["L_T_manifest"].max()))
print("max L_PH_manifest:", int(shape_df["L_PH_manifest"].max()))

max L_T_manifest: 456
max L_PH_manifest: 375


In [9]:
import os
import pandas as pd

home = "/home/natasha/multimodal_model"
negative_manifest_path = os.path.join(home, "data", "negative_manifests", "boltz_100_manifest.csv")

neg = pd.read_csv(negative_manifest_path)

def get_emb_path(base_path, split_dir, yaml_path):
    pair_id = os.path.splitext(os.path.basename(yaml_path))[0]
    emb_path = os.path.join(
        base_path,
        "outputs", "boltz_runs",
        split_dir,
        pair_id,
        f"boltz_results_{pair_id}",
        "predictions",
        pair_id,
        f"embeddings_{pair_id}.npz",
    )
    return pair_id, emb_path

keep_mask = []
pair_ids = []
emb_paths = []

for _, row in neg.iterrows():
    pair_id, emb_path = get_emb_path(home, "negatives", row["yaml_path"])
    ok = os.path.exists(emb_path)
    keep_mask.append(ok)
    pair_ids.append(pair_id)
    emb_paths.append(emb_path)

neg["pair_id"] = pair_ids
neg["emb_path"] = emb_paths

neg_complete = neg[keep_mask].copy().reset_index(drop=True)

print("original rows:", len(neg))
print("kept rows:", len(neg_complete))
print("dropped rows:", len(neg) - len(neg_complete))

# Save as a new, explicit manifest (recommended)
negative_manifest_complete_path = os.path.join(
    home, "data", "negative_manifests", "boltz_100_manifest_COMPLETE.csv"
)
neg_complete.drop(columns=["emb_path"], errors="ignore").to_csv(negative_manifest_complete_path, index=False)
print("wrote:", negative_manifest_complete_path)


original rows: 100
kept rows: 32
dropped rows: 68
wrote: /home/natasha/multimodal_model/data/negative_manifests/boltz_100_manifest_COMPLETE.csv


In [7]:
def load_z(npz_path):
    with np.load(npz_path) as arr:
        z = arr["z"]
    if z.ndim == 4 and z.shape[0] == 1:
        z = z[0]
    elif z.ndim == 4:
        z = z[0]
    return z  # (L,L,dB)

def check_partition_consistency(row):
    z = load_z(row["emb_path"])
    L_pad = z.shape[0]
    La, Lb = int(row["L_T_manifest"] - (row["L_T_manifest"] - row["L_T_manifest"])), int(0)  # ignore; we use L_T directly
    L_T  = int(row["L_T_manifest"])
    L_PH = int(row["L_PH_manifest"])
    L    = int(row["L_manifest"])

    # clamp L to available z
    L = min(L, L_pad)
    L_T = min(L_T, L)  # keep TCR first
    L_PH = min(L_PH, L - L_T)

    # slices
    Z = z[:L, :L, :]
    Z_TT = Z[:L_T, :L_T, :]
    Z_TPH = Z[:L_T, L_T:L_T+L_PH, :]
    # just report shapes
    return {
        "pair_id": row["pair_id"],
        "L_pad": L_pad,
        "L": L,
        "L_T": L_T,
        "L_PH": L_PH,
        "Z_TT_i": Z_TT.shape[0],
        "Z_TPH_i": Z_TPH.shape[0],
        "Z_TPH_j": Z_TPH.shape[1],
    }

# run on a few samples with big L_T / known issues
sample = shape_df.sort_values("L_T_manifest", ascending=False).head(10)
checks = [check_partition_consistency(sample.iloc[k]) for k in range(len(sample))]
pd.DataFrame(checks)


KeyError: 'emb_path'

In [None]:
print("A_T rows:", boltz_factoriser.A_T.shape[0])
print("A_PH rows:", boltz_factoriser.A_PH.shape[0])
print("Expected max L_T (manifest):", int(shape_df["L_T_manifest"].max()))
print("Expected max L_PH (manifest):", int(shape_df["L_PH_manifest"].max()))


In [None]:
A_T_rows = boltz_factoriser.A_T.shape[0]
A_PH_rows = boltz_factoriser.A_PH.shape[0]

bad_T = shape_df[shape_df["L_T_manifest"] > A_T_rows].sort_values("L_T_manifest", ascending=False)
bad_PH = shape_df[shape_df["L_PH_manifest"] > A_PH_rows].sort_values("L_PH_manifest", ascending=False)

print("bad_T count:", len(bad_T))
print("bad_PH count:", len(bad_PH))

bad_T.head(5)
