
# BigProtein-Qwen2.5 — Step‑by‑Step Test Notebook (Colab)
This notebook lets you **test each component** of the protein‑conditioned Qwen2.5 pipeline *before* running full training.  
It mirrors the main script logic, but runs **function‑by‑function** so you can see errors early with clear tracebacks.

> **Files expected in the working directory** (upload or mount a folder containing them):  
> - `bigmodel_joint_train.py`  
> - `protein_encoder.py`  
> - `structure_encoder.py`


In [1]:
#@title Mount Google Drive
from pathlib import Path
from huggingface_hub import snapshot_download
import os, json, pickle, pandas as pd
from tqdm import tqdm
from rich import print as rprint

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

%cd /content/drive/MyDrive/LLM/Bioreasoner/testing_pipelines

from pathlib import Path
BASE_DIR = Path()
OUT_DIR  = BASE_DIR / "sft_test_demo"
print(f"Using Google Drive folder as BASE_DIR: {BASE_DIR}")


Mounted at /content/drive
/content/drive/MyDrive/LLM/Bioreasoner/testing_pipelines
Using Google Drive folder as BASE_DIR: /content/drive/MyDrive/LLM/Bioreasoner/data/hf/proteinDT



## 0) Runtime & Installs
If you're on Google Colab, run this cell to install dependencies.


In [3]:
# --- Make sure no local "transformers" folder shadows the pip package ---
import sys, os, glob
if os.path.isdir("transformers"):
    print("⚠️ Local 'transformers' directory detected in CWD; this will shadow the pip package.")
    print("   -> rename or remove it before continuing.")
else:
    print("No local 'transformers' directory in CWD. Good to go.")

No local 'transformers' directory in CWD. Good to go.


In [4]:
# --- Clean out conflicting wheels so we can pin exactly what we want ---
%pip uninstall -y -q torch torchvision torchaudio xformers transformers tokenizers huggingface_hub

# --- Install PyTorch 2.8.0 (CUDA 12.6 wheels) + matching torchvision/torchaudio ---
# (If you don't have a GPU runtime, remove the --index-url line)
%pip install -q --index-url https://download.pytorch.org/whl/cu126 \
  torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0

# --- Install Transformers & HF Hub (your requested versions) ---
%pip install -q "transformers==4.56.1" "huggingface_hub==0.35.0"

# --- Extras you used elsewhere ---
%pip install -q peft accelerate datasets tqdm rich pandas

[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 5.1.0 requires huggingface-hub>=0.20.0, which is not installed.
sentence-transformers 5.1.0 requires transformers<5.0.0,>=4.41.0, which is not installed.
timm 1.0.19 requires huggingface_hub, which is not installed.
accelerate 1.10.1 requires huggingface_hub>=0.21.0, which is not installed.
torchtune 0.6.1 requires huggingface_hub[hf_transfer], which is not installed.
torchtune 0.6.1 requires tokenizers, which is not installed.
peft 0.17.1 requires huggingface_hub>=0.25.0, which is not installed.
peft 0.17.1 requires transformers, which is not installed.[0m[31m
[0m

In [5]:
# --- Install your encoder requirements, but ignore any Torch/Transformers pins inside ---
import re, pathlib, textwrap, sys

REQ_FILE = pathlib.Path("requirements_encoders.txt")
if REQ_FILE.exists():
    lines = REQ_FILE.read_text().splitlines()
    # Drop lines that would force Torch/TV/TA/Transformers downgrades
    keep = [
        ln for ln in lines
        if not re.match(r'^\s*(torch|torchvision|torchaudio|transformers)\b', ln, flags=re.IGNORECASE)
    ]
    tmp = pathlib.Path("/content/encoders_req_no_torch.txt")
    tmp.write_text("\n".join(keep) + "\n")
    print("Installing sanitized encoder requirements (Torch/Transformers lines removed):")
    for ln in lines:
        if ln not in keep and ln.strip():
            print("  - removed:", ln)
    %pip install -q -r {str(tmp)}
else:
    print("No requirements_encoders.txt found; skipping.")

Installing sanitized encoder requirements (Torch/Transformers lines removed):
  - removed: torch>=2.0.1,<2.3
  - removed: transformers>=4.38,<5


In [6]:
# --- Version & import sanity checks ---
import torch, transformers, huggingface_hub
print("torch            :", torch.__version__)
print("transformers     :", transformers.__version__)
print("huggingface_hub  :", huggingface_hub.__version__)

# Top-level ESM import should work on 4.56.1
try:
    from transformers import AutoTokenizer, EsmForMaskedLM
    print("✅ Top-level EsmForMaskedLM import OK")
except Exception as e:
    print("❌ Top-level EsmForMaskedLM import failed:", repr(e))
    # Fallback check (direct module path)
    try:
        from transformers.models.esm.modeling_esm import EsmForMaskedLM as _E
        print("✅ Direct modeling_esm import OK (fallback)")
    except Exception as ee:
        print("❌ Direct modeling_esm import failed too:", repr(ee))

torch            : 2.8.0+cu126
transformers     : 4.56.1
huggingface_hub  : 0.35.0
✅ Top-level EsmForMaskedLM import OK


torch            : 2.8.0+cu126

transformers     : 4.56.1

huggingface_hub  : 0.35.0

✅ Top-level EsmForMaskedLM import OK


## 1) Loading Encoder Checkpoints

In [7]:
#@title  Reloading encoders
import importlib, os, torch

import protein_encoder as protein_encoder_mod
import structure_encoder as structure_encoder_mod
import bigmodel_joint_train as train_mod

def reload_all():
    importlib.reload(protein_encoder_mod)
    importlib.reload(structure_encoder_mod)
    importlib.reload(train_mod)
    print("Reloaded modules.")

reload_all()
print("Torch:", torch.__version__)


Reloaded modules.
Torch: 2.8.0+cu126


In [8]:
print("transformers:", transformers.__version__)
print("torch:", torch.__version__)

transformers: 4.56.1
torch: 2.8.0+cu126


In [9]:
# ==== Robust loader for ParameterList-style ProTrek checkpoints ====
from pathlib import Path
from collections import defaultdict, Counter
import importlib, torch

# --- Your paths ---
PROTEIN_CONFIG   = Path("/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/esm2_t12_35M_UR50D")
STRUCTURE_CONFIG = Path("/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/foldseek_t12_35M")
CKPT             = Path("/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/ProTrek_35M.pt")

# --- Import/reload your encoder classes ---
import protein_encoder as protein_encoder_mod
import structure_encoder as structure_encoder_mod
importlib.reload(protein_encoder_mod)
importlib.reload(structure_encoder_mod)
ProteinEncoder   = protein_encoder_mod.ProteinEncoder
StructureEncoder = structure_encoder_mod.StructureEncoder

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# --- Build fresh (random-init) encoders from configs only ---
prot_enc = ProteinEncoder(str(PROTEIN_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()
stru_enc = StructureEncoder(str(STRUCTURE_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()

# --- Utility: cosine similarity between two batches of embeddings ---
@torch.no_grad()
def cosine_sim(a, b):
    a = torch.nn.functional.normalize(a, dim=-1)
    b = torch.nn.functional.normalize(b, dim=-1)
    return float((a*b).sum(dim=-1).mean().cpu())

# --- Make baseline outputs BEFORE loading (to prove they change) ---
@torch.no_grad()
def get_prot_repr(model):
    seqs = ["MKTFFVAIATGAFSATA", "MGDVEKGKKIFIMKCSQCHTVEK"]  # toy AA
    return model.get_repr(seqs, batch_size=2, verbose=False).to("cpu")

@torch.no_grad()
def get_stru_repr(model):
    seqs = ["acdefghiklmnpqrstvwy", "acdefghi"]  # toy 3Di (foldseek) strings
    return model.get_repr(seqs, batch_size=2, verbose=False).to("cpu")

prot_before = get_prot_repr(prot_enc)
stru_before = get_stru_repr(stru_enc)

# --- Load checkpoint & locate the real state_dict (handles common wrappers) ---
raw = torch.load(str(CKPT), map_location="cpu")

def locate_state_dict(obj):
    if isinstance(obj, dict):
        for k in ("model", "state_dict", "model_state_dict", "weights", "params"):
            if k in obj and isinstance(obj[k], dict) and any(torch.is_tensor(v) for v in obj[k].values()):
                return obj[k]
        if any(torch.is_tensor(v) for v in obj.values()):
            return obj
    return obj  # fallback

sd = locate_state_dict(raw)
print(f"Checkpoint tensors: {len(sd)}")

# --- Split by ParameterList slot: top-level token before the first dot, if it's an integer ---
slots = defaultdict(dict)
for k, v in sd.items():
    head = k.split(".", 1)[0]
    if head.isdigit():
        slots[int(head)][k[len(head)+1:]] = v  # strip "N." prefix
    else:
        slots[None][k] = v  # non-slotted (rare)

print("Detected slots:", sorted([s for s in slots.keys() if s is not None]),
      "(None present)" if None in slots else "")

# --- Pick best slot for each encoder by # of exact key matches against target state_dict keys ---
def best_slot_for(module, slots_dict):
    tgt = set(module.state_dict().keys())
    best = (None, 0)
    for s, sub in slots_dict.items():
        if s is None:  # skip un-slotted container for this matching
            continue
        hits = sum(1 for k in sub.keys() if k in tgt)
        if hits > best[1]:
            best = (s, hits)
    return best  # (slot, hits)

prot_slot, prot_hits = best_slot_for(prot_enc, slots)
stru_slot, stru_hits = best_slot_for(stru_enc, slots)

print(f"Candidate slot for ProteinEncoder : {prot_slot} (exact-key hits={prot_hits})")
print(f"Candidate slot for StructureEncoder: {stru_slot} (exact-key hits={stru_hits})")

# --- Load with strict=False and report matched/missing/unexpected + preview of matched keys ---
def load_from_slot(module, slot_idx, tag):
    sub = slots.get(slot_idx, {})
    tgt_keys = set(module.state_dict().keys())
    matched_keys = sorted([k for k in sub.keys() if k in tgt_keys])
    missing, unexpected = module.load_state_dict(sub, strict=False)
    print(f"\n[{tag}] slot={slot_idx}  loaded_tensors={len(sub)}")
    print(f"  matched={len(matched_keys)}  missing={len(missing)}  unexpected={len(unexpected)}")
    print("  matched keys (first 20):", matched_keys[:20])
    if missing:
        print("  missing (first 12):", list(missing)[:12])
    if unexpected:
        print("  unexpected (first 12):", list(unexpected)[:12])

load_from_slot(prot_enc, 1, "ProteinEncoder")
load_from_slot(stru_enc, 3, "StructureEncoder")

# --- Prove it changed: cosine(before, after) ---
prot_after = get_prot_repr(prot_enc)
stru_after = get_stru_repr(stru_enc)

print("\nCosine(protein  before vs after) :", cosine_sim(prot_before, prot_after))
print("Cosine(structure before vs after) :", cosine_sim(stru_before, stru_after))
print("\nDone.")

Device: cuda
Checkpoint tensors: 633
Detected slots: [0, 1, 2, 3] 
Candidate slot for ProteinEncoder : 1 (exact-key hits=215)
Candidate slot for StructureEncoder: 1 (exact-key hits=215)

[ProteinEncoder] slot=1  loaded_tensors=216
  matched=215  missing=0  unexpected=1
  matched keys (first 20): ['model.esm.embeddings.word_embeddings.weight', 'model.esm.encoder.emb_layer_norm_after.bias', 'model.esm.encoder.emb_layer_norm_after.weight', 'model.esm.encoder.layer.0.LayerNorm.bias', 'model.esm.encoder.layer.0.LayerNorm.weight', 'model.esm.encoder.layer.0.attention.LayerNorm.bias', 'model.esm.encoder.layer.0.attention.LayerNorm.weight', 'model.esm.encoder.layer.0.attention.output.dense.bias', 'model.esm.encoder.layer.0.attention.output.dense.weight', 'model.esm.encoder.layer.0.attention.self.key.bias', 'model.esm.encoder.layer.0.attention.self.key.weight', 'model.esm.encoder.layer.0.attention.self.query.bias', 'model.esm.encoder.layer.0.attention.self.query.weight', 'model.esm.encoder.laye

ParameterList indexing : 0=temp, 1=protein, 2=text, 3=structure

(hard coded by protrek, just use, no need for fix)

Interpretation tips:

  • If 'matched' above is large (hundreds+) and cosine(sim) << 0.99, weights changed -> checkpoint likely loaded.

  • If matched=0 and cosine(sim)≈1.0, nothing changed -> still random.
  
  • identical_params counts exact-equal tensors before vs after; lower is better (more changed)

Currently, Position idx is not correctly loaded from .pt files to the encoder model for unknown reason. However, it is not included in current model; I will ignore it unless there are further evidence showing this actually matters.

In [10]:
# ===== Check what ESM expects vs. what the checkpoint provides (and optionally load) =====
from collections import Counter
from pathlib import Path
import torch, re, importlib

# --- config paths (already set above normally) ---
# PROTEIN_CONFIG, STRUCTURE_CONFIG, CKPT

# --- (re)import encoders to be safe ---
import protein_encoder as protein_encoder_mod
import structure_encoder as structure_encoder_mod
importlib.reload(protein_encoder_mod)
importlib.reload(structure_encoder_mod)
ProteinEncoder   = protein_encoder_mod.ProteinEncoder
StructureEncoder = structure_encoder_mod.StructureEncoder

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# Build fresh (random) encoders to inspect what keys THEY expect
prot = ProteinEncoder(str(PROTEIN_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()
stru = StructureEncoder(str(STRUCTURE_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()

prot_keys = list(prot.state_dict().keys())
stru_keys = list(stru.state_dict().keys())

def prefix_hist(keys, title=""):
    print(f"\n== {title} ==")
    print(f"Total keys: {len(keys)}")
    print("Sample keys:", keys[:25])
    def head1(k): return k.split(".", 1)[0]
    def head2(k):
        parts = k.split(".")
        return ".".join(parts[:2]) if len(parts)>=2 else parts[0]
    c1 = Counter(head1(k) for k in keys)
    c2 = Counter(head2(k) for k in keys)
    print("\nTop-level prefixes:")
    for k,v in c1.most_common(12): print(f"  {k:28s} {v}")
    print("\nTwo-token prefixes:")
    for k,v in c2.most_common(12): print(f"  {k:28s} {v}")

prefix_hist(prot_keys, "ProteinEncoder expected keys (target)")
prefix_hist(stru_keys, "StructureEncoder expected keys (target)")

# ---- Load checkpoint and locate the real state dict inside it ----
CKPT = Path(CKPT)
raw = torch.load(str(CKPT), map_location="cpu")

def locate_state_dict(obj, max_depth=3):
    if isinstance(obj, dict):
        # common containers
        for k in ("state_dict","model_state_dict","weights","params","model"):
            if k in obj and isinstance(obj[k], dict) and any(torch.is_tensor(v) for v in obj[k].values()):
                return obj[k]
        if any(torch.is_tensor(v) for v in obj.values()):
            return obj
        if max_depth > 0:
            for v in obj.values():
                if isinstance(v, dict):
                    sd = locate_state_dict(v, max_depth=max_depth-1)
                    if sd is not None: return sd
    return None

sd0 = locate_state_dict(raw) if isinstance(raw, dict) else raw
if sd0 is None:
    print("\n❌ Could not locate a tensor state_dict in checkpoint; top-level keys:",
          list(raw.keys())[:20] if isinstance(raw, dict) else type(raw))
    raise SystemExit

ckpt_keys = list(sd0.keys())
prefix_hist(ckpt_keys, f"Checkpoint keys found in: {CKPT.name}")

# ---- Try to slice protein/structure parts by plausible prefixes ----
def strip_prefix(d, prefix):
    if prefix and not any(k.startswith(prefix) for k in d): return d
    return {(k[len(prefix):] if k.startswith(prefix) else k): v for k,v in d.items()}

def extract_slice(sd, candidates):
    sd = strip_prefix(strip_prefix(sd, "model."), "module.")
    best = {}
    best_pref = None
    for pref in candidates:
        sub = {(k[len(pref):] if k.startswith(pref) else k): v
               for k,v in sd.items() if (pref=="" or k.startswith(pref))}
        if len(sub) > len(best):
            best, best_pref = sub, pref
    return best_pref, best

# widen if needed after you see the histogram
prot_prefix_candidates = (
    "protein_encoder.", "module.protein_encoder.", "model.protein_encoder.",
    "protein.", "seq_encoder.", "sequence_encoder.",  # extra guesses
    ""
)
stru_prefix_candidates = (
    "structure_encoder.", "module.structure_encoder.", "model.structure_encoder.",
    "structure.", "struct_encoder.", "foldseek_encoder.",                    # extra guesses
    ""
)

prot_pref, prot_sub = extract_slice(sd0, prot_prefix_candidates)
stru_pref, stru_sub = extract_slice(sd0, stru_prefix_candidates)

print(f"\nProtein slice prefix chosen: {repr(prot_pref)}  (#tensors={len(prot_sub)})")
print(f"Structure slice prefix chosen: {repr(stru_pref)} (#tensors={len(stru_sub)})")

# ---- Show a few sub-keys (post-slicing) ----
print("\nProtein sub-keys sample:", list(prot_sub.keys())[:20])
print("Structure sub-keys sample:", list(stru_sub.keys())[:20])

# ---- Build mapping functions and pick the one with most hits ----
def choose_mapping(sub_keys, target_keys):
    # conservative + esm-aware mappers
    def as_is(k): return k
    def add_model(k):
        if k.startswith(("model.","out.")): return k
        return "model."+k
    def esm_to_model_esm(k):
        if k.startswith(("model.","out.")): return k
        # common ESM bits to sit under model.esm.*
        head = k.split(".",1)[0]
        if head in ("esm","encoder","embeddings","contact_head"):
            return "model.esm."+k if not k.startswith("esm.") else "model."+k   # handle "esm.*" already complete
        if head in ("lm_head",):
            return "model."+k
        return "model."+k

    candidates = {
        "as_is": as_is,
        "add_model": add_model,
        "esm_to_model_esm": esm_to_model_esm,
    }
    best_name, best_fn, best_hits = "as_is", as_is, 0
    tgt = set(target_keys)
    for name, fn in candidates.items():
        hits = sum(1 for k in sub_keys if fn(k) in tgt)
        if hits > best_hits:
            best_name, best_fn, best_hits = name, fn, hits
    return best_name, best_fn, best_hits

prot_map_name, prot_map_fn, prot_hits = choose_mapping(set(prot_sub.keys()), prot_keys)
stru_map_name, stru_map_fn, stru_hits = choose_mapping(set(stru_sub.keys()), stru_keys)

print(f"\nBest mapping for PROTEIN: {prot_map_name}  (hits={prot_hits}/{len(prot_sub)})")
print(f"Best mapping for STRUCT  : {stru_map_name}  (hits={stru_hits}/{len(stru_sub)})")

# ---- Optionally apply (dry-run or real) ----
DO_LOAD = True  # set False to only analyze keys without loading

def remap_and_load(module, sub, map_fn, tag):
    remapped = {map_fn(k): v for k,v in sub.items()}
    missing, unexpected = module.load_state_dict(remapped, strict=False)
    print(f"[{tag}] loaded_tensors={len(remapped)}  missing={len(missing)}  unexpected={len(unexpected)}")
    if missing:    print("  • missing (first 12):   ", list(missing)[:12])
    if unexpected: print("  • unexpected (first 12):", list(unexpected)[:12])

if DO_LOAD:
    remap_and_load(prot, prot_sub, prot_map_fn, "ProteinEncoder")
    remap_and_load(stru, stru_sub, stru_map_fn, "StructureEncoder")

# ---- Concrete proof: compare outputs before/after on same inputs ----
@torch.no_grad()
def cosine_sim(a, b):
    a = torch.nn.functional.normalize(a, dim=-1)
    b = torch.nn.functional.normalize(b, dim=-1)
    return float((a*b).sum(dim=-1).mean().cpu())

@torch.no_grad()
def get_prot(prot_model):
    seqs = ["MKTFFVAIATGAFSATA", "MGDVEKGKKIFIMKCSQCHTVEK"]
    return prot_model.get_repr(seqs, batch_size=2, verbose=False).to("cpu")

@torch.no_grad()
def get_stru(stru_model):
    seqs = ["acdefghiklmnpqrstvwy", "acdefghi"]
    return stru_model.get_repr(seqs, batch_size=2, verbose=False).to("cpu")

# fresh random baseline encoders
prot_rand = ProteinEncoder(str(PROTEIN_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()
stru_rand = StructureEncoder(str(STRUCTURE_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()

prot_before, prot_after = get_prot(prot_rand), get_prot(prot)
stru_before, stru_after = get_stru(stru_rand), get_stru(stru)

print("\nCosine(protein random vs loaded)  :", cosine_sim(prot_before, prot_after))
print("Cosine(struct  random vs loaded)  :", cosine_sim(stru_before, stru_after))

Device: cuda

== ProteinEncoder expected keys (target) ==
Total keys: 215
Sample keys: ['model.esm.embeddings.word_embeddings.weight', 'model.esm.encoder.layer.0.attention.self.query.weight', 'model.esm.encoder.layer.0.attention.self.query.bias', 'model.esm.encoder.layer.0.attention.self.key.weight', 'model.esm.encoder.layer.0.attention.self.key.bias', 'model.esm.encoder.layer.0.attention.self.value.weight', 'model.esm.encoder.layer.0.attention.self.value.bias', 'model.esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'model.esm.encoder.layer.0.attention.output.dense.weight', 'model.esm.encoder.layer.0.attention.output.dense.bias', 'model.esm.encoder.layer.0.attention.LayerNorm.weight', 'model.esm.encoder.layer.0.attention.LayerNorm.bias', 'model.esm.encoder.layer.0.intermediate.dense.weight', 'model.esm.encoder.layer.0.intermediate.dense.bias', 'model.esm.encoder.layer.0.output.dense.weight', 'model.esm.encoder.layer.0.output.dense.bias', 'model.esm.encoder.layer.0.Layer


## 3) Quick Configs (Edit Me)
Set your small model variants here. You can pass either a **local path** to a config/checkpoint or a **Hugging Face ID**.
For first tests, pick small backbones to speed things up (e.g., `Qwen/Qwen2.5-0.5B-Instruct`, `facebook/esm2_t12_35M_UR50D`).


In [11]:

# === LLM & Encoders ===
MODEL_NAME         = "Qwen/Qwen2.5-0.5B-Instruct"   # Small-ish for Colab testing
PROTEIN_CONFIG     = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/esm2_t12_35M_UR50D"  # or local path to a config/checkpoint
STRUCTURE_CONFIG   = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/foldseek_t12_35M"  # for structure encoder (uses same ESM family for demo)

# === Prefix/Proj ===
SINGLE_TOKEN_PREFIX = False     # True -> 1 token; False -> soft prefix of length PREFIX_LEN
PREFIX_LEN          = 4
PROJ_HID            = 1024
DROPOUT             = 0.10

# === Training toggles ===
USE_LORA            = False
TRAIN_ENCODERS      = False    # True = end-to-end; False = freeze encoders
FREEZE_PROTEIN      = False    # only used if TRAIN_ENCODERS=True
FREEZE_STRUCTURE    = False    # only used if TRAIN_ENCODERS=True
GRAD_CHECKPOINT     = False

# === Misc ===
DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN             = 512
BSZ                 = 2
ACCUM               = 1
LR                  = 5e-5
WARMUP_RATIO        = 0.03
EPOCHS              = 1
OUTPUT_DIR          = "runs/colab_smoketest"
LOG_EVERY           = 1

print("Device:", DEVICE)


Device: cuda



## 4) Minimal JSONL Toy Data



In [12]:
# --- PATHS: edit these three if needed ---
SFT_JSONL = "/content/drive/MyDrive/LLM/Bioreasoner/data//hf/proteinDT/sft_jsonl/sft_test_sample.jsonl"  # copy your toy file here
PROTEIN_CONFIG   = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/esm2_t12_35M_UR50D"
STRUCTURE_CONFIG = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/foldseek_t12_35M"
CKPT             = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/ProTrek_35M.pt"  # ProTrek .pt

from pathlib import Path
for p in [SFT_JSONL, PROTEIN_CONFIG, STRUCTURE_CONFIG, CKPT]:
    print(str(p), "OK" if Path(p).exists() else "MISSING")

/content/drive/MyDrive/LLM/Bioreasoner/data//hf/proteinDT/sft_jsonl/sft_test_sample.jsonl OK
/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/esm2_t12_35M_UR50D OK
/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/foldseek_t12_35M OK
/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/ProTrek_35M.pt OK


In [13]:
import json, re
from statistics import mean

AA_RE  = re.compile(r"^[ACDEFGHIKLMNPQRSTVWY]+$")   # uppercase 20 AA
FS_RE  = re.compile(r"^[a-z]+$")                    # loose: lowercase letters

def validate_jsonl(path, max_show=10):
    ok, issues = 0, []
    lens_prompt, lens_response = [], []
    with open(path, "r") as f:
        for i, line in enumerate(f, 1):
            try:
                obj = json.loads(line)
            except Exception as e:
                issues.append((i, f"bad json: {e}")); continue

            # required
            if "prompt" not in obj or "response" not in obj:
                issues.append((i, "missing prompt/response")); continue
            if not isinstance(obj["prompt"], str) or not isinstance(obj["response"], str):
                issues.append((i, "prompt/response not string")); continue

            # optional aa_seq / stru_str
            if "aa_seq" in obj and obj["aa_seq"] is not None:
                if not AA_RE.match(obj["aa_seq"]):
                    issues.append((i, "aa_seq contains non-standard chars"))
            if "stru_str" in obj and obj["stru_str"] is not None:
                if not FS_RE.match(obj["stru_str"]):
                    issues.append((i, "stru_str must be lowercase alpha (3Di-like)"))

            lens_prompt.append(len(obj["prompt"]))
            lens_response.append(len(obj["response"]))
            ok += 1

    print(f"✔ valid lines: {ok}")
    if lens_prompt and lens_response:
        print(f" avg prompt len: {int(mean(lens_prompt))} chars | avg response len: {int(mean(lens_response))} chars")
    if issues:
        print(f"⚠ found {len(issues)} issues (showing first {max_show}):")
        for row, msg in issues[:max_show]:
            print(f"  line {row}: {msg}")
    else:
        print("No schema/character issues found.")

validate_jsonl(SFT_JSONL)

✔ valid lines: 5
 avg prompt len: 184 chars | avg response len: 259 chars
No schema/character issues found.


In [14]:
# --- Cell 3 (hard-coded slots: protein=1, structure=3) ---
import importlib, json, torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# Import your encoders
import protein_encoder as protein_encoder_mod
import structure_encoder as structure_encoder_mod
importlib.reload(protein_encoder_mod)
importlib.reload(structure_encoder_mod)
ProteinEncoder   = protein_encoder_mod.ProteinEncoder
StructureEncoder = structure_encoder_mod.StructureEncoder

# Build encoders from configs ONLY (random init)
prot_enc = ProteinEncoder(str(PROTEIN_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()
stru_enc = StructureEncoder(str(STRUCTURE_CONFIG), out_dim=1024, load_pretrained=False).to(DEVICE).eval()

# Load checkpoint and slice by ParameterList slots
raw = torch.load(str(CKPT), map_location="cpu")
sd  = raw["model"] if (isinstance(raw, dict) and "model" in raw) else (
      raw.get("state_dict", raw) if isinstance(raw, dict) else raw)

# Build slot dict: {slot_index: {param_key_without_prefix: tensor}}
slots = {}
for k, v in sd.items():
    head = k.split(".", 1)[0]
    if head.isdigit():
        slots.setdefault(int(head), {})[k[len(head)+1:]] = v

# Hard-code slots (adjust if your checkpoint order differs)
PROT_SLOT = 1   # protein encoder lives here in your .pt
STRU_SLOT = 3   # structure encoder lives here in your .pt
print("Available slots in checkpoint:", sorted(slots.keys()))

if PROT_SLOT not in slots or STRU_SLOT not in slots:
    raise KeyError(
        f"Hard-coded slots not found in checkpoint. "
        f"Have slots={sorted(slots.keys())}, need PROT_SLOT={PROT_SLOT}, STRU_SLOT={STRU_SLOT}."
        # If needed, you can re-enable the tiny auto-finder:
        # def best_slot_for(module):
        #     tgt = set(module.state_dict().keys())
        #     return max(slots.items(), key=lambda kv: sum(1 for k in kv[1] if k in tgt))[0]
        # PROT_SLOT = best_slot_for(prot_enc); STRU_SLOT = best_slot_for(stru_enc)
    )

# Optionally drop harmless extras (e.g., position_ids when using RoPE)
def drop_extras(sd_sub: dict):
    bad = [k for k in sd_sub if "embeddings.position_ids" in k]
    for k in bad: sd_sub.pop(k)
    return sd_sub

prot_sub = drop_extras(dict(slots[PROT_SLOT]))
stru_sub = drop_extras(dict(slots[STRU_SLOT]))

# Report matches, then load with strict=False
def report_and_load(module, sub, tag):
    tgt = set(module.state_dict().keys())
    matched_keys = sorted([k for k in sub if k in tgt])
    missing, unexpected = module.load_state_dict(sub, strict=False)
    print(f"\n[{tag}] slot loaded: matched={len(matched_keys)}  missing={len(missing)}  unexpected={len(unexpected)}")
    print("  matched (first 20):", matched_keys[:20])
    if missing:    print("  missing (first 12):   ", list(missing)[:12])
    if unexpected: print("  unexpected (first 12):", list(unexpected)[:12])

report_and_load(prot_enc, prot_sub, "ProteinEncoder")
report_and_load(stru_enc, stru_sub, "StructureEncoder")

# Quick forward on toy JSONL (ensure 'toy' exists; otherwise read it)
try:
    toy
except NameError:
    toy = [json.loads(x) for x in open(SFT_JSONL)]

aa_list   = [ex.get("aa_seq") or ""    for ex in toy]
stru_list = [ex.get("stru_str") or ""  for ex in toy]

with torch.no_grad():
    prot_vecs = prot_enc.get_repr([s for s in aa_list if s],   batch_size=8, verbose=False) if any(aa_list) else None
    stru_vecs = stru_enc.get_repr([s for s in stru_list if s], batch_size=8, verbose=False) if any(stru_list) else None

print("\nEmbeddings:")
print("  Protein vecs:", None if prot_vecs is None else tuple(prot_vecs.shape))
print("  Struct  vecs:", None if stru_vecs is None else tuple(stru_vecs.shape))

Device: cuda
Available slots in checkpoint: [0, 1, 2, 3]

[ProteinEncoder] slot loaded: matched=215  missing=0  unexpected=0
  matched (first 20): ['model.esm.embeddings.word_embeddings.weight', 'model.esm.encoder.emb_layer_norm_after.bias', 'model.esm.encoder.emb_layer_norm_after.weight', 'model.esm.encoder.layer.0.LayerNorm.bias', 'model.esm.encoder.layer.0.LayerNorm.weight', 'model.esm.encoder.layer.0.attention.LayerNorm.bias', 'model.esm.encoder.layer.0.attention.LayerNorm.weight', 'model.esm.encoder.layer.0.attention.output.dense.bias', 'model.esm.encoder.layer.0.attention.output.dense.weight', 'model.esm.encoder.layer.0.attention.self.key.bias', 'model.esm.encoder.layer.0.attention.self.key.weight', 'model.esm.encoder.layer.0.attention.self.query.bias', 'model.esm.encoder.layer.0.attention.self.query.weight', 'model.esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'model.esm.encoder.layer.0.attention.self.value.bias', 'model.esm.encoder.layer.0.attention.self.value

In [15]:
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"

tok = AutoTokenizer.from_pretrained(MODEL_NAME)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Tip: prefer dtype= over torch_dtype= (newer API wording)
llm = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.float32#(torch.float32 if DEVICE == "cuda" else torch.float32)
).to(DEVICE).eval()

print("LLM hidden size:", llm.config.hidden_size)
print("Pad token id:", tok.pad_token_id, "| EOS id:", tok.eos_token_id)

# Smoke test: plain forward on text
batch = tok([toy[0]["prompt"] + "\n\n" + toy[0]["response"]],
            return_tensors="pt", padding=True).to(DEVICE)

with torch.no_grad():
    out = llm(**batch, return_dict=True)  # default: no hidden states
print("LLM logits:", tuple(out.logits.shape))  # (B, T, vocab)

# If you ALSO want last hidden states:
with torch.no_grad():
    out_h = llm(**batch, output_hidden_states=True, return_dict=True)
print("LLM last hidden:", tuple(out_h.hidden_states[-1].shape))  # (B, T, H)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


LLM hidden size: 896
Pad token id: 151643 | EOS id: 151645
LLM logits: (1, 97, 151936)
LLM last hidden: (1, 97, 896)


Cell 5 — Projector + prefix injection (single-token prefix by default)

In [16]:
# ==== Projector + prefix injection with proper padding & dtype alignment ====
import torch
import torch.nn as nn

PREFIX_LEN = 1
D_PROT, D_STRU = 1024, 1024
D_IN = D_PROT + D_STRU
D_HID = llm.config.hidden_size

# NEW: capture the model's dtype (fp16 on GPU in your setup)
MODEL_DTYPE = next(llm.parameters()).dtype  # e.g., torch.float16

projector = nn.Sequential(
    nn.Linear(D_IN, D_HID),
    nn.SiLU(),
    nn.Linear(D_HID, D_HID * PREFIX_LEN)
).to(DEVICE, dtype=MODEL_DTYPE)  # NEW: ensure projector params are same dtype as LLM

@torch.no_grad()
def encode_protein_pair(aa_seq: str, stru_str: str):
    if aa_seq:
        prot = prot_enc.get_repr([aa_seq], batch_size=1, verbose=False)[0].to(DEVICE)
    else:
        prot = torch.zeros(D_PROT, device=DEVICE)
    if stru_str:
        stru = stru_enc.get_repr([stru_str], batch_size=1, verbose=False)[0].to(DEVICE)
    else:
        stru = torch.zeros(D_STRU, device=DEVICE)
    return torch.cat([prot, stru], dim=-1)  # returns float32; we'll cast before projector

def build_batch(examples, max_len=1024):
    we = llm.get_input_embeddings()
    pad_id = tok.pad_token_id
    eos_id = tok.eos_token_id

    enc_prompts   = tok([e["prompt"] for e in examples], add_special_tokens=False)
    enc_responses = tok([e["response"] + tok.eos_token for e in examples], add_special_tokens=False)

    text_token_ids, prompt_lens = [], []
    for i in range(len(examples)):
        ids_p = enc_prompts["input_ids"][i]
        ids_r = enc_responses["input_ids"][i]
        ids   = (ids_p + ids_r)[:max_len]
        text_token_ids.append(ids)
        prompt_lens.append(min(len(ids_p), len(ids)))

    T_max = max(len(ids) for ids in text_token_ids) if text_token_ids else 0

    inputs_embeds_list, attention_mask_list, labels_list = [], [], []

    for i, ex in enumerate(examples):
        ids = text_token_ids[i]
        t_i = len(ids)

        # Text embeds -> cast to model dtype
        ids_tensor = torch.tensor(ids, device=DEVICE).unsqueeze(0)
        text_embeds = we(ids_tensor).to(MODEL_DTYPE)  # NEW

        # Labels: mask prompt, train on response
        L = [-100]*prompt_lens[i] + ids[prompt_lens[i]:]

        # Protein prefix: cast input to projector dtype before matmul
        pvec = encode_protein_pair(ex.get("aa_seq") or "", ex.get("stru_str") or "")
        pvec = pvec.to(MODEL_DTYPE)  # NEW: match projector/LLM dtype
        pref = projector(pvec.unsqueeze(0)).view(1, PREFIX_LEN, D_HID)  # already MODEL_DTYPE

        # Concat prefix + text (both MODEL_DTYPE)
        combined = torch.cat([pref, text_embeds], dim=1)  # (1, P+t_i, H)
        att = torch.tensor([1]*PREFIX_LEN + [1]*t_i, device=DEVICE, dtype=torch.long).unsqueeze(0)
        lab = torch.tensor([-100]*PREFIX_LEN + L, device=DEVICE, dtype=torch.long).unsqueeze(0)

        # Pad to (P + T_max)
        pad_steps = (PREFIX_LEN + T_max) - combined.size(1)
        if pad_steps > 0:
            pad_vec = we(torch.tensor([[pad_id]], device=DEVICE)).to(MODEL_DTYPE)  # NEW
            pad_block = pad_vec.expand(1, pad_steps, D_HID)
            combined = torch.cat([combined, pad_block], dim=1)
            att = torch.cat([att, torch.zeros(1, pad_steps, device=DEVICE, dtype=torch.long)], dim=1)
            lab = torch.cat([lab, torch.full((1, pad_steps), -100, device=DEVICE, dtype=torch.long)], dim=1)

        inputs_embeds_list.append(combined)
        attention_mask_list.append(att)
        labels_list.append(lab)

    inputs_embeds  = torch.cat(inputs_embeds_list, dim=0)                # (B, L, H)
    attention_mask = torch.cat(attention_mask_list, dim=0)               # (B, L)
    labels         = torch.cat(labels_list, dim=0)                       # (B, L)
    # safety: ensure final dtype matches model
    inputs_embeds  = inputs_embeds.to(MODEL_DTYPE)                       # NEW
    return inputs_embeds, attention_mask, labels

In [18]:

# Build a tiny batch (all toy examples)
inputs_embeds, attention_mask, labels = build_batch(toy, max_len=256)
print("inputs_embeds :", tuple(inputs_embeds.shape))
print("attention_mask:", tuple(attention_mask.shape))
print("labels        :", tuple(labels.shape))

inputs_embeds : (5, 99, 896)
attention_mask: (5, 99)
labels        : (5, 99)


In [19]:
import torch

def inspect_batch(inputs_embeds, attention_mask, labels, prefix_len=1):
    B, L, H = inputs_embeds.shape
    print(f"B={B}, L={L}, H={H}")
    for i in range(B):
        att_on   = int(attention_mask[i].sum().item())
        n_pad    = L - att_on
        n_ign    = int((labels[i] == -100).sum().item())
        n_sup    = L - n_ign  # supervised tokens
        print(f"sample {i}: att_on={att_on}  pad={n_pad}  ignored={n_ign}  supervised={n_sup}")

        # spot-check masking: prefix tokens ignored & attended
        assert (labels[i, :prefix_len] == -100).all()
        assert (attention_mask[i, :prefix_len] == 1).all()

        # padded area (if any) should be att=0 and labels=-100
        if att_on < L:
            assert (attention_mask[i, att_on:] == 0).all()
            assert (labels[i, att_on:] == -100).all()

inspect_batch(inputs_embeds, attention_mask, labels, prefix_len=1)

B=5, L=99, H=896
sample 0: att_on=99  pad=0  ignored=38  supervised=61
sample 1: att_on=95  pad=4  ignored=42  supervised=57
sample 2: att_on=96  pad=3  ignored=41  supervised=58
sample 3: att_on=92  pad=7  ignored=45  supervised=54
sample 4: att_on=85  pad=14  ignored=52  supervised=47


inputs_embeds : (5, 99, 896)

attention_mask: (5, 99)

labels        : (5, 99)

B=5, L=99, H=896

sample 0: att_on=99  pad=0  ignored=38  supervised=61

sample 1: att_on=95  pad=4  ignored=42  supervised=57

sample 2: att_on=96  pad=3  ignored=41  supervised=58

sample 3: att_on=92  pad=7  ignored=45  supervised=54

sample 4: att_on=85  pad=14  ignored=52  supervised=47


Cell 6 — Full forward (masked SFT loss) + tiny training loop

In [20]:
from torch.optim import AdamW
import torch

llm.train(); projector.train()

# Tip: smaller LR if updating the LLM (e.g., 1e-5). If projector-only, 1e-3 is fine.
optimizer = AdamW(list(projector.parameters()) + list(llm.parameters()), lr=1e-5)

def step_once_fp32(examples):
    inputs_embeds, attention_mask, labels = build_batch(examples, max_len=256)

    # ensure there are supervised tokens
    if int((labels != -100).sum()) == 0:
        print("⚠️ No supervised tokens — skipping.")
        return float("nan")

    out = llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
    loss = out.loss
    if not torch.isfinite(loss):
        print("⚠️ Non-finite loss — skipping.")
        return float("nan")

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(list(projector.parameters()) + list(llm.parameters()), 1.0)
    optimizer.step()
    return float(loss.detach().cpu())

batch1 = toy[:3]
batch2 = toy[3:]

loss1 = step_once_fp32(batch1)
loss2 = step_once_fp32(batch2)

llm.eval(); projector.eval()
print(f"loss step1: {loss1:.4f} | step2: {loss2:.4f}")

loss step1: 3.5683 | step2: 3.6503


Cell 7 — Quick generation test (prefix-conditioned)

In [21]:
llm.eval(); projector.eval()
gen_ex = toy[0]  # pick one
prompt = gen_ex["prompt"]

# build prefix for this sample
pvec = encode_protein_pair(gen_ex.get("aa_seq") or "", gen_ex.get("stru_str") or "")
pref = projector(pvec.unsqueeze(0))                     # (1, H*P)
pref = pref.view(1, PREFIX_LEN, D_HID)

# tokenize prompt only (no response)
enc = tok(prompt, return_tensors="pt").to(DEVICE)
we  = llm.get_input_embeddings()
text = we(enc["input_ids"])                             # (1, T, H)

inputs_embeds  = torch.cat([pref, text], dim=1)
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long, device=DEVICE)

gen = llm.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask,
    max_new_tokens=64,
    do_sample=True,
    top_p=0.9,
    temperature=0.8,
    eos_token_id=tok.eos_token_id,
    pad_token_id=tok.pad_token_id
)

# Drop the prefix when decoding: we passed inputs_embeds, so there's no input_ids.
# Decode only the newly generated tokens by slicing off the prompt length; HF adds tokens to 'gen' tensor.
print(tok.decode(gen[0], skip_special_tokens=True))

 Aminotransferase, involved in protein degradation or proteasomal processing. The catalytic residues suggest an allosteric effect. The overall topology and sequence indicate it is part of a membrane-bound complex. This suggests an intracellular role. [Source] http://www. biochem-yr. org


In [22]:
# ==== Protein-conditioned natural language generation ====
import torch

llm.eval(); projector.eval()

def build_inputs_with_prefix(prompt: str, aa_seq: str | None, stru_str: str | None):
    """
    Returns inputs_embeds and attention_mask ready for llm.generate(...).
    """
    # 1) tokenize prompt text
    enc = tok(prompt, return_tensors="pt").to(llm.device)
    we  = llm.get_input_embeddings()
    text_embeds = we(enc["input_ids"])                               # (1, T, H)

    # 2) make protein vector -> projector -> soft prefix
    pvec = encode_protein_pair(aa_seq or "", stru_str or "")         # (2048,)
    pref = projector(pvec.unsqueeze(0)).view(1, PREFIX_LEN, D_HID)   # (1, P, H)

    # 3) concat prefix + text
    inputs_embeds  = torch.cat([pref, text_embeds], dim=1)           # (1, P+T, H)
    attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.long, device=llm.device)
    return inputs_embeds, attention_mask

@torch.no_grad()
def generate_answer(prompt: str, aa_seq: str | None, stru_str: str | None,
                    max_new_tokens=128, temperature=0.7, top_p=0.95, do_sample=True):
    inputs_embeds, attention_mask = build_inputs_with_prefix(prompt, aa_seq, stru_str)
    out_ids = llm.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    # when using inputs_embeds, HF returns only the newly generated tokens
    return tok.decode(out_ids[0], skip_special_tokens=True)

# ---- demo on your toy set ----
N = min(3, len(toy))   # change N if you want more
for i in range(N):
    ex = toy[i]
    print(f"\n=== Example {i+1} ===")
    print("Prompt:", ex["prompt"])
    ans = generate_answer(ex["prompt"], ex.get("aa_seq"), ex.get("stru_str"),
                          max_new_tokens=64, temperature=0.8, top_p=0.9, do_sample=True)
    print("Model:", ans)

# (Optional) compare against an unconditioned baseline (no protein prefix)
@torch.no_grad()
def generate_unconditioned(prompt: str, max_new_tokens=128, temperature=0.7, top_p=0.95, do_sample=True):
    enc = tok(prompt, return_tensors="pt").to(llm.device)
    out_ids = llm.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tok.eos_token_id,
        pad_token_id=tok.pad_token_id,
    )
    return tok.decode(out_ids[0], skip_special_tokens=True)

print("\n--- Unconditioned comparison on example 1 ---")
print(generate_unconditioned(toy[0]["prompt"], max_new_tokens=64, temperature=0.8, top_p=0.9))


=== Example 1 ===
Prompt: You are a professional protein biologist. Based on the amino-acid sequence (and structure if available), write a concise, biologically accurate 2–4 sentence description of the protein.
Model:  This is a highly conserved serine-rich, hydrophobic protein that likely functions as a membrane-localized receptor or transmembrane helix–helix–helix–like domain.Human beings have a long history of using DNA to create life forms (such as the origin of life hypothesis). Based on

=== Example 2 ===
Prompt: You are a professional protein biologist. Based on the amino-acid sequence (and structure if available), write a concise, biologically accurate 2–4 sentence description of the protein.
Model:  This protein shows significant homology to members of the serine/threonine kinase family, including S6K1, and interacts with the C-terminal tail. It likely functions in signal transduction or DNA-binding. The presence of a hydrophobic core around the Ser/Thr motifs suggests a role

In [None]:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

JsonlDataset = train_mod.JsonlDataset
PadAndMaskCollator = train_mod.PadAndMaskCollator
CollateCfg = train_mod.CollateCfg

train_set = JsonlDataset("sft_data/train_tiny.jsonl", tokenizer, max_len=MAX_LEN)
collate = PadAndMaskCollator(CollateCfg(tokenizer=tokenizer, max_len=MAX_LEN))

batch = [train_set[0], train_set[1]]
batch_out = collate(batch)

for k, v in batch_out.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape, v.dtype)
    else:
        print(k, type(v), len(v))

# Labels should be -100 over the prompt (and EOS), non‑masked over response.
print("Sample labels row (first 50 tokens):", batch_out["labels"][0][:50].tolist())


input_ids torch.Size([2, 24]) torch.int64
attention_mask torch.Size([2, 24]) torch.int64
labels torch.Size([2, 24]) torch.int64
aa_seq <class 'list'> 2
stru_str <class 'list'> 2
Sample labels row (first 50 tokens): [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1986, 7952, 311, 387, 458, 48142, 448, 3204, 6275, 67, 1080, 519, 5702, 13, 151645]
