## deduplicate rbp

In [None]:
# %% RBP-only dedup (drop-in replacement for old script outputs)
# Produces exactly the same files/format as the original:
#   dedup_pwm_out/
#     - training_keep_indices.txt
#     - training_drop_indices.txt
#     - training_map.csv                  (orig_col,cluster_id,representative_col)
#     - training_duplicate_clusters_pwm.json
#
# No RNA dedup. RNAs are only used to build PWMs for duplicate-RBP decisions.

import os, json, csv
import numpy as np

# ----------------------------- paths -----------------------------
RBP_FILE    = "training_RBPs2.txt"    # 200 protein sequences (1 per line)
SEQ_FILE    = "training_seqs.txt"     # 120,678 RNA sequences (T/U -> U)
MATRIX_FILE = "training_data2.txt"    # 120,678 x 200 float matrix (whitespace)
OUT_DIR     = "dedup_pwm_out"

# PWM construction from probe k-mers
KMER_K        = 6
PWM_SOFTMAX_B = 4.0
PWM_PSEUDO    = 1e-3

# Selection rule: "consensus" (paper-style) or "ic" (old behavior)
REP_SELECTION = "consensus"

# ----------------------------- helpers -----------------------------
def read_lines(path, normalize_rna=False):
    seqs = []
    with open(path, 'r') as f:
        for ln in f:
            s = ln.strip().upper()
            if not s:
                continue
            if normalize_rna:
                s = s.replace('T','U')
            seqs.append(s)
    return seqs

def find_duplicate_clusters(seq_list):
    idx_of = {}
    clusters = []
    for i, s in enumerate(seq_list):
        if s in idx_of:
            clusters[idx_of[s]].append(i)
        else:
            idx_of[s] = len(clusters)
            clusters.append([i])
    return [c for c in clusters if len(c) > 1]

NUC2I = {'A':0,'C':1,'G':2,'U':3}

def encode_kmers_of_seq(seq, k=6):
    arr = np.array([NUC2I.get(ch, -1) for ch in seq], dtype=np.int64)
    if (arr < 0).any():
        return np.empty(0, dtype=np.int32)
    n = len(arr); m = n - k + 1
    if m <= 0:
        return np.empty(0, dtype=np.int32)
    codes = np.empty(m, dtype=np.int32)
    base = 0
    pow4 = 4**(k-1)
    for i in range(k):
        base = base*4 + int(arr[i])
    codes[0] = base
    for i in range(1, m):
        base = (base - int(arr[i-1])*pow4)*4 + int(arr[i+k-1])
        codes[i] = base
    return codes

def precompute_probe_kmers(rna_list, k=6):
    offsets = [0]
    chunks = []
    for s in rna_list:
        c = encode_kmers_of_seq(s, k)
        chunks.append(c)
        offsets.append(offsets[-1] + len(c))
    flat = np.concatenate(chunks) if chunks else np.empty(0, dtype=np.int32)
    offsets = np.array(offsets, dtype=np.int64)
    per_probe_counts = offsets[1:] - offsets[:-1]
    counts_by_code = np.bincount(flat, minlength=4**k)
    return flat, offsets, per_probe_counts, counts_by_code

def read_matrix_columns_txt(path, usecols):
    """Return (N_rows x len(cols)) float32 array AND the sorted col list used."""
    cols_sorted = sorted(usecols)
    data_cols = [[] for _ in cols_sorted]
    with open(path, 'r') as f:
        for ln in f:
            parts = ln.strip().split()
            if not parts:
                continue
            for j, c in enumerate(cols_sorted):
                data_cols[j].append(float(parts[c]))
    M = np.stack([np.array(col, dtype=np.float32) for col in data_cols], axis=1)
    return M, cols_sorted

def pwm_from_kmer_scores(avg_scores, counts_by_code, k=6, beta=4.0, pseudocount=1e-3):
    V = 4**k
    assert avg_scores.shape[0] == V
    valid = counts_by_code > 0
    codes = np.nonzero(valid)[0]
    weights = counts_by_code[valid].astype(np.float64)
    vals = avg_scores[valid].astype(np.float64)

    pwm_scores = np.zeros((k,4), dtype=np.float64)
    denom = np.zeros((k,4), dtype=np.float64)
    pow4 = np.array([4**(k-1-p) for p in range(k)], dtype=np.int64)
    for pos in range(k):
        base = (codes // pow4[pos]) % 4
        for b in range(4):
            mask = (base == b)
            if mask.any():
                w = weights[mask]
                v = vals[mask]
                pwm_scores[pos, b] = (w * v).sum()
                denom[pos, b] = w.sum()
    mean = np.divide(pwm_scores, np.maximum(denom, 1e-12))
    pwm = np.zeros_like(mean)
    for pos in range(k):
        x = mean[pos] - mean[pos].mean()
        x = np.exp(beta * x) + pseudocount
        pwm[pos] = x / x.sum()
    return pwm

def info_content_bits(pwm, bg=0.25, eps=1e-9):
    P = pwm + eps
    P = P / P.sum(axis=1, keepdims=True)
    return float(np.sum(P * (np.log2(P) - np.log2(bg))))

# --------- NEW: paper-style PWM alignment metrics (max-corr / min-sKL) ----------
def _overlap_slices(L1, L2, shift):
    # shift>0: Q is moved right; compare P[p] with Q[p-shift]
    start_p = max(0, shift)
    end_p   = min(L1, L2 + shift)
    if end_p - start_p <= 0: return None
    start_q = start_p - shift
    end_q   = start_q + (end_p - start_p)
    return slice(start_p, end_p), slice(start_q, end_q)

def pwm_maxcorr(P, Q):
    """Maximum column-wise Pearson (flattened over bases) across all shifts."""
    L1, L2 = P.shape[0], Q.shape[0]
    best = 0.0
    for s in range(-(L2-1), L1):
        sl = _overlap_slices(L1, L2, s)
        if sl is None: continue
        p = P[sl[0]].reshape(-1)
        q = Q[sl[1]].reshape(-1)
        mp, mq = p.mean(), q.mean()
        sp, sq = p.std(),  q.std()
        if sp == 0 or sq == 0: continue
        r = float(np.dot(p-mp, q-mq) / (len(p)*sp*sq))
        if r > best: best = r
    return best

def pwm_min_sKL(P, Q, eps=1e-12):
    """Minimum symmetric KL across all shifts (reporting metric)."""
    L1, L2 = P.shape[0], Q.shape[0]
    best = float("inf")
    for s in range(-(L2-1), L1):
        sl = _overlap_slices(L1, L2, s)
        if sl is None: continue
        p = P[sl[0]] + eps
        q = Q[sl[1]] + eps
        p = p / p.sum(axis=1, keepdims=True)
        q = q / q.sum(axis=1, keepdims=True)
        d = float(0.5 * (np.sum(p * (np.log(p) - np.log(q))) +
                         np.sum(q * (np.log(q) - np.log(p)))))
        if d < best: best = d
    return best if np.isfinite(best) else 0.0
# -------------------------------------------------------------------

# ----------------------------- main flow -----------------------------
def main():
    assert os.path.exists(RBP_FILE),    f"Missing {RBP_FILE}"
    assert os.path.exists(SEQ_FILE),    f"Missing {SEQ_FILE}"
    assert os.path.exists(MATRIX_FILE), f"Missing {MATRIX_FILE}"
    os.makedirs(OUT_DIR, exist_ok=True)

    print("[1] Loading proteins & finding duplicate RBPs…")
    train_rbps = read_lines(RBP_FILE, normalize_rna=False)
    dup_clusters = find_duplicate_clusters(train_rbps)
    if not dup_clusters:
        print("   No duplicate protein sequences found — nothing to deduplicate.")
        keep_idx = list(range(len(train_rbps)))
        with open(os.path.join(OUT_DIR, "training_keep_indices.txt"), "w") as f:
            for i in keep_idx: f.write(f"{i}\n")
        with open(os.path.join(OUT_DIR, "training_map.csv"), "w", newline="") as f:
            w = csv.writer(f); w.writerow(["orig_col","cluster_id","representative_col"])
            for i in keep_idx: w.writerow([i, -1, i])
        # mimic old behavior: no drop file / cluster json when no dups
        print(f"   Wrote outputs to {OUT_DIR}/ ; exiting.")
        return

    print(f"   Found {len(dup_clusters)} duplicate groups (covering {sum(len(c) for c in dup_clusters)} columns).")

    print("[2] Precomputing k-mers over all RNA probes (k=%d)…" % KMER_K)
    rna_probes = read_lines(SEQ_FILE, normalize_rna=True)
    codes_flat, offsets, per_probe_counts, counts_by_code = precompute_probe_kmers(rna_probes, k=KMER_K)
    N = len(rna_probes)
    repeat_idx = np.repeat(np.arange(N, dtype=np.int32), per_probe_counts)
    print(f"   Probes: {N:,}, total {codes_flat.size:,} {KMER_K}-mers.")

    # gather all duplicate columns
    all_dup_cols = sorted({i for grp in dup_clusters for i in grp})
    print(f"[3] Reading intensities for {len(all_dup_cols)} duplicate columns…")
    Y, cols_sorted = read_matrix_columns_txt(MATRIX_FILE, all_dup_cols)  # (N, K_dup)
    col_to_pos = {c: j for j, c in enumerate(cols_sorted)}

    print("[4] Building PWMs & selecting representatives (paper-style consensus)…")
    cluster_reports = []
    rep_set = set()

    for cid, grp in enumerate(dup_clusters):
        col_info = []
        pwms = {}
        # build PWM for each member
        for col in grp:
            y = Y[:, col_to_pos[col]].astype(np.float32)   # (N,)
            y_rep = y[repeat_idx]                          # expand per-probe to per-kmer weights
            sums = np.bincount(codes_flat, weights=y_rep, minlength=4**KMER_K)
            avg  = sums / np.maximum(counts_by_code, 1)
            pwm  = pwm_from_kmer_scores(avg, counts_by_code, k=KMER_K,
                                        beta=PWM_SOFTMAX_B, pseudocount=PWM_PSEUDO)
            ic   = info_content_bits(pwm)
            col_info.append({"col": int(col), "IC_bits": float(ic)})
            pwms[int(col)] = pwm

        # pairwise metrics (paper-style)
        pairwise = []
        for i in range(len(grp)):
            for j in range(i+1, len(grp)):
                ci, cj = int(grp[i]), int(grp[j])
                r  = pwm_maxcorr(pwms[ci], pwms[cj])   # ADDED
                d  = pwm_min_sKL(pwms[ci], pwms[cj])   # ADDED
                pairwise.append({
                    "col_i": ci, "col_j": cj,
                    "maxcorr": float(r),
                    "min_sKL": float(d)
                })

        # consensus score per member = mean of pairwise maxcorr to others
        if REP_SELECTION.lower() == "consensus" and len(grp) > 1:
            corr_map = {}
            for p in pairwise:
                corr_map.setdefault(p["col_i"], []).append(p["maxcorr"])
                corr_map.setdefault(p["col_j"], []).append(p["maxcorr"])
            for rec in col_info:
                rec["consensus_corr_mean"] = float(np.mean(corr_map.get(rec["col"], [1.0])))
            # choose by consensus, tie-break by IC
            col_info.sort(key=lambda d: (d.get("consensus_corr_mean", 0.0), d["IC_bits"]), reverse=True)
        else:
            # old behavior (IC only)
            col_info.sort(key=lambda d: d["IC_bits"], reverse=True)

        rep_col = col_info[0]["col"]
        rep_set.add(rep_col)

        cluster_reports.append({
            "cluster_id": cid,
            "members": [int(c) for c in grp],
            "representative_col": int(rep_col),
            "members_sorted": col_info,        # includes IC and (if used) consensus_corr_mean
            "pairwise_pwm_stats": pairwise     # includes maxcorr and min_sKL (paper-style)
        })

    # keep = non-duplicates + representatives
    all_cols_set = set(range(len(train_rbps)))
    dup_cols_set = set(all_dup_cols)
    nondup_cols  = sorted(all_cols_set - dup_cols_set)
    keep_cols    = sorted(nondup_cols + list(rep_set))
    drop_cols    = sorted(dup_cols_set - rep_set)

    print("\n[5] Writing EXACT old-style outputs to dedup_pwm_out/ …")
    os.makedirs(OUT_DIR, exist_ok=True)

    # keep
    with open(os.path.join(OUT_DIR, "training_keep_indices.txt"), "w") as f:
        for c in keep_cols: f.write(f"{c}\n")

    # drop
    with open(os.path.join(OUT_DIR, "training_drop_indices.txt"), "w") as f:
        for c in drop_cols: f.write(f"{c}\n")

    # map
    with open(os.path.join(OUT_DIR, "training_map.csv"), "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["orig_col","cluster_id","representative_col"])
        for c in nondup_cols:
            w.writerow([c, -1, c])
        for cid, rep in enumerate(cluster_reports):
            repc = rep["representative_col"]
            for c in rep["members"]:
                w.writerow([c, cid, repc])

    # json report (same filename/key schema as old; now richer payload is fine)
    with open(os.path.join(OUT_DIR, "training_duplicate_clusters_pwm.json"), "w") as f:
        json.dump({
            "k_mer": KMER_K,
            "note": ("Representative chosen by cluster-consensus (mean aligned Pearson) "
                     "with IC tie-break; JSON also reports aligned maxcorr and min sKL."),
            "selection": REP_SELECTION,
            "clusters": cluster_reports
        }, f, indent=2)

    print("[Done]")
    print(f" - keep indices: {OUT_DIR}/training_keep_indices.txt  (len={len(keep_cols)})")
    print(f" - drop indices: {OUT_DIR}/training_drop_indices.txt  (len={len(drop_cols)})")
    print(f" - mapping csv : {OUT_DIR}/training_map.csv")
    print(f" - report json : {OUT_DIR}/training_duplicate_clusters_pwm.json")

if __name__ == "__main__":
    main()


## vienna rna + tools download  , skip if you dont want to calculate secondary struct

In [None]:
!pip install ViennaRNA


In [None]:
!sudo apt-get update
!sudo apt-get install vienna-rna

## code to calculate secondary struct

In [None]:
%%writefile rna_struct_features_jupyter.py
# (paste the whole Jupyter-friendly code I gave you)

# Generate per-sequence NPZ files with PHIME [L,5] (and optional extras) using ViennaRNA.
# Designed to be imported or run inside a notebook.

import os, re, glob, time, tempfile, shutil, subprocess, sys, json
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Tuple, Optional, Dict

import numpy as np

# --------------------- small utils ---------------------

_DOT_RE = re.compile(r'^[().]+')

def check_vienna(verbose: bool = True) -> bool:
    need = ["RNAfold","RNAsubopt","RNAplfold"]
    missing=[]
    for name in need:
        try:
            subprocess.run([name, "-h"], capture_output=True, check=False)
        except Exception:
            missing.append(name)
    if missing and verbose:
        print("❌ Missing:", ", ".join(missing), file=sys.stderr)
        print("   Install on Debian/Ubuntu:", file=sys.stderr)
        print("   sudo apt-get -qq update && sudo apt-get install -y vienna-rna", file=sys.stderr)
    return len(missing) == 0

def _run(cmd, inp: Optional[str] = None, cwd: Optional[str] = None, env: Optional[Dict[str,str]] = None):
    env2 = os.environ.copy()
    env2.update(env or {})
    # Determinism + no thread oversubscription
    env2.setdefault("VRNA_RANDSEED", "42")
    env2.setdefault("OMP_NUM_THREADS", "1")
    env2.setdefault("OPENBLAS_NUM_THREADS", "1")
    env2.setdefault("MKL_NUM_THREADS", "1")
    return subprocess.run(cmd, input=inp, text=True, capture_output=True, check=False, cwd=cwd, env=env2)

def _to_rna(seq: str) -> str:
    return seq.strip().upper().replace("T","U")

def _parse_structs(stdout: str, L: int) -> List[str]:
    out=[]
    for ln in stdout.splitlines():
        m=_DOT_RE.match(ln.strip())
        if m:
            s=m.group(0)
            if len(s)==L: out.append(s)
    return out

def _dot_from_any_output(text: str, L: int) -> str:
    best=""
    for ln in text.splitlines():
        m=_DOT_RE.match(ln.strip())
        if m:
            s=m.group(0)
            if len(s)==L and len(s)>len(best): best=s
    return best

# ----------- MFE & dot-bracket -> one-hot PHIME -----------

def rnafold_mfe(sequence: str, temp_c: float = 37.0) -> Tuple[str, float]:
    seq=_to_rna(sequence)
    cp=_run(['RNAfold','--noPS','-T',f'{temp_c:.2f}'], f">seq\n{seq}\n")
    db=_dot_from_any_output(cp.stdout, len(seq))
    m = re.search(r'[-+]?\d+(?:\.\d+)?', cp.stdout)
    mfe=float(m.group(0)) if m else 0.0
    return db, mfe

def _pair_table(db: str) -> List[int]:
    L=len(db); pt=[0]*(L+1); st=[]
    for i,ch in enumerate(db,1):
        if ch=='(':
            st.append(i)
        elif ch==')' and st:
            j=st.pop(); pt[i]=j; pt[j]=i
    return pt

def _kids(pt, i, j):
    kids=[]; p=i+1
    while p<j:
        q=pt[p]
        if q==0: p+=1
        elif i<p<q<j: kids.append((p,q)); p=q+1
        else: p+=1
    return kids

def _anno_pair(pt,i,j,H,I,M):
    ks=_kids(pt,i,j)
    if len(ks)==0:
        for t in range(i+1,j):
            if pt[t]==0: H[t-1]+=1
        return
    if len(ks)==1:
        k,l=ks[0]
        for t in range(i+1,k):
            if pt[t]==0: I[t-1]+=1
        for t in range(l+1,j):
            if pt[t]==0: I[t-1]+=1
        _anno_pair(pt,k,l,H,I,M); return
    prev=i
    for (k,l) in ks:
        for t in range(prev+1,k):
            if pt[t]==0: M[t-1]+=1
        _anno_pair(pt,k,l,H,I,M); prev=l
    for t in range(prev+1,j):
        if pt[t]==0: M[t-1]+=1

def _tops(db: str):
    st=[]; out=[]
    for idx,ch in enumerate(db,1):
        if ch=='(':
            st.append(idx)
        elif ch==')':
            if not st: continue
            i=st.pop(); j=idx
            if len(st)==0: out.append((i,j))
    return out

def onehot_phime_from_db(db: str)->np.ndarray:
    L=len(db); P=np.zeros(L); H=np.zeros(L); I=np.zeros(L); M=np.zeros(L); E=np.zeros(L)
    pt=_pair_table(db)
    for i in range(1,L+1):
        if pt[i]!=0: P[i-1]+=1
    prev=0
    for (i,j) in _tops(db):
        for t in range(prev+1,i):
            if pt[t]==0: E[t-1]+=1
        _anno_pair(pt,i,j,H,I,M)
        prev=j
    for t in range(prev+1,L+1):
        if pt[t]==0: E[t-1]+=1
    return np.vstack([P,H,I,M,E]).T

# ---------------- RNAsubopt ensemble (adaptive) ----------------

def rnasubopt_structs(sequence: str, samples: int, temp_c: float = 37.0) -> List[str]:
    seq=_to_rna(sequence); L=len(seq)
    variants = [
        (['RNAsubopt','-p',str(samples),'-T',f'{temp_c:.2f}','--stochBT'], f">seq\n{seq}\n"),
        (['RNAsubopt','-p',str(samples),'--stochBT'],                          f">seq\n{seq}\n"),
        (['RNAsubopt','-p',str(samples),'-T',f'{temp_c:.2f}'],                 seq+"\n"),
        (['RNAsubopt','-p',str(samples)],                                      seq+"\n"),
    ]
    for cmd, inp in variants:
        cp=_run(cmd, inp)
        structs=_parse_structs(cp.stdout, L)
        if structs: return structs
    return []

def subopt_loop_props_adaptive(sequence: str, temp_c: float = 37.0,
                               batch: int = 200, max_samp: int = 1500,
                               tol: float = 1e-2) -> np.ndarray:
    """Return ensemble average PHIME from adaptive RNAsubopt sampling."""
    seq=_to_rna(sequence); L=len(seq)
    acc=np.zeros((L,5), dtype=float)
    n=0

    structs = rnasubopt_structs(seq, samples=batch, temp_c=temp_c)
    if not structs:
        out=np.zeros((L,5), np.float32); out[:,4]=1.0; return out
    for db in structs:
        acc += onehot_phime_from_db(db)
    n += len(structs)

    return (acc / float(n)).astype(np.float32)

# ------------- RNAplfold: parse u_1..u_U --------------

def _parse_lunp_matrix(path: str, L: int, U: int, strict: bool) -> Optional[np.ndarray]:
    mat = np.full((L, U), np.nan, dtype=float)
    row_idx = 0
    with open(path, 'r') as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith('#') or ln.startswith('>'):
                continue
            toks = ln.split()
            nums=[]
            for t in toks:
                try: nums.append(float(t))
                except: pass
            if not nums:
                continue
            if len(nums) == U + 1:
                i = int(round(nums[0]))
                if 1 <= i <= L:
                    mat[i-1, :] = nums[1:U+1]
            elif len(nums) >= U:
                if row_idx < L:
                    mat[row_idx, :] = nums[-U:]
                    row_idx += 1
    if np.isnan(mat).any():
        filled = np.count_nonzero(~np.isnan(mat))
        if filled < L * max(1, U//2):
            if strict:
                raise RuntimeError(f"RNAplfold parse incomplete for {path} (got {filled}/{L*U})")
            return None
        for u in range(U):
            col = mat[:,u]
            mask = np.isnan(col)
            if mask.any():
                idx = np.where(~mask)[0]
                if len(idx)==0: continue
                last = col[idx[0]]
                for i in range(L):
                    if not mask[i]:
                        last = col[i]
                    else:
                        col[i] = last
                mat[:,u] = col
    return np.clip(mat, 0.0, 1.0).astype(np.float32)

def rnaplfold_unpaired_multi(sequence: str, temp_c: float = 37.0,
                             U: int = 9, W: Optional[int]=None,
                             strict: bool = False) -> Optional[np.ndarray]:
    seq=_to_rna(sequence); L=len(seq); W=W or L
    tmp=tempfile.mkdtemp(prefix="plfold_")
    try:
        cmd=['RNAplfold','-u',str(U),'-W',str(W),'-L',str(W),'-d','2','-T',f'{temp_c:.2f}',
             '--auto-id','--id-prefix','phime','-c','0']
        _run(cmd, f">seq\n{seq}\n", cwd=tmp)
        candidates = sorted(glob.glob(os.path.join(tmp, "phime_*_lunp"))) or \
                     sorted(glob.glob(os.path.join(tmp, "*_lunp")))
        for path in candidates:
            mat=_parse_lunp_matrix(path, L, U, strict=strict)
            if mat is not None and mat.shape==(L,U):
                return mat
        if strict:
            raise RuntimeError("RNAplfold output not found or unparsable.")
        return None
    finally:
        shutil.rmtree(tmp, ignore_errors=True)

# -------------------- Combined best-match --------------------

def phime_plfold_calibrated(sequence: str,
                            temp_c: float = 37.0,
                            batch: int = 200,
                            max_samp: int = 1500,
                            tol: float = 1e-2,
                            U: int = 9,
                            smoothing: float = 0.02,
                            strict_plfold: bool = False,
                            return_u: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    seq=_to_rna(sequence); L=len(seq)
    subopt = subopt_loop_props_adaptive(seq, temp_c=temp_c, batch=batch, max_samp=max_samp, tol=tol)  # [L,5]
    u_all = rnaplfold_unpaired_multi(seq, temp_c=temp_c, U=U, W=L, strict=strict_plfold)
    if u_all is None:
        S = np.clip(subopt, 0, 1).astype(np.float32)
        s = S.sum(axis=1, keepdims=True); s[s==0]=1.0
        S = (S/s).astype(np.float32)
        return (S, None) if return_u else (S, None)

    u1 = u_all[:,0]
    db_mfe,_ = rnafold_mfe(seq, temp_c=temp_c)
    mfe = onehot_phime_from_db(db_mfe) if db_mfe else None

    out = np.zeros((L,5), dtype=float)
    eps = 1e-6
    for i in range(L):
        out[i,0] = max(0.0, 1.0 - float(u1[i]))  # P
        u_sub = float(subopt[i,1:].sum())
        if u_sub < eps:
            prior = np.zeros(4, dtype=float)  # H,I,M,E
            if mfe is not None and mfe[i,0] < 0.5:
                k = int(np.argmax(mfe[i,1:]))
                prior[k] = 1.0
            else:
                prior[3] = 1.0  # external
            alpha = 1.0 - float(smoothing)
            out[i,1:] = float(u1[i]) * (alpha*prior + smoothing*0.25)
        else:
            props = subopt[i,1:] / u_sub
            out[i,1:] = props * float(u1[i])
        rs = out[i].sum()
        if rs <= 0:
            out[i,4]=1.0
        elif abs(rs-1.0) > 1e-8:
            out[i] /= rs

    S = out.astype(np.float32)
    return (S, u_all) if return_u else (S, u_all)

# -------------------------- Conversions --------------------------

def phime_to_plum(S_phime: np.ndarray) -> np.ndarray:
    P = S_phime[:, 0]
    L = S_phime[:, 1]   # hairpin
    U = S_phime[:, 4]   # external -> unstructured
    M = S_phime[:, 2] + S_phime[:, 3]
    S_plum = np.stack([P, L, U, M], axis=-1)
    S_plum = S_plum / (S_plum.sum(axis=1, keepdims=True) + 1e-12)
    return S_plum.astype(np.float32)

def u_window_aggregates(u_all: Optional[np.ndarray], l_start:int=6, l_end:int=9):
    if u_all is None:
        return None, None
    U = u_all.shape[1]
    a = max(1, l_start) - 1
    b = min(U, l_end)
    if b <= a:
        return None, None
    slab = u_all[:, a:b]
    return np.min(slab, axis=1).astype(np.float32), np.mean(slab, axis=1).astype(np.float32)

# -------------------------- Public API (Notebook) --------------------------

def compute_struct_for_sequence(seq: str,
                                temp_c: float = 37.0,
                                U: int = 9,
                                batch: int = 200,
                                max_samp: int = 1500,
                                tol: float = 1e-2,
                                smoothing: float = 0.02,
                                strict_plfold: bool = False) -> Dict[str, np.ndarray]:
    """
    Return dict with:
      PHIME [L,5], PLUM [L,4], u_min6_9 [L], u_mean6_9 [L], L [1], seq [1]
    """
    S_phime, u_all = phime_plfold_calibrated(
        seq, temp_c=temp_c, batch=batch, max_samp=max_samp, tol=tol,
        U=U, smoothing=smoothing, strict_plfold=strict_plfold, return_u=True
    )
    umin, umean = u_window_aggregates(u_all, 6, 9)
    if umin is None:
        L = len(seq)
        umin = np.zeros((L,), np.float32)
        umean = np.zeros((L,), np.float32)
    out = {
        "PHIME": S_phime.astype(np.float32),
        "PLUM": phime_to_plum(S_phime).astype(np.float32),
        "u_min6_9": umin.astype(np.float32),
        "u_mean6_9": umean.astype(np.float32),
        "L": np.array([len(seq)], dtype=np.int32),
        "seq": np.array([_to_rna(seq)], dtype=object),
    }
    return out

def _valid_existing_npz(path: str, L: int) -> bool:
    try:
        with np.load(path) as z:
            if "PHIME" not in z: return False
            arr = z["PHIME"]
            return (arr.ndim == 2) and (arr.shape[0] == L) and (arr.shape[1] == 5)
    except Exception:
        return False

def _write_manifest(out_dir: str, n_total: int, dtype: str = "float16"):
    manifest = {
        "key": "PHIME",
        "channels": ["P","H","I","M","E"],
        "pattern": "seq_%06d.npz",
        "count": int(n_total),
        "dtype": dtype,
        "note": "Each NPZ has PHIME[L,5]. Index i corresponds to line i in the input sequences file."
    }
    with open(os.path.join(out_dir, "FORMAT.json"), "w") as f:
        json.dump(manifest, f, indent=2)

def generate_struct_features(in_path: str, out_dir: str,
                             workers: int = 4,
                             temp_c: float = 37.0,
                             U: int = 9,
                             batch: int = 200,
                             max_samp: int = 1500,
                             tol: float = 1e-2,
                             smoothing: float = 0.02,
                             strict_plfold: bool = False,
                             save_plum: bool = True,
                             save_aggs: bool = True,
                             dtype: str = "float16",
                             start: Optional[int] = None,
                             end: Optional[int] = None,
                             resume: bool = True,
                             verbose_every: int = 1000) -> Dict[str,int]:
    """
    Read sequences (one per line) from `in_path`, write seq_%06d.npz with at least PHIME[L,5].
    Returns summary dict with counts.
    """
    if not check_vienna(verbose=True):
        raise RuntimeError("ViennaRNA tools missing.")
    os.makedirs(out_dir, exist_ok=True)

    # Load all sequences
    seqs=[]
    with open(in_path, 'r') as f:
        for ln in f:
            s=_to_rna(ln.strip())
            if s: seqs.append(s)
    n_total = len(seqs)

    # Shard
    a = 0 if start is None else max(0, int(start))
    b = n_total if end is None else min(n_total, int(end))
    seqs_shard = seqs[a:b]
    n = len(seqs_shard)
    print(f"Found {n_total} sequences; processing [{a}:{b}) → {n} items. Writing to: {out_dir}")

    # Worker function (defined at top-level so it's picklable)
    def _worker(i: int, s: str):
        try:
            path = os.path.join(out_dir, f"seq_{i:06d}.npz")
            L = len(s)
            if resume and os.path.exists(path) and _valid_existing_npz(path, L):
                return (i, True, "cached")
            d = compute_struct_for_sequence(
                s, temp_c=temp_c, U=U, batch=batch, max_samp=max_samp,
                tol=tol, smoothing=smoothing, strict_plfold=strict_plfold
            )
            # trim optional keys if requested
            np_dtype = np.float16 if str(dtype) == "float16" else np.float32
            out = {"PHIME": d["PHIME"].astype(np_dtype)}
            if save_plum:
                out["PLUM"] = d["PLUM"].astype(np_dtype)
            if save_aggs:
                out["u_min6_9"] = d["u_min6_9"].astype(np_dtype)
                out["u_mean6_9"] = d["u_mean6_9"].astype(np_dtype)
            out["L"] = d["L"]; out["seq"] = d["seq"]
            np.savez_compressed(path, **out)
            return (i, True, "ok")
        except Exception as e:
            return (i, False, str(e))

    ok=0; fails=0
    t0 = time.perf_counter()
    workers = max(1, int(workers))

    # If workers==1, run serially (useful for restricted notebook envs)
    if workers == 1:
        for idx, s in enumerate(seqs_shard, start=a):
            i, success, msg = _worker(idx, s)
            ok += int(success); fails += int(not success)
            if ok % max(1, verbose_every) == 0:
                print(f"[{ok}/{n}] last={msg}")
    else:
        with ProcessPoolExecutor(max_workers=workers) as ex:
            futs = [ex.submit(_worker, idx, s) for idx, s in enumerate(seqs_shard, start=a)]
            for fut in as_completed(futs):
                i, success, msg = fut.result()
                ok += int(success); fails += int(not success)
                if ok % max(1, verbose_every) == 0:
                    print(f"[{ok}/{n}] last={msg}")

    _write_manifest(out_dir, n_total, dtype)
    t1 = time.perf_counter()
    print(f"Done. OK: {ok}, Failed: {fails}. Time: {t1-t0:.1f}s")
    return {"ok": ok, "failed": fails, "total": n, "start": a, "end": b}

def verify_struct_outputs(in_path: str, out_dir: str) -> Dict[str, object]:
    """Check that out_dir has valid NPZs for each line in the sequences file."""
    seqs=[]
    with open(in_path, 'r') as f:
        for ln in f:
            s=_to_rna(ln.strip())
            if s: seqs.append(s)
    n = len(seqs)
    missing=[]
    ok=0
    for i, s in enumerate(seqs):
        p = os.path.join(out_dir, f"seq_{i:06d}.npz")
        if os.path.exists(p) and _valid_existing_npz(p, len(s)):
            ok += 1
        else:
            missing.append(i)
    print(f"Expected {n} NPZs. Found valid: {ok}. Missing/invalid: {len(missing)}.")
    if missing:
        print("First few missing indices:", missing[:20])
    return {"expected": n, "valid": ok, "missing": missing}


In [None]:
import rna_struct_features_jupyter as rsf

# (optional) make sure ViennaRNA tools are on PATH
rsf.check_vienna()

# generate features for training RNAs
rsf.generate_struct_features(
    in_path="training_seqs.txt",
    out_dir="struct_train",
    workers=1,          # if you get issues in notebooks, set workers=1
    dtype="float16"     # saves disk; 'float32' also fine
)

# (optional) verify all NPZ files exist & look right
rsf.verify_struct_outputs("training_seqs.txt", "struct_train")


In [None]:
# Zip the generated files in ./struct_train into struct_train.zip
import shutil

archive_path = shutil.make_archive("struct_train", "zip", root_dir=".", base_dir="struct_train")
print("Created:", archive_path)


In [None]:
# Unzip struct_train.zip into ./struct_train_unzipped
import shutil

shutil.unpack_archive("/content/struct_train.zip", "struct_train", "zip")
print("Extracted to: struct_train")


# model configs + utils

In [None]:
#!/usr/bin/env python3
"""
RNA–RBP Binding Intensity Predictor (two-tower + ESM-2 ⊕ ProtT5 + cosine bilinear)
===================================================================================

- RBP-disjoint validation (unseen RBPs), same RNA/k-mer universe by default.
- Protein side: frozen ESM-2 and ProtT5 embeddings (mean-pooled), concatenated, cached to .npy.
- Optional fusion hygiene:
    • PROT_SRC_ZSCORE: z-score each source (ESM, ProtT5) using TRAIN RBPs only
    • PROT_SWAPDROP_P: with prob p, drop either source (never both) during TRAIN to encourage complementarity
- Intensities are log1p + lightly clipped, then per-RBP z-scored (train RNAs).
- Loss = 0.2 * Huber + 0.8 * (1 - Pearson per-RBP across RNAs in-batch).
- Head: low-rank cosine bilinear (rank=256) for correlation-friendly scoring.

- NEW: Optional PWM-based dedup of TRAIN RBPs from a JSON report (your clusters & representatives).
  If CFG.DEDUP_PWM_JSON exists, training matrix and RBP list are sliced to representative columns.
  A mapping CSV is written at CFG.DEDUP_WRITE_MAP. Embedding caches use a "_dedup" suffix.

- NEW: Optional RNA secondary-structure features (PHIME: P,H,I,M,E) per base.
  If CFG.RNA_USE_STRUCT=True, TrainData can load per-seq NPZs from CFG.RNA_STRUCT_DIR and
  provide padded tensors via load_struct_batch(...). Later cells will pass these into RNATower.

Author: You + ChatGPT (Aug 2025)
"""

import os, random, time, json, csv
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModel

# =============================
# CONFIG (edit here)
# =============================
@dataclass
class Config:
    # File paths
    TRAIN_MATRIX_PATH: str = 'training_data2.txt'
    TRAIN_RBPS_PATH: str = 'training_RBPs2.txt'
    TRAIN_SEQS_PATH: str = 'training_seqs.txt'

    TEST_RBPS_PATH: str = 'test_RBPs2.txt'
    TEST_SEQS_PATH: str = 'test_seqs.txt'

    # --- EMA of model parameters ---
    EMA_USE: bool = True          # turn EMA on/off
    EMA_DECAY: float = 0.999      # typical: 0.999–0.9999 (higher = smoother, slower to react)
    EMA_EVAL: bool = True         # use EMA weights for evaluation/checkpoints

    # --- NEW: RNA structure features (PHIME: P,H,I,M,E) ---
    RNA_USE_STRUCT: bool = True                 # flip to True after generating NPZs
    RNA_STRUCT_DIR: Optional[str] = '/content/struct_train/struct_train'  # folder with seq_000000.npz, seq_000001.npz, ...
    RNA_STRUCT_DIM: int = 5                      # PHIME has 5 channels
    RNA_STRUCT_CACHE: int = 120678                   # LRU capacity (#seqs) for training-time cache

    # --- (optional) training RBP dedup (PWM-based) ---
    # If this JSON exists, we apply column dedup BEFORE computing mu/sd and embeddings.
    # JSON schema: {"clusters":[{"members":[...], "representative_col": int}, ...], ...}
    # Indices can be 0-based (preferred). If they look 1-based, we'll auto-shift.
    DEDUP_PWM_JSON: Optional[str] = 'dedup_pwm_out/training_duplicate_clusters_pwm.json'
    DEDUP_WRITE_MAP: str = 'cache/training_column_map.csv'  # old->new mapping

    # Random seeds
    SEED: int = 1

    # Model dims
    D_MODEL: int = 256
    RNA_VOCAB: str = 'ACGU'      # T will be mapped to U
    RANK: int = 512
    GATE_STRENGTH: float = 0.5
    # RNA tower
    RNA_USE_TRANSFORMER: bool = True
    RNA_TRANSFORMER_LAYERS: int = 2
    RNA_NHEAD: int = 4
    RNA_DROPOUT: float = 0.3
    RNA_MAX_LEN: int = 64        # safety cap for positional embeddings

    # Training
    BATCH_RBPS: int = 8          # proteins per batch
    BATCH_RNAS: int = 2048       # RNAs per batch
    EPOCHS: int = 50
    STEPS_PER_EPOCH: int = 300   # tune to time budget
    LR: float = 1e-4
    WEIGHT_DECAY: float = 1e-2
    HUBER_DELTA: float = 1.0
    CORR_WEIGHT: float = 0.8     # loss = (1-CORR_WEIGHT)*Huber + CORR_WEIGHT*(1-Pearson)
    MIXED_PRECISION: bool = False

    # Validation setup
    USE_KMER_DISJOINT_RNA: bool = False
    VAL_RBPS_COUNT: int = 0
    VAL_RBPS_SEED: int = 176
    VAL_RBPS_INDICES: Optional[List[int]] = None  # can be original indices; will auto-map if dedup applied

    # (optional) different probe libraries
    KMER_K: int = 9
    VAL_KMER_FRACTION: float = 0.2

    # Target preprocessing
    LOG1P: bool = True
    CLIP_PCTL: float = 99.5

    # Evaluation / caching
    EVAL_ZSCORE_TARGETS: bool = True
    CACHE_DIR: str = 'cache'
    SAVE_EVERY: int = 2

    # ProtT5
    USE_PROTT5: bool = False
    PROTT5_MODEL_ID: str = "Rostlab/prot_t5_xl_half_uniref50-enc"
    PROT_EMB_DIM: int = 1024           # overwritten if ESM-2 is added (→ 2304)
    PROT_EMB_BATCH: int = 8
    PROT_EMB_CACHE_TRAIN: str = 'cache/prot_emb_train.npy'
    PROT_EMB_CACHE_TEST: str = 'cache/prot_emb_test.npy'   # unused here but kept for completeness

    # ESM-2 (HF names supported; short names via map below)
    USE_ESM2: bool = True
    ESM2_MODEL_ID: str = "facebook/esm2_t48_15B_UR50D"
    ESM_EMB_DIM: int = 1280            # info only
    ESM_EMB_BATCH: int = 8
    ESM_EMB_CACHE_TRAIN: str = 'cache/esm_emb_train.npy'
    ESM_EMB_CACHE_TEST: str = 'cache/esm_emb_test.npy'     # unused here but kept for completeness

    # Fusion hygiene / regularization
    PROT_SRC_ZSCORE: bool = True
    PROT_SWAPDROP_P: float = 0.0

    # Prot MLP
    PROT_MLP_HIDDEN: int = 256
    PROT_DROPOUT: float = 0.3
    PROT_MLP_WD: float = 3e-4

CFG = Config()
os.makedirs(CFG.CACHE_DIR, exist_ok=True)

# =============================
# Reproducibility
# =============================
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(CFG.SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================
# Data loading & helpers
# =============================
def _read_lines(path: str) -> List[str]:
    with open(path) as f:
        return [ln.strip() for ln in f if ln.strip()]

def load_protein_seqs(path: str) -> List[str]:
    return [s.upper() for s in _read_lines(path)]

def load_rna_seqs(path: str) -> List[str]:
    return [s.upper().replace('T', 'U') for s in _read_lines(path)]

def load_matrix(path: str) -> np.ndarray:
    return np.loadtxt(path, dtype=np.float32)

def central_kmer(seq: str, k: int) -> str:
    start = (len(seq) - k) // 2
    return seq[start:start + k]

def _with_suffix(path: str, suffix: str) -> str:
    root, ext = os.path.splitext(path)
    return root + suffix + (ext if ext else '')

# =============================
# NEW: PWM-JSON based TRAIN RBP dedup
# =============================
def _maybe_shift_to_zero_based(indices: List[int], M: int) -> List[int]:
    """If indices look 1-based (min>=1 and max==M), shift to 0-based."""
    if not indices: return indices
    mn, mx = min(indices), max(indices)
    if mn >= 1 and mx == M:
        return [i - 1 for i in indices]
    return indices

def apply_rbp_dedup_by_json(
    rbp_seqs: List[str],
    Y: np.ndarray,
    json_path: Optional[str],
    map_csv_out: str
) -> Tuple[List[str], np.ndarray, List[int], Dict[str, Any]]:
    """
    Slice RBP columns of Y and rbp_seqs based on PWM cluster JSON.
    Returns (rbp_seqs_dedup, Y_dedup, keep_idx, info).
    If json missing/None -> returns inputs unchanged with info['applied']=False.
    """
    M = Y.shape[1]
    if not json_path or not os.path.exists(json_path):
        print("[Dedup] No JSON provided or file missing → skipping.")
        keep_idx = list(range(M))
        info = {"applied": False, "n_clusters": 0, "n_groups_gt1": 0, "dropped": 0}
        return rbp_seqs, Y, keep_idx, info

    with open(json_path, "r") as f:
        report = json.load(f)
    clusters = report.get("clusters", [])
    # Collect all mentioned indices, check base
    mentioned = []
    reps = []
    for c in clusters:
        mentioned.extend(c.get("members", []))
        reps.append(c.get("representative_col"))
    mentioned = _maybe_shift_to_zero_based(list(map(int, mentioned)), M)
    reps = _maybe_shift_to_zero_based(list(map(int, reps)), M)

    # Build rep mapping: default i->i
    rep_of = list(range(M))
    cluster_id_of = [-1] * M
    group_sizes = []
    for c, rep in zip(clusters, reps):
        members = _maybe_shift_to_zero_based(list(map(int, c.get("members", []))), M)
        group_sizes.append(len(members))
        for m in members:
            if 0 <= m < M:
                rep_of[m] = rep
                cluster_id_of[m] = c.get("cluster_id", -1)

    # Keep the original column order of first appearance of each representative
    first_seen = {}
    for j in range(M):
        r = rep_of[j]
        if r not in first_seen:
            first_seen[r] = j
    keep_idx = sorted(first_seen.keys(), key=lambda r: first_seen[r])

    dropped = M - len(keep_idx)
    n_groups_gt1 = sum(1 for g in group_sizes if g > 1)

    # Build old->new mapping (after dedup)
    new_pos = {old_rep: i for i, old_rep in enumerate(keep_idx)}
    os.makedirs(os.path.dirname(map_csv_out), exist_ok=True)
    with open(map_csv_out, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["old_col", "cluster_id", "representative_old_col", "kept_as_new_col", "dropped"])
        for j in range(M):
            rep = rep_of[j]
            kept = (j == rep)
            w.writerow([
                j,
                cluster_id_of[j],
                rep,
                (new_pos[rep] if kept else ""),
                int(not kept)
            ])

    print("[Dedup] Applied PWM JSON:")
    print(f"   Original RBPs: {M} → Kept representatives: {len(keep_idx)} (dropped {dropped})")
    print(f"   Clusters total: {len(clusters)} | with >1 member: {n_groups_gt1}")
    print(f"   Mapping CSV: {map_csv_out}")

    # Slice data
    Y2 = Y[:, keep_idx].copy()
    rbp2 = [rbp_seqs[j] for j in keep_idx]
    info = {
        "applied": True,
        "keep_idx": keep_idx,
        "rep_of": rep_of,
        "cluster_id_of": cluster_id_of,
        "n_clusters": len(clusters),
        "n_groups_gt1": n_groups_gt1,
        "dropped": dropped
    }
    return rbp2, Y2, keep_idx, info

# =============================
# ProtT5 embedding (frozen)
# =============================
def load_or_compute_prott5_embeddings(
    seqs: List[str], cache_path: str, model_id: str, batch_size: int, device: torch.device
) -> np.ndarray:
    if os.path.exists(cache_path):
        E = np.load(cache_path)
        if E.shape[0] == len(seqs):
            print(f"[ProtT5] Loaded cached embeddings: {cache_path} {E.shape}")
            return E.astype(np.float32)
        print(f"[ProtT5] Cache size mismatch ({E.shape[0]} vs {len(seqs)}), recomputing...")

    print(f"[ProtT5] Computing embeddings with {model_id} ...")
    tok = T5Tokenizer.from_pretrained(model_id, legacy=True)
    t5 = T5EncoderModel.from_pretrained(
        model_id,
        torch_dtype=(torch.float16 if device.type=='cuda' else torch.float32)
    ).to(device).eval()

    outs = []
    with torch.no_grad():
        for i in range(0, len(seqs), batch_size):
            batch = [" ".join(s) for s in seqs[i:i+batch_size]]
            enc = tok(batch, padding=True, return_tensors="pt").to(device)
            rep = t5(**enc).last_hidden_state                   # [B, L, 1024]
            mask = enc["attention_mask"].unsqueeze(-1)          # [B, L, 1]
            pooled = (rep * mask).sum(1) / mask.sum(1)          # [B, 1024]
            outs.append(pooled.float().cpu().numpy())
            print(f"  ProtT5 {min(i+batch_size, len(seqs))}/{len(seqs)}")

    E = np.concatenate(outs, axis=0).astype(np.float32)
    np.save(cache_path, E)
    print(f"[ProtT5] Saved embeddings to {cache_path} with shape {E.shape}")
    del t5, tok
    if device.type == 'cuda': torch.cuda.empty_cache()
    return E

# =============================
# ESM-2 embedding (frozen, HF)
# =============================
_HF_ESM_MAP = {
    "esm2_t33_650M_UR50D": "facebook/esm2_t33_650M_UR50D",
    "esm2_t36_3B_UR50D": "facebook/esm2_t36_3B_UR50D",
    "esm2_t48_15B_UR50D": "facebook/esm2_t48_15B_UR50D",
}
def _resolve_esm_id(model_id: str) -> str:
    return model_id if model_id.startswith("facebook/") else _HF_ESM_MAP.get(model_id, model_id)

def load_or_compute_esm2_embeddings(
    seqs: List[str], cache_path: str, model_id: str, batch_size: int, device: torch.device
) -> np.ndarray:
    if os.path.exists(cache_path):
        E = np.load(cache_path)
        if E.shape[0] == len(seqs):
            print(f"[ESM2] Loaded cached embeddings: {cache_path} {E.shape}")
            return E.astype(np.float32)
        print(f"[ESM2] Cache size mismatch ({E.shape[0]} vs {len(seqs)}), recomputing...")

    hf_model_name = _resolve_esm_id(model_id)
    print(f"[ESM2] Computing embeddings with {hf_model_name} ...")

    tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
    model = AutoModel.from_pretrained(
        hf_model_name,
        torch_dtype=(torch.float16 if device.type=='cuda' else torch.float32)
    ).to(device).eval()
    print(f"[ESM2] Model embedding dimension: {model.config.hidden_size}")

    outs = []
    with torch.no_grad():
        for i in range(0, len(seqs), batch_size):
            batch = seqs[i:i+batch_size]
            inputs = tokenizer(batch, padding=True, truncation=True, max_length=1024, return_tensors="pt").to(device)
            hs = model(**inputs).last_hidden_state            # [B, L, D]
            am = inputs["attention_mask"]                     # [B, L]

            pooled = []
            for j in range(hs.size(0)):
                seq_len = int(am[j].sum().item())
                if seq_len > 2:
                    aa = hs[j, 1:seq_len-1, :]                # exclude <cls> and <eos>
                else:
                    aa = hs[j, 1:seq_len, :]
                pooled.append(aa.mean(0).float().cpu().numpy())
            outs.append(np.stack(pooled, axis=0))

            print(f"  ESM2 {min(i+batch_size, len(seqs))}/{len(seqs)}")
            if device.type == 'cuda' and (i // batch_size) % 10 == 0:
                torch.cuda.empty_cache()

    E = np.concatenate(outs, axis=0).astype(np.float32)
    np.save(cache_path, E)
    print(f"[ESM2] Saved embeddings to {cache_path} with shape {E.shape}")
    del model
    if device.type == 'cuda': torch.cuda.empty_cache()
    return E

# =============================
# Train data holder
# =============================
class TrainData:
    """Holds full training tensors and index splits; also saves μ,σ for z-scoring."""
    def __init__(self, cfg: Config):
        self.cfg = cfg
        print("\n[Load] Reading training files ...")
        rna_seqs = load_rna_seqs(cfg.TRAIN_SEQS_PATH)
        Y = load_matrix(cfg.TRAIN_MATRIX_PATH).astype(np.float32)   # (N_rna, N_rbp)
        prot_seqs = load_protein_seqs(cfg.TRAIN_RBPS_PATH)

        assert Y.shape[0] == len(rna_seqs), "Rows of matrix must equal number of RNA sequences"
        assert Y.shape[1] == len(prot_seqs), "Columns of matrix must equal number of RBP sequences"
        orig_M = Y.shape[1]

        # ---- NEW: apply PWM-JSON dedup on TRAIN RBPs (columns) ----
        prot_seqs, Y, keep_idx, dedup_info = apply_rbp_dedup_by_json(
            prot_seqs, Y, cfg.DEDUP_PWM_JSON, cfg.DEDUP_WRITE_MAP
        )
        self.dedup_info = dedup_info
        self.keep_idx = keep_idx  # old-col indices kept

        if cfg.LOG1P:
            clip = np.percentile(Y, cfg.CLIP_PCTL)
            Y = np.log1p(np.clip(Y, None, clip))
            print(f"[Preproc] Applied log1p with clip@{cfg.CLIP_PCTL}p (<= {clip:.3f})")

        self.rna_seqs, self.Y, self.prot_seqs = rna_seqs, Y, prot_seqs
        self.N_rna, self.N_rbp = len(rna_seqs), len(prot_seqs)
        print(f"    RNAs: {self.N_rna:,} | RBPs: {self.N_rbp} (from {orig_M} before dedup)")

        # ---- NEW: structure feature loader state ----
        self.struct_dir = cfg.RNA_STRUCT_DIR if (cfg.RNA_USE_STRUCT and cfg.RNA_STRUCT_DIR) else None
        self._struct_cache = OrderedDict()
        self._struct_cache_cap = int(cfg.RNA_STRUCT_CACHE)

        # RNA split
        if cfg.USE_KMER_DISJOINT_RNA:
            print("[Split] Building 9-mer-disjoint RNA split ...")
            kmers = [central_kmer(s, cfg.KMER_K) for s in rna_seqs]
            uniq = list(sorted(set(kmers)))
            rng = np.random.default_rng(cfg.SEED); rng.shuffle(uniq)
            n_val_k = int(len(uniq) * cfg.VAL_KMER_FRACTION)
            val_kset = set(uniq[:n_val_k])
            train_idx, val_idx = zip(*[(i, None) if km not in val_kset else (None, i) for i, km in enumerate(kmers)])
            self.train_idx = np.array([i for i in train_idx if i is not None], dtype=np.int64)
            self.val_idx   = np.array([i for i in val_idx   if i is not None], dtype=np.int64)
            print(f"    Train RNAs: {len(self.train_idx):,} | Val RNAs: {len(self.val_idx):,} | Unique 9-mers: {len(uniq):,}")
        else:
            print("[Split] Using ALL RNAs for train & val (same k-mer universe as test) ...")
            self.train_idx = np.arange(self.N_rna, dtype=np.int64)
            self.val_idx   = self.train_idx.copy()
            print(f"    RNAs used: train={len(self.train_idx):,} | val={len(self.val_idx):,}")

        # RBP split
        print("[Split] Building RBP-disjoint validation ...")
        all_rbps = np.arange(self.N_rbp)

        if cfg.VAL_RBPS_INDICES is not None:
            # If user supplied original indices, map them via representative mapping
            orig_inds = np.array(cfg.VAL_RBPS_INDICES, dtype=int)
            if dedup_info.get("applied"):
                rep_of = np.array(dedup_info["rep_of"], dtype=int)
                keep = np.array(self.keep_idx, dtype=int)
                # Map each original to its representative, then to new position
                rep_mapped = rep_of[orig_inds]
                # new index = position in keep list
                pos = {old_rep: i for i, old_rep in enumerate(keep)}
                val_rbps = np.array([pos[r] for r in rep_mapped if r in pos], dtype=int)
                val_rbps = np.unique(val_rbps)
                if len(val_rbps) < len(orig_inds):
                    print(f"[Split] Warning: some supplied VAL_RBPS_INDICES collapsed under dedup (unique {len(val_rbps)}).")
            else:
                val_rbps = np.unique(orig_inds)
        else:
            rng = np.random.default_rng(cfg.VAL_RBPS_SEED)
            n_val = min(cfg.VAL_RBPS_COUNT, self.N_rbp)
            val_rbps = rng.choice(all_rbps, size=n_val, replace=False)

        self.train_rbps = np.setdiff1d(all_rbps, val_rbps).astype(np.int64)
        self.val_rbps   = np.array(val_rbps, dtype=np.int64)
        print(f"    Train RBPs: {len(self.train_rbps)} | Val RBPs: {len(self.val_rbps)}")
        print(f"    Held-out RBP indices (after dedup if applied): {self.val_rbps.tolist()}")

        # Per-RBP z-score stats from TRAIN RNAs
        print("[Norm] Computing per-RBP z-score stats from TRAIN RNAs ...")
        Y_train = Y[self.train_idx, :]
        self.mu = Y_train.mean(axis=0).astype(np.float32)
        self.sd = Y_train.std(axis=0).astype(np.float32); self.sd[self.sd < 1e-6] = 1.0
        np.save(os.path.join(cfg.CACHE_DIR, 'mu_train.npy'), self.mu)
        np.save(os.path.join(cfg.CACHE_DIR, 'sd_train.npy'), self.sd)

    # ---- structure feature helpers (PHIME) ----
    def _struct_path(self, idx: int) -> str:
        return os.path.join(self.struct_dir, f"seq_{idx:06d}.npz")

    def _load_struct_one(self, idx: int, L_expected: int) -> np.ndarray:
        C = CFG.RNA_STRUCT_DIM
        if not self.struct_dir:
            S = np.zeros((L_expected, C), np.float32); S[:, 4] = 1.0
            return S

        key = (int(idx), int(L_expected))
        if key in self._struct_cache:
            val = self._struct_cache.pop(key)
            self._struct_cache[key] = val
            return val

        S = None
        loaded = False
        src_name = "fallback"
        try:
            with np.load(self._struct_path(idx)) as npz:
                if "PHIME" in npz:
                    S = np.array(npz["PHIME"], dtype=np.float32)  # (L,5)
                    loaded = True
                    src_name = "PHIME"
                elif "PLUM" in npz:
                    P, Lc, Uc, Mc = (np.array(npz["PLUM"], dtype=np.float32).T)
                    I = 0.5 * Mc
                    M2 = Mc - I
                    S = np.stack([P, Lc, I, M2, Uc], axis=-1).astype(np.float32)
                    loaded = True
                    src_name = "PLUM→PHIME"
        except Exception:
            S = None

        if S is None:
            S = np.zeros((L_expected, C), np.float32); S[:, 4] = 1.0

        # One-time confirmation print *only* if a real NPZ was loaded
        if loaded and not getattr(self, "_struct_confirm_printed", False):
            print(f"[Struct] Loaded {src_name} for seq_{idx:06d} from {self._struct_path(idx)}; shape={S.shape}")
            self._struct_confirm_printed = True

        # length fixups (unchanged from your version) ...
        if S.shape[0] != L_expected:
            if S.shape[0] > L_expected:
                S = S[:L_expected]
            else:
                pad = np.zeros((L_expected - S.shape[0], S.shape[1]), np.float32)
                S = np.vstack([S, pad])

        # cache put (unchanged) ...
        self._struct_cache[key] = S
        if len(self._struct_cache) > self._struct_cache_cap:
            self._struct_cache.popitem(last=False)
        return S


    def load_struct_batch(self, indices: np.ndarray, max_len: Optional[int] = None) -> torch.Tensor:
        """
        Build a [B, Lmax, 5] tensor aligned to tokenize_rna_batch padding.
        Returns CPU tensor (float32). Caller can .to(device).
        """
        if indices is None or len(indices) == 0:
            return torch.zeros((0, 0, CFG.RNA_STRUCT_DIM), dtype=torch.float32)
        lens = [len(self.rna_seqs[int(i)]) for i in indices]
        Lmax = max_len or (max(lens) if lens else 0)
        B = len(indices); C = CFG.RNA_STRUCT_DIM
        out = np.zeros((B, Lmax, C), np.float32)
        for b, (i, L) in enumerate(zip(indices, lens)):
            S = self._load_struct_one(int(i), L)
            out[b, :L, :] = S
        return torch.from_numpy(out)

# =============================
# Tokenization & batching
# =============================
# ---- RNA vocab with a real PAD id (single source of truth) ----
PAD_ID = len(CFG.RNA_VOCAB)                # = 4 when CFG.RNA_VOCAB == 'ACGU'
RNA_TO_IDX = {ch: i for i, ch in enumerate(CFG.RNA_VOCAB)}  # real tokens are 0..3

def tokenize_rna_batch(seqs: List[str]) -> torch.Tensor:
    lens = [len(s) for s in seqs]
    max_len = max(lens) if lens else 0
    out = torch.full((len(seqs), max_len), fill_value=PAD_ID, dtype=torch.long)  # fill with PAD
    for i, s in enumerate(seqs):
        ids = [RNA_TO_IDX.get(ch, 0) for ch in s]  # unknowns → 'A' (0), fine for RNAcompete
        if ids:
            out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
    return out

class BatchSampler:
    """Yields (rbp_idx_list, rna_idx_list) for each training step."""
    def __init__(self, data: TrainData, cfg: Config):
        self.data, self.cfg = data, cfg
        self.rng = np.random.default_rng(cfg.SEED)
    def sample(self) -> Tuple[np.ndarray, np.ndarray]:
        rbps = self.rng.choice(self.data.train_rbps, size=CFG.BATCH_RBPS, replace=False)
        rnas = self.rng.choice(self.data.train_idx,   size=CFG.BATCH_RNAS, replace=False)
        return rbps.astype(np.int64), rnas.astype(np.int64)


## model components

In [None]:
# =============================
# Model components
# =============================
class PositionalEncoding(nn.Module):
    def __init__(self, dim: int, max_len: int):
        super().__init__()
        self.pe = nn.Embedding(max_len, dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Clamp positions to avoid OOB if a sequence exceeds max_len
        L = x.size(1)
        pos = torch.arange(L, device=x.device)
        pos = pos.clamp_max(self.pe.num_embeddings - 1).unsqueeze(0)
        return x + self.pe(pos)

class ConvBlock(nn.Module):
    def __init__(self, dim: int, kernel: int, dilation: int, dropout: float):
        super().__init__()
        padding = (kernel - 1) // 2 * dilation
        self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=padding,
                              dilation=dilation, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(dim)
    def forward(self, x: torch.Tensor, mask_1d: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)                         # [B, D, L]
        y = F.gelu(y)
        y = y.transpose(1, 2)                    # [B, L, D]
        y = self.ln(y)
        y = y.transpose(1, 2)
        y = self.dropout(y)
        out = x + y
        return out * mask_1d.unsqueeze(1).to(out.dtype)

class GatedPooling(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.proj = nn.Linear(dim, 1)
        self.alpha = nn.Parameter(torch.tensor(0.2))
        self.log_sigma = nn.Parameter(torch.log(torch.tensor(6.0)))
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        scores = self.proj(x).squeeze(-1)                        # [B,L]
        B, L = scores.size()
        pos = torch.arange(L, device=x.device).float().unsqueeze(0).expand(B, -1)
        lens = (mask.sum(1) if mask is not None else torch.full((B,), L, device=x.device)).clamp(min=1).float().unsqueeze(1)
        centers = (lens - 1) / 2
        dist2 = (pos - centers).pow(2)
        sigma = torch.exp(self.log_sigma) + 1e-6
        center_bias = - dist2 / (2 * sigma * sigma)
        scores = scores + self.alpha * center_bias
        if mask is not None:
            neg = torch.finfo(scores.dtype).min
            scores = scores.masked_fill(~mask.bool(), neg)
        attn = torch.softmax(scores.float(), dim=1).to(x.dtype)
        return torch.einsum('bl,bld->bd', attn, x)

class ProtMLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int, p: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.LayerNorm(hidden), nn.Dropout(p),
            nn.Linear(hidden, out_dim),
        )
        self.out_norm = nn.LayerNorm(out_dim)
        self._init()
    def _init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.out_norm(self.net(x))

class RNATower(nn.Module):
    """
    Tokens: LongTensor [B, L]
    Optional structure (PHIME): FloatTensor [B, L, 5] where channels = [P,H,I,M,E].
    If provided, structure is projected to model dim and added residually to token embeddings.
    """
    def __init__(self, cfg: Config):
        super().__init__()
        V, D = len(CFG.RNA_VOCAB), cfg.D_MODEL
        self.embed = nn.Embedding(V + 1, D, padding_idx=PAD_ID)
        self.pos = PositionalEncoding(D, cfg.RNA_MAX_LEN)

        # --- NEW: structure fusion (light residual) ---
        self.use_struct = bool(getattr(cfg, "RNA_USE_STRUCT", False))
        if self.use_struct:
            self.struct_proj = nn.Linear(getattr(cfg, "RNA_STRUCT_DIM", 5), D, bias=True)
            self.struct_ln = nn.LayerNorm(D)
            self.struct_drop = nn.Dropout(cfg.RNA_DROPOUT)
            self.struct_scale = nn.Parameter(torch.tensor(1.0))
        else:
            self.struct_proj = None

        self.conv1 = ConvBlock(D, kernel=5,  dilation=1, dropout=cfg.RNA_DROPOUT)
        self.conv2 = ConvBlock(D, kernel=9,  dilation=2, dropout=cfg.RNA_DROPOUT)
        self.conv3 = ConvBlock(D, kernel=13, dilation=4, dropout=cfg.RNA_DROPOUT)
        self.k9 = nn.Conv1d(D, D, kernel_size=9, padding=4, bias=False)
        self.k9_gamma = nn.Parameter(torch.tensor(0.5))
        if cfg.RNA_USE_TRANSFORMER:
            el = nn.TransformerEncoderLayer(
                d_model=D, nhead=cfg.RNA_NHEAD, dim_feedforward=D*4,
                dropout=cfg.RNA_DROPOUT, batch_first=True
            )
            self.tf = nn.TransformerEncoder(el, num_layers=cfg.RNA_TRANSFORMER_LAYERS)
        else:
            self.tf = None
        self.pool = GatedPooling(D)
        self.out_norm = nn.LayerNorm(D)

    def forward(self, tokens: torch.Tensor, struct: Optional[torch.Tensor] = None) -> torch.Tensor:
        mask = (tokens != PAD_ID)                           # [B,L]
        x = self.embed(tokens)                              # [B,L,D]
        x = self.pos(x)
        x = x * mask.unsqueeze(-1).to(x.dtype)

        # Inject structure if available + enabled
        if self.use_struct and struct is not None and self.struct_proj is not None:
            # struct: [B,L,5]; zero is safe for PAD since loader pads with zeros
            s = self.struct_proj(struct.to(x.dtype))        # [B,L,D]
            s = self.struct_ln(F.gelu(s))
            s = self.struct_drop(s)
            x = x + self.struct_scale * s
            # keep masking clean
            x = x * mask.unsqueeze(-1).to(x.dtype)

        xc = x.transpose(1, 2)                              # [B,D,L]
        xc = self.conv1(xc, mask)
        xc = self.conv2(xc, mask)
        xc = self.conv3(xc, mask)
        k9 = F.gelu(self.k9(xc))
        k9 = k9 * mask.unsqueeze(1).to(k9.dtype)
        xc = xc + self.k9_gamma * k9
        x = xc.transpose(1, 2)
        if self.tf is not None:
            x = self.tf(x, src_key_padding_mask=~mask)
        h = self.pool(x, mask)
        return self.out_norm(h)

class GatedBilinearLowRankCosine(nn.Module):
    """
    Cosine bilinear with protein-conditioned gating in rank space.
    Gate is residual & zero-centered: scale axes by (1 + s * tanh(G e_p)),
    so with GATE_STRENGTH=0 you recover the original head.
    """
    def __init__(self, dim: int, rank: int, gate_strength: float = 0.5):
        super().__init__()
        self.U = nn.Linear(dim, rank, bias=False)   # protein -> rank
        self.V = nn.Linear(dim, rank, bias=False)   # RNA -> rank
        self.G = nn.Linear(dim, rank, bias=True)    # protein -> gate params
        self.bias = nn.Parameter(torch.zeros(1))
        self.gate_strength = gate_strength

    def forward(self, e_p: torch.Tensor, e_r: torch.Tensor) -> torch.Tensor:
        # e_p: [Bp, D], e_r: [Br, D]
        up = F.normalize(self.U(e_p), dim=1)        # [Bp, R]
        vr = F.normalize(self.V(e_r), dim=1)        # [Br, R]
        g  = torch.tanh(self.G(e_p))                # [-1,1], shape [Bp, R]
        upg = up * (1.0 + self.gate_strength * g)   # protein-conditional scaling of rank axes
        return upg @ vr.t() + self.bias             # [Bp, Br]


class TwoTowerModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.rna = RNATower(cfg)
        self.prot_proj = ProtMLP(cfg.PROT_EMB_DIM, cfg.PROT_MLP_HIDDEN, cfg.D_MODEL, cfg.PROT_DROPOUT)
        self.score = GatedBilinearLowRankCosine(cfg.D_MODEL, cfg.RANK, cfg.GATE_STRENGTH)

    def encode_rna(self, rna_tokens: torch.Tensor, rna_struct: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.rna(rna_tokens, rna_struct)

    def project_prot(self, prot_vecs: torch.Tensor) -> torch.Tensor:
        return self.prot_proj(prot_vecs)

    def forward_scores_vecs(
        self,
        prot_vecs: torch.Tensor,
        rna_tokens: torch.Tensor,
        rna_struct: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        e_p = self.project_prot(prot_vecs)
        e_r = self.encode_rna(rna_tokens, rna_struct)
        return self.score(e_p, e_r)


## training code

In [None]:
# =============================
# Losses & metrics
# =============================
class PearsonLoss(nn.Module):
    """1 - Pearson, averaged across protein rows in the mini-batch."""
    def __init__(self, eps: float = 1e-8): super().__init__(); self.eps = eps
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred = pred - pred.mean(dim=1, keepdim=True)
        target = target - target.mean(dim=1, keepdim=True)
        cov = (pred * target).sum(dim=1)
        denom = torch.sqrt((pred.pow(2).sum(dim=1) + self.eps) * (target.pow(2).sum(dim=1) + self.eps))
        r = cov / denom
        return (1.0 - r).mean()

def huber_loss(pred: torch.Tensor, target: torch.Tensor, delta: float) -> torch.Tensor:
    return F.huber_loss(pred, target, delta=delta)

def pearson_numpy(a: np.ndarray, b: np.ndarray) -> float:
    a = a - a.mean(); b = b - b.mean()
    denom = np.sqrt((a*a).sum() * (b*b).sum()) + 1e-12
    return float((a*b).sum() / denom)

# =============================
# Training / evaluation helpers
# =============================
def build_param_groups(model: nn.Module):
    prot_decay, decay, no_decay = [], [], []
    for module_name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            for name, param in module.named_parameters(recurse=False):
                if not param.requires_grad: continue
                if name == "weight":
                    (prot_decay if module_name.startswith("prot_proj") else decay).append(param)
                else:
                    no_decay.append(param)
        elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
            for _, param in module.named_parameters(recurse=False):
                if param.requires_grad: no_decay.append(param)
    covered = {id(p) for p in prot_decay + decay + no_decay}
    for _, p in model.named_parameters():
        if p.requires_grad and id(p) not in covered: decay.append(p)
    return {"prot_decay": prot_decay, "decay": decay, "no_decay": no_decay}

# --- NEW: Exponential Moving Average of model params ---
import copy

class ModelEMA:
    """
    Track an exponential moving average of model parameters & buffers.
    Keeps a full copy of the model as `module` in eval mode with requires_grad=False.
    """
    def __init__(self, model: nn.Module, decay: float = 0.999, device: Optional[torch.device] = None):
        self.decay = float(decay)
        if device is None:
            device = next(model.parameters()).device
        self.module = copy.deepcopy(model).to(device).eval()
        for p in self.module.parameters():
            p.requires_grad_(False)
        self.num_updates = 0

    @torch.no_grad()
    def update(self, model: nn.Module):
        """EMA: m_ema = decay * m_ema + (1 - decay) * m."""
        self.num_updates += 1
        ema_sd = self.module.state_dict()
        mdl_sd = model.state_dict()
        for k, ema_v in ema_sd.items():
            v = mdl_sd[k]
            if not torch.is_floating_point(v):
                ema_v.copy_(v)  # buffers (ints/bools) copy directly
            else:
                ema_v.mul_(self.decay).add_(v, alpha=1.0 - self.decay)

    def to(self, device: torch.device):
        self.module.to(device)

    def state_dict(self):
        return {
            "module": self.module.state_dict(),
            "decay": self.decay,
            "num_updates": self.num_updates,
        }

    def load_state_dict(self, state):
        self.module.load_state_dict(state["module"])
        self.decay = float(state.get("decay", self.decay))
        self.num_updates = int(state.get("num_updates", 0))


class Trainer:
    def __init__(self, data: TrainData, cfg: Config):
        self.data, self.cfg = data, cfg

        # Choose cache names (avoid overwriting non-dedup caches)
        cache_suffix = "_dedup" if data.dedup_info.get("applied") else ""
        pt5_cache = _with_suffix(cfg.PROT_EMB_CACHE_TRAIN, cache_suffix)
        esm_cache = _with_suffix(cfg.ESM_EMB_CACHE_TRAIN, cache_suffix)

        # --- Protein embeddings (ProtT5 and/or ESM-2) ---
        self.pt5_all = None
        self.esm_all = None
        total_dim = 0

        if cfg.USE_PROTT5:
            pt5_np = load_or_compute_prott5_embeddings(
                data.prot_seqs, pt5_cache, cfg.PROTT5_MODEL_ID, cfg.PROT_EMB_BATCH, device
            )
            self.pt5_all = torch.tensor(pt5_np, dtype=torch.float32, device=device)  # [N_rbp, 1024]
            total_dim += self.pt5_all.shape[1]

        if cfg.USE_ESM2:
            esm_np = load_or_compute_esm2_embeddings(
                data.prot_seqs, esm_cache, cfg.ESM2_MODEL_ID, cfg.ESM_EMB_BATCH, device
            )
            self.esm_all = torch.tensor(esm_np, dtype=torch.float32, device=device)  # [N_rbp, 1280+]
            total_dim += self.esm_all.shape[1]

        if total_dim == 0:
            raise ValueError("At least one of USE_PROTT5 or USE_ESM2 must be True.")
        cfg.PROT_EMB_DIM = total_dim

        used = [name for name, arr in (("ESM2", self.esm_all), ("ProtT5", self.pt5_all)) if arr is not None]
        print(f"[Prot] Using {', '.join(used)} → PROT_EMB_DIM={cfg.PROT_EMB_DIM}")

        # Build model (uses cfg.PROT_EMB_DIM)
        self.model = TwoTowerModel(cfg).to(device)

        # --- NEW: EMA (safe defaults if not in Config) ---
        ema_use   = bool(getattr(cfg, "EMA_USE", True))
        ema_decay = float(getattr(cfg, "EMA_DECAY", 0.999))
        self.ema = ModelEMA(self.model, decay=ema_decay) if ema_use else None

        # Train-only source-wise standardization stats
        if cfg.PROT_SRC_ZSCORE:
            tr = torch.tensor(self.data.train_rbps, device=device)

            if self.pt5_all is not None:
                self.pt5_mu = self.pt5_all[tr].mean(0, keepdim=True)
                self.pt5_sd = self.pt5_all[tr].std(0, keepdim=True).clamp_min(1e-6)
            else:
                self.pt5_mu = self.pt5_sd = None

            if self.esm_all is not None:
                self.esm_mu = self.esm_all[tr].mean(0, keepdim=True)
                self.esm_sd = self.esm_all[tr].std(0, keepdim=True).clamp_min(1e-6)
            else:
                self.esm_mu = self.esm_sd = None
        else:
            self.pt5_mu = self.pt5_sd = self.esm_mu = self.esm_sd = None

        # Optimizer with parameter groups
        groups = build_param_groups(self.model)
        print(f"[Opt] Param groups — prot_decay: {sum(p.numel() for p in groups['prot_decay']):,} | "
              f"decay: {sum(p.numel() for p in groups['decay']):,} | "
              f"no_decay: {sum(p.numel() for p in groups['no_decay']):,}")

        self.opt = torch.optim.AdamW([
            {"params": groups["prot_decay"], "lr": cfg.LR, "weight_decay": cfg.PROT_MLP_WD},
            {"params": groups["decay"],      "lr": cfg.LR, "weight_decay": cfg.WEIGHT_DECAY},
            {"params": groups["no_decay"],   "lr": cfg.LR, "weight_decay": 0.0},
        ])

        # AMP (new API)
        self.scaler = torch.amp.GradScaler('cuda', enabled=cfg.MIXED_PRECISION and device.type == 'cuda')
        self.pearson_loss = PearsonLoss()
        self.best_median = -1.0
        self.checkpoint_path = os.path.join(cfg.CACHE_DIR, 'best_model.pt' + ('.dedup' if cache_suffix else ''))


    def _prep_prot(self, rbp_idx: torch.Tensor, training: bool) -> torch.Tensor:
        """Return per-RBP protein features (concat of available sources in [ESM, ProtT5] order)."""
        # ESM2 branch (optional)
        esm = None
        if self.esm_all is not None:
            esm = self.esm_all[rbp_idx]
            if self.cfg.PROT_SRC_ZSCORE and (self.esm_mu is not None) and (self.esm_sd is not None):
                esm = (esm - self.esm_mu) / self.esm_sd

        # ProtT5 branch (optional)
        pt5 = None
        if self.pt5_all is not None:
            pt5 = self.pt5_all[rbp_idx]
            if self.cfg.PROT_SRC_ZSCORE and (self.pt5_mu is not None) and (self.pt5_sd is not None):
                pt5 = (pt5 - self.pt5_mu) / self.pt5_sd

        # Sanity: at least one must be present (guarded in __init__, but keep it here too)
        if esm is None and pt5 is None:
            raise RuntimeError("No protein embeddings available: set USE_ESM2 and/or USE_PROTT5 to True.")

        # Optional swapdrop (only meaningful if both streams exist)
        if training and self.cfg.PROT_SWAPDROP_P > 0 and (esm is not None) and (pt5 is not None):
            p = self.cfg.PROT_SWAPDROP_P
            B = esm.size(0)
            m_esm = (torch.rand(B, 1, device=esm.device) < p)
            m_pt5 = (torch.rand(B, 1, device=pt5.device) < p)
            # Avoid dropping both: if both true, drop only ESM (keep ProtT5)
            both = m_esm & m_pt5
            m_pt5 = m_pt5 & ~both
            esm = esm.masked_fill(m_esm, 0.0)
            pt5 = pt5.masked_fill(m_pt5, 0.0)

        # Concatenate in [ESM, ProtT5] order when both exist
        if esm is not None and pt5 is not None:
            return torch.cat([esm, pt5], dim=1)
        return esm if esm is not None else pt5


    def _maybe_struct_batch(self, rna_idx: np.ndarray) -> Optional[torch.Tensor]:
        """Fetch PHIME batch if available & enabled; shape [B, L, 5] padded with zeros."""
        if not getattr(self.cfg, "RNA_USE_STRUCT", False):
            return None
        if not hasattr(self.data, "load_struct_batch"):
            return None
        S = self.data.load_struct_batch(rna_idx)  # expected torch.FloatTensor or np.ndarray
        if S is None:
            return None
        if isinstance(S, np.ndarray):
            S = torch.from_numpy(S)
        return S  # device move happens at call site

    def sample_batch(self, sampler: BatchSampler):
        rbp_idx_np, rna_idx = sampler.sample()
        rbp_idx = torch.tensor(rbp_idx_np, device=device)
        prot_vecs = self._prep_prot(rbp_idx, training=True)                          # [Bp, in_dim]
        rna_tokens = tokenize_rna_batch([self.data.rna_seqs[i] for i in rna_idx]).to(device)
        rna_struct = self._maybe_struct_batch(rna_idx)
        if rna_struct is not None: rna_struct = rna_struct.to(device)

        mu = torch.tensor(self.data.mu[rbp_idx_np], device=device).unsqueeze(1)      # [Bp,1]
        sd = torch.tensor(self.data.sd[rbp_idx_np], device=device).unsqueeze(1)      # [Bp,1]
        Y = torch.tensor(self.data.Y[np.ix_(rna_idx, rbp_idx_np)], device=device).t()# [Bp,Br]
        return prot_vecs, rna_tokens, rna_struct, (Y - mu) / sd

    def train(self):
        sampler = BatchSampler(self.data, self.cfg)
        for epoch in range(1, self.cfg.EPOCHS + 1):
            self.model.train(); losses = []; t0 = time.time()
            for step in range(1, self.cfg.STEPS_PER_EPOCH + 1):
                prot_vecs, rna_tokens, rna_struct, Yz = self.sample_batch(sampler)
                self.opt.zero_grad(set_to_none=True)
                with torch.amp.autocast('cuda', enabled=self.cfg.MIXED_PRECISION and device.type == 'cuda'):
                    S = self.model.forward_scores_vecs(prot_vecs, rna_tokens, rna_struct)  # [Bp, Br]
                    loss = (1 - self.cfg.CORR_WEIGHT) * huber_loss(S, Yz, delta=self.cfg.HUBER_DELTA) \
                           + self.cfg.CORR_WEIGHT * self.pearson_loss(S, Yz)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.opt); self.scaler.update()

                # --- EMA update after optimizer step ---
                if self.ema is not None:
                    self.ema.update(self.model)

                losses.append(loss.item())
                if step % 50 == 0:
                    print(f"Epoch {epoch} Step {step}/{self.cfg.STEPS_PER_EPOCH} | loss={np.mean(losses[-50:]):.4f}")

            med_r, mean_r = self.evaluate()
            print(f"[Eval] Epoch {epoch} | median r={med_r:.4f} | mean r={mean_r:.4f} | time={time.time()-t0:.1f}s")

            if med_r > self.best_median:
                self.best_median = med_r
                payload = {
                    'model': self.model.state_dict(),
                    'ema': (self.ema.state_dict() if self.ema is not None else None),
                    'cfg': self.cfg.__dict__,
                    'median_r': med_r,
                }
                torch.save(payload, self.checkpoint_path)
                print(f"[Save] New best median r={med_r:.4f}. Saved to {self.checkpoint_path}")

            if epoch % self.cfg.SAVE_EVERY == 0:
                payload = {
                    'model': self.model.state_dict(),
                    'ema': (self.ema.state_dict() if self.ema is not None else None),
                    'cfg': self.cfg.__dict__,
                    'epoch': epoch,
                }
                torch.save(payload, os.path.join(self.cfg.CACHE_DIR, f'checkpoint_epoch_{epoch}.pt'))

    @torch.no_grad()
    def evaluate(self, use_tta=True, n_augments=5, noise_scale=0.01) -> Tuple[float, float]:
        """Pearson per-RBP across RNAs.
          If no held-out RBPs exist, evaluate on TRAIN RBPs/RNAs instead of returning 0."""
        # Decide which split to use
        use_val = hasattr(self.data, "val_rbps") and len(self.data.val_rbps) > 0
        rbp_idx_np = self.data.val_rbps if use_val else self.data.train_rbps
        rna_idx_np = self.data.val_idx   if (use_val and hasattr(self.data, "val_idx")) else self.data.train_idx
        split_name = "VAL" if use_val else "TRAIN"
        print(f"[Eval] Using {split_name} split (RBPs={len(rbp_idx_np)}, RNAs={len(rna_idx_np)})")

        # Pick model (EMA or current)
        use_ema_for_eval = bool(getattr(self.cfg, "EMA_EVAL", True))
        eval_model = self.ema.module if (self.ema is not None and use_ema_for_eval) else self.model
        eval_model.eval()

        # Tokenize RNAs for this split
        tokens = tokenize_rna_batch([self.data.rna_seqs[i] for i in rna_idx_np])

        # Optional structure
        struct = None
        if getattr(self.cfg, "RNA_USE_STRUCT", False) and hasattr(self.data, "load_struct_batch"):
            vs = self.data.load_struct_batch(rna_idx_np)
            if isinstance(vs, np.ndarray): vs = torch.from_numpy(vs)
            struct = vs

        # Encode RNAs
        rna_emb = self._encode_rna_in_batches(tokens, struct, model=eval_model)  # [N_rna, D]

        # Protein features
        rbp_idx_t = torch.tensor(rbp_idx_np, device=device)
        base = self._prep_prot(rbp_idx_t, training=False)  # [N_rbp, in_dim]

        # TTA (optional)
        if use_tta:
            all_scores = []
            for k in range(n_augments):
                prot_vecs = base if k == 0 else base + torch.randn_like(base) * noise_scale
                e_p = eval_model.project_prot(prot_vecs)
                all_scores.append(eval_model.score(e_p, rna_emb).detach().cpu().numpy())
            S = np.mean(all_scores, axis=0)  # [N_rbp, N_rna]
        else:
            e_p = eval_model.project_prot(base)
            S = eval_model.score(e_p, rna_emb).detach().cpu().numpy()

        # Targets for the same split
        Y = self.data.Y[np.ix_(rna_idx_np, rbp_idx_np)]
        if self.cfg.EVAL_ZSCORE_TARGETS:
            mu = self.data.mu[rbp_idx_np][None, :]
            sd = self.data.sd[rbp_idx_np][None, :]
            Y = (Y - mu) / sd
        Y = Y.T  # [N_rbp, N_rna]

        # Per-RBP Pearson
        rs = np.zeros(len(rbp_idx_np), dtype=np.float32)
        for i in range(len(rbp_idx_np)):
            rs[i] = pearson_numpy(S[i], Y[i])
        return float(np.median(rs)), float(rs.mean())


    @torch.no_grad()
    def _encode_rna_in_batches(
        self,
        tokens: torch.Tensor,
        struct: Optional[torch.Tensor] = None,
        batch: int = 2048,
        model: Optional[nn.Module] = None,
    ) -> torch.Tensor:
        """
        tokens: [N, L]
        struct: optional [N, L, 5] aligned & zero-padded; may be on CPU.
        model: which model to use for encoding (EMA or current)
        returns: [N, D] on `device`
        """
        model = model or self.model
        model.eval(); outs = []
        N = tokens.size(0)
        for i in range(0, N, batch):
            tok = tokens[i:i+batch].to(device)
            st = None
            if struct is not None:
                st = struct[i:i+batch]
                if isinstance(st, np.ndarray): st = torch.from_numpy(st)
                st = st.to(device)
            outs.append(model.encode_rna(tok, st))
        return torch.cat(outs, dim=0)

# =============================
# Main
# =============================
def main():
    data = TrainData(CFG)
    trainer = Trainer(data, CFG)
    trainer.train()
    print("\n[Info] Best validation median r = {:.4f}".format(trainer.best_median))
    print(f"Best checkpoint: {trainer.checkpoint_path}")

if __name__ == '__main__':
    main()


## calculate test rbp embedding

In [None]:
# %% Colab cell: ESM-2 embeddings for test RBPs (no CLI args) + index sidecar
# If needed: !pip -q install transformers

import os, sys, gc, numpy as np, torch
from transformers import AutoTokenizer, AutoModel

# ----------------- CONFIG -----------------
RBP_FILE    = "test_RBPs2.txt"                    # one protein per line
OUT_PATH    = "cache/esm_emb_test.npy"            # embeddings (N, D)
INDEX_PATH  = "cache/esm_emb_test.index.txt"      # 0-based index per row (one int per line)
MODEL_ID    = "facebook/esm2_t48_15B_UR50D"       # 15B
FALLBACK_ON_OOM = False
FALLBACK_MODEL  = "facebook/esm2_t33_650M_UR50D"  # fallback if OOM
BATCH_SIZE  = 4                                   # tune if you see OOM
OVERWRITE   = False                               # set True to force recompute
MAX_LEN     = 1024                                # tokenizer cap (ESM-2 limit)
# ------------------------------------------

def read_lines(path):
    with open(path) as f:
        return [ln.strip().upper() for ln in f if ln.strip()]

def pick_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def pick_dtype(device):
    if device.type == "cuda":
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    return torch.float32

def load_model(model_id, dtype, device):
    print(f"[ESM2] Loading {model_id} (dtype={dtype}, device={device}) ...")
    tok = AutoTokenizer.from_pretrained(model_id)
    mdl = AutoModel.from_pretrained(model_id, torch_dtype=dtype)
    mdl = mdl.to(device).eval()
    print(f"[ESM2] Hidden size: {mdl.config.hidden_size}")
    return tok, mdl

def clear_cuda():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

@torch.no_grad()
def compute_embeddings(seqs, tokenizer, model, device, batch_size, max_len):
    N = len(seqs)
    outs = []
    for i in range(0, N, batch_size):
        batch = seqs[i:i+batch_size]
        inputs = tokenizer(
            batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt"
        ).to(device)
        hs = model(**inputs).last_hidden_state         # [B, L, D]
        am = inputs["attention_mask"]                  # [B, L]
        pooled = []
        for j in range(hs.size(0)):
            L = int(am[j].sum().item())
            toks = hs[j, 1:L-1, :] if L > 2 else hs[j, 1:L, :]  # drop <cls>/<eos> if both present
            pooled.append(toks.mean(0).float().cpu().numpy())
        outs.append(np.stack(pooled, axis=0))
        if (i // batch_size) % 10 == 0:
            print(f"[ESM2] {min(i+batch_size, N)}/{N}")
        if device.type == "cuda" and ((i // batch_size) % 8 == 0):
            clear_cuda()
    return np.concatenate(outs, axis=0).astype(np.float32)

# ---- run ----
os.makedirs(os.path.dirname(OUT_PATH), exist_ok=True)
seqs = read_lines(RBP_FILE)
N = len(seqs)
print(f"[Load] RBPs: {N}")

if os.path.exists(OUT_PATH) and not OVERWRITE:
    try:
        E_existing = np.load(OUT_PATH)
        if E_existing.shape[0] == N:
            # Still (re)write the index file to be safe/consistent
            with open(INDEX_PATH, "w") as f:
                for i in range(N):
                    f.write(f"{i}\n")
            print(f"[Skip] Existing embeddings match rows at {OUT_PATH} ({E_existing.shape}).")
            print(f"[Info] Wrote index sidecar: {INDEX_PATH} (0-based, one per line)")
            raise SystemExit
    except Exception:
        pass  # fall through to recompute

device = pick_device()
dtype  = pick_dtype(device)
print(f"[Device] {device}")

try:
    tokenizer, model = load_model(MODEL_ID, dtype, device)
    E = compute_embeddings(seqs, tokenizer, model, device, BATCH_SIZE, MAX_LEN)
except RuntimeError as e:
    if "out of memory" in str(e).lower() and FALLBACK_ON_OOM:
        print("\n[OOM] GPU ran out of memory with 15B. Falling back to", FALLBACK_MODEL)
        del model; clear_cuda(); gc.collect()
        tokenizer, model = load_model(FALLBACK_MODEL, dtype, device)
        E = compute_embeddings(seqs, tokenizer, model, device, BATCH_SIZE*2, MAX_LEN)
    else:
        raise

np.save(OUT_PATH, E)
print(f"[Done] Saved embeddings: {OUT_PATH} with shape {E.shape}")

# Sidecar: 0-based index for each embedding row (row i ↔ line i in RBP_FILE)
with open(INDEX_PATH, "w") as f:
    for i in range(N):
        f.write(f"{i}\n")
print(f"[Done] Saved index sidecar: {INDEX_PATH} (0-based, one per line)")


In [None]:
import numpy as np, pathlib

E = np.load("cache/esm_emb_test.npy")                  # [N, D]
idx = np.loadtxt("cache/esm_emb_test.index.txt", dtype=int)  # [N]

print(E.shape, idx.shape)           # e.g., (44, 1280) (44,)
assert (idx == np.arange(len(idx))).all()   # mapping is 0..N-1

# Optional: verify line counts vs your RBP file
N_file = sum(1 for _ in open("test_RBPs2.txt") if _.strip())
assert E.shape[0] == N_file == idx.size


## prediction code

In [None]:
#!/usr/bin/env python3
# predict_streaming_standalone.py (with TTA)
# Streams RNAs, loads PHIME per batch, uses ESM-2 test embeddings, writes RBP201.. files.

import os, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List

# ==============================
# Hardcoded paths & knobs
# ==============================
RBP_FILE   = "test_RBPs2.txt"              # one protein per line (44 lines)
RNA_FILE   = "test_seqs.txt"               # one RNA per line (same order as outputs)
CKPT_PATH  = "cache/best_model.pt.dedup"   # your final checkpoint
STRUCT_DIR = "struct_test"                 # where seq_000000.npz ... live (PHIME)
ESM_TEST   = "cache/esm_emb_test.npy"      # precomputed ESM-2 embeddings for test RBPs
ESM_TRAIN  = "cache/esm_emb_train.npy"     # optional: train z-score (if enabled during training)

OUT_DIR    = "predictions"                 # output folder
BASE_OFFSET = 200                          # → RBP201.txt .. RBP(200+P).txt
RNA_BATCH   = 4096                         # RNA batch size for streaming
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FMT = "%.8f"                               # eight decimals

# --- TTA knobs (match training evaluate) ---
TTA_USE   = True
TTA_N     = 5
TTA_NOISE = 0.01                           # noise std in (z-scored) ESM feature space

# ==============================
# Minimal config & constants
# ==============================
@dataclass
class Config:
    # (training cfg fields we need; rest carried from checkpoint's cfg)
    RNA_USE_STRUCT: bool = True
    RNA_STRUCT_DIM: int = 5
    RNA_MAX_LEN: int = 64
    RNA_VOCAB: str = 'ACGU'

    D_MODEL: int = 256
    RANK: int = 512
    PROT_EMB_DIM: int = 1280

    RNA_USE_TRANSFORMER: bool = True
    RNA_NHEAD: int = 4
    RNA_TRANSFORMER_LAYERS: int = 2
    RNA_DROPOUT: float = 0.3

    GATE_STRENGTH: float = 0.5

    USE_ESM2: bool = True
    USE_PROTT5: bool = False
    PROT_SRC_ZSCORE: bool = True

    PROT_MLP_HIDDEN: int = 256
    PROT_DROPOUT: float = 0.3

# ==============================
# Helpers
# ==============================
def read_lines(path: str, normalize_rna: bool=False) -> List[str]:
    xs=[]
    with open(path) as f:
        for ln in f:
            s = ln.strip()
            if not s: continue
            s = s.upper()
            if normalize_rna: s = s.replace("T","U")
            xs.append(s)
    return xs

def build_vocab(cfg: Config):
    vocab = {ch: i for i, ch in enumerate(cfg.RNA_VOCAB)}
    pad_id = len(cfg.RNA_VOCAB)
    return vocab, pad_id

def tokenize_rna_batch(cfg: Config, seqs: List[str]) -> torch.Tensor:
    vocab, pad_id = build_vocab(cfg)
    lens = [len(s) for s in seqs]
    Lmax = max(lens) if lens else 0
    out = torch.full((len(seqs), Lmax), fill_value=pad_id, dtype=torch.long)
    for i, s in enumerate(seqs):
        ids = [vocab.get(ch, 0) for ch in s]  # unknowns→A(0)
        if ids:
            out[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
    return out

# ==============================
# Model (mirrors your training shapes)
# ==============================
class PositionalEncoding(nn.Module):
    def __init__(self, dim: int, max_len: int):
        super().__init__()
        self.pe = nn.Embedding(max_len, dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L = x.size(1)
        pos = torch.arange(L, device=x.device).clamp_max(self.pe.num_embeddings - 1).unsqueeze(0)
        return x + self.pe(pos)

class ConvBlock(nn.Module):
    def __init__(self, dim: int, kernel: int, dilation: int, dropout: float):
        super().__init__()
        padding = (kernel - 1) // 2 * dilation
        self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=padding, dilation=dilation, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(dim)
    def forward(self, x: torch.Tensor, mask_1d: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)                  # [B,D,L]
        y = F.gelu(y).transpose(1, 2)     # [B,L,D]
        y = self.ln(y).transpose(1, 2)
        y = self.dropout(y)
        out = x + y
        return out * mask_1d.unsqueeze(1).to(out.dtype)

class GatedPooling(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.proj = nn.Linear(dim, 1)
        self.alpha = nn.Parameter(torch.tensor(0.2))
        self.log_sigma = nn.Parameter(torch.log(torch.tensor(6.0)))
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        scores = self.proj(x).squeeze(-1)            # [B,L]
        B, L = scores.size()
        pos = torch.arange(L, device=x.device).float().unsqueeze(0).expand(B, -1)
        lens = (mask.sum(1) if mask is not None else torch.full((B,), L, device=x.device)).clamp(min=1).float().unsqueeze(1)
        centers = (lens - 1) / 2
        dist2 = (pos - centers).pow(2)
        sigma = torch.exp(self.log_sigma) + 1e-6
        scores = scores + self.alpha * (- dist2 / (2 * sigma * sigma))
        if mask is not None:
            scores = scores.masked_fill(~mask.bool(), torch.finfo(scores.dtype).min)
        attn = torch.softmax(scores.float(), dim=1).to(x.dtype)
        return torch.einsum('bl,bld->bd', attn, x)

class ProtMLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int, p: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(), nn.LayerNorm(hidden), nn.Dropout(p),
            nn.Linear(hidden, out_dim),
        )
        self.out_norm = nn.LayerNorm(out_dim)
        self._init()
    def _init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.out_norm(self.net(x))

class RNATower(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.vocab, self.pad_id = build_vocab(cfg)
        V, D = len(cfg.RNA_VOCAB), cfg.D_MODEL
        self.embed = nn.Embedding(V + 1, D, padding_idx=self.pad_id)
        self.pos = PositionalEncoding(D, cfg.RNA_MAX_LEN)

        self.use_struct = bool(getattr(cfg, "RNA_USE_STRUCT", False))
        if self.use_struct:
            self.struct_proj = nn.Linear(getattr(cfg, "RNA_STRUCT_DIM", 5), D, bias=True)
            self.struct_ln = nn.LayerNorm(D)
            self.struct_drop = nn.Dropout(cfg.RNA_DROPOUT)
            self.struct_scale = nn.Parameter(torch.tensor(1.0))
        else:
            self.struct_proj = None

        self.conv1 = ConvBlock(D, 5, 1, cfg.RNA_DROPOUT)
        self.conv2 = ConvBlock(D, 9, 2, cfg.RNA_DROPOUT)
        self.conv3 = ConvBlock(D, 13, 4, cfg.RNA_DROPOUT)
        self.k9 = nn.Conv1d(D, D, kernel_size=9, padding=4, bias=False)
        self.k9_gamma = nn.Parameter(torch.tensor(0.5))

        if cfg.RNA_USE_TRANSFORMER:
            el = nn.TransformerEncoderLayer(d_model=D, nhead=cfg.RNA_NHEAD, dim_feedforward=D*4,
                                            dropout=cfg.RNA_DROPOUT, batch_first=True)
            self.tf = nn.TransformerEncoder(el, num_layers=cfg.RNA_TRANSFORMER_LAYERS)
        else:
            self.tf = None

        self.pool = GatedPooling(D)
        self.out_norm = nn.LayerNorm(D)

    def forward(self, tokens: torch.Tensor, struct: Optional[torch.Tensor] = None) -> torch.Tensor:
        mask = (tokens != self.pad_id)                      # [B,L]
        x = self.embed(tokens)                              # [B,L,D]
        x = self.pos(x)
        x = x * mask.unsqueeze(-1).to(x.dtype)

        if self.use_struct and struct is not None and self.struct_proj is not None:
            s = self.struct_proj(struct.to(x.dtype))
            s = self.struct_ln(F.gelu(s))
            s = self.struct_drop(s)
            x = x + self.struct_scale * s
            x = x * mask.unsqueeze(-1).to(x.dtype)

        xc = x.transpose(1, 2)                              # [B,D,L]
        xc = self.conv1(xc, mask)
        xc = self.conv2(xc, mask)
        xc = self.conv3(xc, mask)
        k9 = F.gelu(self.k9(xc))
        k9 = k9 * mask.unsqueeze(1).to(k9.dtype)
        xc = xc + self.k9_gamma * k9
        x = xc.transpose(1, 2)

        if self.tf is not None:
            x = self.tf(x, src_key_padding_mask=~mask)
        h = self.pool(x, mask)
        return self.out_norm(h)

class GatedBilinearLowRankCosine(nn.Module):
    def __init__(self, dim: int, rank: int, gate_strength: float = 0.5):
        super().__init__()
        self.U = nn.Linear(dim, rank, bias=False)   # protein -> rank
        self.V = nn.Linear(dim, rank, bias=False)   # RNA -> rank
        self.G = nn.Linear(dim, rank, bias=True)    # protein -> gate
        self.bias = nn.Parameter(torch.zeros(1))
        self.gate_strength = gate_strength
    def forward(self, e_p: torch.Tensor, e_r: torch.Tensor) -> torch.Tensor:
        up = F.normalize(self.U(e_p), dim=1)        # [Bp,R]
        vr = F.normalize(self.V(e_r), dim=1)        # [Br,R]
        g  = torch.tanh(self.G(e_p))                # [Bp,R]
        upg = up * (1.0 + self.gate_strength * g)
        return upg @ vr.t() + self.bias

class TwoTowerModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.rna = RNATower(cfg)
        self.prot_proj = ProtMLP(cfg.PROT_EMB_DIM, cfg.PROT_MLP_HIDDEN, cfg.D_MODEL, cfg.PROT_DROPOUT)
        self.score = GatedBilinearLowRankCosine(cfg.D_MODEL, cfg.RANK, cfg.GATE_STRENGTH)
    def encode_rna(self, tokens: torch.Tensor, struct: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.rna(tokens, struct)
    def project_prot(self, prot_vecs: torch.Tensor) -> torch.Tensor:
        return self.prot_proj(prot_vecs)

# ==============================
# I/O + structure + checkpoint
# ==============================
def load_checkpoint(ckpt_path: str, device: torch.device):
    ckpt = torch.load(ckpt_path, map_location=device)
    saved = ckpt.get("cfg", {}) or {}

    # Start from local defaults, then overlay only known fields
    cfg = Config()
    fields = set(Config.__dataclass_fields__.keys())
    applied, skipped = [], []
    for k, v in saved.items():
        if k in fields:
            setattr(cfg, k, v); applied.append(k)
        else:
            skipped.append(k)
    if applied:
        print(f"[Cfg] Applied {len(applied)} fields from checkpoint.")
    if skipped:
        print(f"[Cfg] Skipped {len(skipped)} unknown fields (expected in standalone).")

    model = TwoTowerModel(cfg).to(device).eval()
    ema = ckpt.get("ema")
    if isinstance(ema, dict) and "module" in ema:
        model.load_state_dict(ema["module"], strict=True)
        print("[Load] EMA weights.")
    else:
        model.load_state_dict(ckpt["model"], strict=True)
        print("[Load] Raw weights.")
    return cfg, model

def load_esm_only(cfg: Config, n_rbps: int) -> torch.Tensor:
    assert cfg.USE_ESM2 and not cfg.USE_PROTT5, "This script expects ESM2 only."
    E = np.load(ESM_TEST).astype(np.float32)
    assert E.shape[0] == n_rbps, f"ESM test rows ({E.shape[0]}) != RBPs ({n_rbps})"
    if getattr(cfg, "PROT_SRC_ZSCORE", False) and os.path.exists(ESM_TRAIN):
        T = np.load(ESM_TRAIN).astype(np.float32)
        mu, sd = T.mean(0, keepdims=True), T.std(0, keepdims=True); sd[sd < 1e-6] = 1.0
        E = (E - mu) / sd
        print("[Prot] Applied train z-score to ESM test embeddings.")
    print(f"[Prot] ESM test embeddings: {E.shape}")
    return torch.from_numpy(E).to(DEVICE)

def load_struct_slice(start_idx: int, seqs: List[str]) -> torch.Tensor:
    """Return [B, Lmax, 5] PHIME for RNA lines [start_idx : start_idx+B)."""
    B = len(seqs)
    lens = [len(s) for s in seqs]
    Lmax = max(lens) if lens else 0
    out = np.zeros((B, Lmax, 5), np.float32)
    for b, i in enumerate(range(start_idx, start_idx+B)):
        L = lens[b]
        path = os.path.join(STRUCT_DIR, f"seq_{i:06d}.npz")
        S = None
        try:
            with np.load(path) as z:
                if "PHIME" in z:
                    S = np.array(z["PHIME"], np.float32)
                elif "PLUM" in z:
                    P, Lh, Uu, Mm = (np.array(z["PLUM"], np.float32).T)
                    I = 0.5 * Mm; M2 = Mm - I
                    S = np.stack([P, Lh, I, M2, Uu], axis=-1).astype(np.float32)
        except Exception:
            S = None
        if S is None:
            S = np.zeros((L,5), np.float32); S[:,4] = 1.0
        if S.shape[0] != L:
            S = S[:L] if S.shape[0] > L else np.vstack([S, np.zeros((L - S.shape[0], 5), np.float32)])
        out[b, :L, :] = S
    return torch.from_numpy(out)

# ==============================
# Main (no sys.argv)
# ==============================
def main():
    # Sanity checks
    for p in [RBP_FILE, RNA_FILE, CKPT_PATH, ESM_TEST]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing required path: {p}")
    if not os.path.isdir(STRUCT_DIR):
        raise FileNotFoundError(f"Missing struct dir: {STRUCT_DIR}")
    os.makedirs(OUT_DIR, exist_ok=True)

    rbps = read_lines(RBP_FILE, normalize_rna=False)
    rnas = read_lines(RNA_FILE, normalize_rna=True)
    print(f"[Load] RBPs: {len(rbps)} | RNAs: {len(rnas)} | Device: {DEVICE}")

    cfg, model = load_checkpoint(CKPT_PATH, DEVICE)
    if getattr(cfg, "RNA_USE_STRUCT", False) and not os.path.isdir(STRUCT_DIR):
        raise RuntimeError("cfg.RNA_USE_STRUCT=True but STRUCT_DIR not found.")

    # --- Protein side: ESM only; build TTA projections once (same as training evaluate) ---
    esm = load_esm_only(cfg, len(rbps))            # [P, Din], already z-scored if PROT_SRC_ZSCORE
    with torch.no_grad():
        if TTA_USE:
            e_p_list = []
            for k in range(TTA_N):
                vecs = esm if k == 0 else (esm + TTA_NOISE * torch.randn_like(esm))
                e_p_list.append(model.project_prot(vecs))  # each [P, D]
        else:
            e_p_single = model.project_prot(esm)           # [P, D]
    P = (e_p_list[0] if TTA_USE else e_p_single).size(0)

    # Open output files
    files = [open(os.path.join(OUT_DIR, f"RBP{BASE_OFFSET + i + 1}.txt"), "w") for i in range(P)]

    # Stream RNAs in order
    N = len(rnas)
    with torch.no_grad():
        for a in range(0, N, RNA_BATCH):
            b = min(a + RNA_BATCH, N)
            batch = rnas[a:b]
            tok = tokenize_rna_batch(cfg, batch).to(DEVICE)
            st  = load_struct_slice(a, batch).to(DEVICE) if getattr(cfg, "RNA_USE_STRUCT", False) else None
            e_r = model.encode_rna(tok, st)                # [B, D]

            # TTA over protein side
            if TTA_USE:
                S_acc = None
                for e_p in e_p_list:
                    S_k = model.score(e_p, e_r).detach().cpu().numpy()   # [P, B]
                    S_acc = S_k if S_acc is None else (S_acc + S_k)
                S = S_acc / float(TTA_N)
            else:
                S = model.score(e_p_single, e_r).detach().cpu().numpy()  # [P, B]

            # Append lines (preserve RNA order)
            for j in range(P):
                files[j].write("".join(FMT % v + "\n" for v in S[j]))
            if ((a // RNA_BATCH) % 10) == 0:
                print(f"[Prog] {b}/{N} ({b/N:.1%})")

    for f in files: f.close()
    print(f"[Done] Wrote {P} files in {OUT_DIR}/  (RBP{BASE_OFFSET+1}.txt .. RBP{BASE_OFFSET+P}.txt)")

if __name__ == "__main__":
    main()


In [None]:
# Unzip struct_train.zip into ./struct_train_unzipped
import shutil

shutil.unpack_archive("/content/struct_test.zip", "struct_test", "zip")
print("Extracted to: struct_test")


In [None]:
!zip -r -q predictions.zip predictions

In [None]:
import os, glob, zipfile

src = "predictions"
out = "predictions_split9"
group = 9

os.makedirs(out, exist_ok=True)

# natural-ish sort: if you need strict natural sort, use `natsort` library; otherwise this is fine for many cases
files = sorted(glob.glob(os.path.join(src, "*.txt")), key=lambda p: (os.path.basename(p).lower()))

for i in range(0, len(files), group):
    part = i//group + 1
    zpath = os.path.join(out, f"predictions_{part:03d}.zip")
    with zipfile.ZipFile(zpath, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        for f in files[i:i+group]:
            zf.write(f, arcname=os.path.basename(f))
    print("wrote", zpath)
