# FSDP Test Add‑on — `train_prefix_qwen_fsdp.py` (Patched)

This section imports the patched trainer and runs forward/unit tests so you can see tracebacks inline.

In [1]:
#@title Mount Google Drive
from pathlib import Path
from huggingface_hub import snapshot_download
import os, json, pickle, pandas as pd
from tqdm import tqdm
from rich import print as rprint

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

%cd /content/drive/MyDrive/LLM/Bioreasoner/testing_pipelines

from pathlib import Path
BASE_DIR = Path("/content/drive/MyDrive/LLM/Bioreasoner/data/hf/proteinDT")
OUT_DIR  = BASE_DIR / "sft_test_demo"
print(f"Using Google Drive folder as BASE_DIR: {BASE_DIR}")


In [None]:
# Check GPU
!nvidia-smi

# Fresh pip + libs (PyTorch CUDA 12.1 build + matching libs)
%pip -q install --upgrade pip
%pip install -q --index-url https://download.pytorch.org/whl/cu126 \
  torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0
%pip -q install transformers==4.56.1 huggingface_hub==0.35.0 tqdm safetensors

In [None]:
# --- Version & import sanity checks ---
import torch, transformers, huggingface_hub
print("torch            :", torch.__version__)
print("transformers     :", transformers.__version__)
print("huggingface_hub  :", huggingface_hub.__version__)

# Top-level ESM import should work on 4.56.1
try:
    from transformers import AutoTokenizer, EsmForMaskedLM
    print("✅ Top-level EsmForMaskedLM import OK")
except Exception as e:
    print("❌ Top-level EsmForMaskedLM import failed:", repr(e))
    # Fallback check (direct module path)
    try:
        from transformers.models.esm.modeling_esm import EsmForMaskedLM as _E
        print("✅ Direct modeling_esm import OK (fallback)")
    except Exception as ee:
        print("❌ Direct modeling_esm import failed too:", repr(ee))

In [None]:

# Installs (adjust if your runtime differs)
# %pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# %pip -q install transformers>=4.43.0 peft accelerate datasets tqdm


In [None]:

import os, sys
print('CWD:', os.getcwd())
if os.getcwd() not in sys.path:
    sys.path.insert(0, os.getcwd())
!ls -la | head -n 40


In [None]:

from importlib import reload
import train_prefix_qwen_fsdp as tpq
reload(tpq)

from train_prefix_qwen_fsdp import BigProteinQwen, CollateCfg, PadAndMaskCollator, JsonlStream, train as fsdp_train
print('Imported OK')


In [None]:

MODEL_NAME       = "Qwen/Qwen2.5-0.5B-Instruct"
PROTEIN_CONFIG   = "facebook/esm2_t12_35M_UR50D"
STRUCTURE_CONFIG = "facebook/esm2_t12_35M_UR50D"

# <- EDIT THESE FOR YOUR DRIVE PATH & SLOTS ->
PROTREK_CKPT = "/content/drive/MyDrive/LLM/Bioreasoner/protrek/weights/ProTrek_35M/ProTrek_35M.pt"
PROT_SLOT = 1
STRU_SLOT = 3

SINGLE_TOKEN_PREFIX = False
PREFIX_LEN          = 4
PROJ_HID            = 1024
DROPOUT             = 0.10
DTYPE               = "bf16"   # or "fp32", "fp16", "auto"

MAX_LEN    = 256
BATCH_SIZE = 2

print("Configs ready")


In [None]:

# Tiny JSONL
import json, os
os.makedirs("sft_data", exist_ok=True)
toy = [
    {"prompt":"Describe the likely function of this protein.",
     "response":"This appears to be an enzyme with possible hydrolase activity.",
     "aa_seq":"MKTFFVAIATGAFSATA","stru_str":None},
    {"prompt":"What domain might this protein contain?",
     "response":"Likely contains a Rossmann-like fold domain.",
     "aa_seq":"MGDVEKGKKIFIMKCSQCHTVEKGGKHKTGPNLHGLFGRKTGQAP",
     "stru_str":"ACDEFGHIKLMNPQRSTVWY"}
]
with open("sft_data/train_tiny.jsonl","w") as f:
    for ex in toy: f.write(json.dumps(ex)+"\n")
with open("sft_data/val_tiny.jsonl","w") as f:
    for ex in toy: f.write(json.dumps(ex)+"\n")
print("Wrote toy jsonl")


In [None]:

from transformers import AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

collate = PadAndMaskCollator(CollateCfg(tokenizer=tok, max_len=MAX_LEN))

rows = []
with open("sft_data/train_tiny.jsonl","r") as f:
    for i, line in enumerate(f):
        if i>=2: break
        rows.append(json.loads(line))

batch = collate(rows)
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, tuple(v.shape), v.dtype)
    else:
        print(k, type(v), len(v) if isinstance(v, list) else "")
print("labels[0][:40]:", batch["labels"][0][:40].tolist())


In [None]:

# Forward pass
device = "cuda" if torch.cuda.is_available() else "cpu"

big = BigProteinQwen(
    model_name=MODEL_NAME,
    protein_config=PROTEIN_CONFIG,
    structure_config=STRUCTURE_CONFIG,
    protrek_ckpt=PROTREK_CKPT,
    prot_slot=PROT_SLOT,
    stru_slot=STRU_SLOT,
    single_token_prefix=SINGLE_TOKEN_PREFIX,
    prefix_len=PREFIX_LEN,
    proj_hid=PROJ_HID,
    dropout=DROPOUT,
    train_encoders=False,
    dtype_str=DTYPE,
).to(device)

for k in ("input_ids","attention_mask","labels"):
    batch[k] = batch[k].to(device)

with torch.no_grad():
    out = big(**batch)
print("Forward OK. loss:", float(out.loss.detach().cpu()))
