### Libraries

In [1]:
import os
import math
import signal
import argparse
import yaml
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers.optimization import get_scheduler
from peft import LoraConfig, get_peft_model

In [2]:
from viscy.data.triplet import TripletDataModule
from viscy.transforms import ScaleIntensityRangePercentilesd, NormalizeSampled, Decollated
import pandas as pd
from typing import Dict, Tuple, Any, Iterable, List

### Main

In [3]:
# -----------------------
# Helper functions
# -----------------------
def as_int_or_str(v, default):
    if v is None:
        return default
    if isinstance(v, int):
        return v
    if isinstance(v, str):
        s = v.strip()
        if s.isdigit():
            return int(s)
        return s  # e.g., "auto"
    return v

def to_int(v, default=None):
    if v is None: return default
    if isinstance(v, int): return v
    if isinstance(v, str) and v.strip().isdigit(): return int(v)
    return int(float(v))  # handles "3.0"

def to_float(v, default=None):
    if v is None: return default
    if isinstance(v, (int, float)): return float(v)
    return float(str(v).strip())

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="qwen_config_notebooks.yaml")

# mirror TrainingArguments you shared
parser.add_argument("--epochs", type=int, default=None)                     # num_train_epochs
parser.add_argument("--lr", type=float, default=None)                       # learning_rate
parser.add_argument("--lr_scheduler_type", type=str, default=None,          # lr_scheduler_type
                    choices=["linear","cosine","cosine_with_restarts","polynomial","constant","constant_with_warmup"])
parser.add_argument("--batch_size", type=int, default=None)                 # per_device_train_batch_size
parser.add_argument("--eval_batch_size", type=int, default=None)            # per_device_eval_batch_size
parser.add_argument("--grad_accum", type=int, default=None)                 # gradient_accumulation_steps
parser.add_argument("--weight_decay", type=float, default=None)
parser.add_argument("--logging_fraction", type=float, default=None)         # to derive logging_steps
parser.add_argument("--eval_fraction", type=float, default=None)            # to derive eval_steps
parser.add_argument("--warmup_steps", type=int, default=None)               # optional
parser.add_argument("--warmup_ratio", type=float, default=None)             # optional, if steps not given

# runtime/system
parser.add_argument("--model_id", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--precision", type=str, default=None, choices=["bf16-mixed","16-mixed","32-true"])
parser.add_argument("--devices", type=int, default=None)
parser.add_argument("--strategy", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--tf32", action="store_true")
parser.add_argument("--no_tf32", dest="tf32", action="store_false")
parser.set_defaults(tf32=True)

# Single or Multi GPU training
parser.add_argument("--device_map", type=str, default=None,
                help='Hugging Face device_map (e.g., "auto"). Usually leave unset/None when using Lightning.')


# logging / names / output
parser.add_argument("--run_name", type=str, default=None)
parser.add_argument("--logging_dir", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)

# GC controls (Trainer: gradient_checkpointing, kwargs)
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument("--no_gradient_checkpointing", dest="gradient_checkpointing", action="store_false")
parser.set_defaults(gradient_checkpointing=True)
parser.add_argument("--gc_use_reentrant", action="store_true")
parser.add_argument("--gc_no_reentrant", dest="gc_use_reentrant", action="store_false")
parser.set_defaults(gc_use_reentrant=False)  # Qwen-friendly default

# HF Hub
parser.add_argument("--hub_model_id", type=str, default=None)

# ---- LoRA CLI overrides 
parser.add_argument("--lora_r", type=int, default=None)
parser.add_argument("--lora_alpha", type=float, default=None)
parser.add_argument("--lora_dropout", type=float, default=None)
parser.add_argument("--lora_bias", type=str, default=None, choices=["none","all","lora_only"])
parser.add_argument("--lora_task_type", type=str, default=None, help="e.g., CAUSAL_LM")

parser.add_argument("--lora_use_rslora", dest="lora_use_rslora", action="store_true")
parser.add_argument("--lora_no_use_rslora", dest="lora_use_rslora", action="store_false")
parser.set_defaults(lora_use_rslora=None)  # None = not provided on CLI

parser.add_argument("--lora_target_modules", type=str, default=None,
                    help="Comma-separated list e.g. 'q_proj,k_proj,v_proj,o_proj,...'")
parser.add_argument("--lora_modules_to_save", type=str, default=None,
                    help="Comma-separated list e.g. 'lm_head,embed_tokens'")

parser.add_argument("--lora_extras_yaml", type=str, default=None,
                    help="Inline YAML/JSON dict of extra LoraConfig fields (e.g., rank_pattern)")

args = parser.parse_args([])

# Load YAML and merge
with open(args.config, "r") as f:
    config = yaml.safe_load(f)

def pick(key, default=None):
    return getattr(args, key) if getattr(args, key) is not None else config.get(key, default)


# Map your provided defaults
epochs          = to_int(pick("epochs", 3))
lr              = to_float(pick("lr", 1e-4)) 
lr_scheduler    = pick("lr_scheduler_type", "linear")
batch_size      = to_int(pick("batch_size", 1))
eval_batch_size = to_int(pick("eval_batch_size", 1))
grad_accum      = to_int(pick("grad_accum", 4))
weight_decay    = pick("weight_decay", 0.01)
logging_frac    = to_float(pick("logging_fraction", 0.10))
eval_frac       = to_float(pick("eval_fraction", 0.10))
warmup_steps_cfg= pick("warmup_steps", None)
warmup_ratio    = pick("warmup_ratio", None)  # if you decide to use ratio
if warmup_steps_cfg is not None: warmup_steps_cfg = to_int(warmup_steps_cfg)
if warmup_ratio is not None:     warmup_ratio     = to_float(warmup_ratio)

model_id        = pick("model_id")
dataset_name    = pick("dataset")
num_workers     = to_int(pick("num_workers", 4))
precision       = pick("precision", "bf16-mixed")
# devices/strategy with normalization
devices_raw     = pick("devices", 1)           # may be int or "auto"
strategy        = pick("strategy", "auto")
devices         = as_int_or_str(devices_raw, 1)
device_map      = pick("device_map", None)

seed            = pick("seed", 42)
tf32            = args.tf32 if "tf32" in args else config.get("tf32", True)

run_name        = pick("run_name", f"dynacell-{lr}_lr-{epochs}_epochs-{lr_scheduler}_schedule-completions")
logging_dir     = pick("logging_dir", f"./logs/{run_name}")
output_dir      = pick("output_dir", "fine-tuned-model")

gradient_checkpointing = args.gradient_checkpointing if "gradient_checkpointing" in args else config.get("gradient_checkpointing", True)
gc_use_reentrant       = args.gc_use_reentrant if "gc_use_reentrant" in args else config.get("gc_use_reentrant", False)

hub_model_id    = pick("hub_model_id", "shenbaba/Qwen2.5-VLM-3B-dynacell")

 # If Lightning is doing multi-GPU or non-auto strategy, force HF device_map=None
multi_gpu = (isinstance(devices, int) and devices > 1)
non_auto_devices = (isinstance(devices, str) and devices not in (None, "auto"))
if multi_gpu or non_auto_devices or (strategy and strategy != "auto"):
    device_map = None

# --- LoRA config (YAML + CLI overrides via pick-like behavior) ---
lora_from_yaml = config.get("lora", {}) or {}
_default_lora = {
    "r": 32,
    "lora_alpha": 16,
    "use_rslora": True,
    "target_modules": ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj","mlp.0","mlp.2"],
    "modules_to_save": ["lm_head","embed_tokens"],
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": "CAUSAL_LM",
}

def _split_csv(s):
    return [x.strip() for x in s.split(",")] if s else None

def pick_lora(field, default=None):
    # CLI first
    cli_val = getattr(args, f"lora_{field}", None)
    if cli_val is not None:
        return cli_val
    # YAML next
    if field in lora_from_yaml and lora_from_yaml[field] is not None:
        return lora_from_yaml[field]
    # Fallback
    return _default_lora.get(field, default)

lora_cfg = {
    "r": to_int(pick_lora("r")),
    "lora_alpha": to_float(pick_lora("lora_alpha")),
    "lora_dropout": to_float(pick_lora("lora_dropout")),
    "use_rslora": pick_lora("use_rslora"),
    "bias": pick_lora("bias"),
    "task_type": pick_lora("task_type"),
    "target_modules": _split_csv(args.lora_target_modules)
                        if args.lora_target_modules is not None
                        else pick_lora("target_modules"),
    "modules_to_save": _split_csv(args.lora_modules_to_save)
                        if args.lora_modules_to_save is not None
                        else pick_lora("modules_to_save"),
}

if args.lora_extras_yaml:
    try:
        extra = yaml.safe_load(args.lora_extras_yaml)
        if isinstance(extra, dict):
            lora_cfg.update(extra)
    except Exception:
        pass


In [5]:
# CSV lookup

def build_annotation_lookup(
    csv_or_df: Any,
    key_cols=("fov_name","track_id","t","parent_id"),
    value_cols=("organelle","predicted_cellstate","predicted_infection"),
    caption_col: str | None = None,
):
    """Return dict[(fov_name,track_id,t,parent_id)] -> {value_cols..., '__caption__': str}"""
    df = pd.read_csv(csv_or_df) if isinstance(csv_or_df, (str, bytes)) else csv_or_df.copy()

    # normalize dtypes for exact matching
    df["fov_name"] = df["fov_name"].astype(str)
    for c in ("track_id","t","parent_id"):
        df[c] = pd.to_numeric(df[c], downcast="integer")

    def make_caption(row):
        if caption_col and caption_col in row and pd.notna(row[caption_col]):
            return str(row[caption_col]).strip()
        org   = str(row.get("organelle","unknown")).strip()
        phase = str(row.get("predicted_cellstate","unknown")).strip()
        inf   = str(row.get("predicted_infection","unknown")).strip()
        return f"{org}; {phase}; {inf}"

    lookup: Dict[Tuple[Any,...], Dict[str,Any]] = {}
    for _, row in df.iterrows():
        key = tuple(row[c] for c in key_cols)
        payload = {c: row[c] for c in value_cols if c in df.columns}
        payload["__caption__"] = make_caption(row)
        lookup[key] = payload
    return lookup


In [6]:
pl.seed_everything(seed, workers=True)

Seed set to 42


42

### Data

In [7]:
data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr"
tracks_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr"
source_channel =  ["Phase3D", "GFP EX488 EM525-45", "mCherry EX561 EM600-37"]
annot_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/cytospeak_annotations/2025_07_24_annotations.csv"

In [8]:
# build the dict
anno_lookup = build_annotation_lookup(annot_path, caption_col=None)

In [9]:
processor = AutoProcessor.from_pretrained(model_id, use_fast=True) # change this back to use_fast=False if you run into issues
# dataset = load_dataset(dataset_name)

# disable rescaling (you handle scaling yourself)
processor.image_processor.do_rescale = False

# optionally also disable normalization if you already did it
processor.image_processor.do_normalize = False

# disable conversion to RGB
processor.image_processor.do_convert_rgb = False 


In [10]:
processor.image_processor

Qwen2VLImageProcessorFast {
  "crop_size": null,
  "data_format": "channels_first",
  "default_to_square": true,
  "device": null,
  "do_center_crop": null,
  "do_convert_rgb": false,
  "do_normalize": false,
  "do_rescale": false,
  "do_resize": true,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "Qwen2VLImageProcessorFast",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "input_data_format": null,
  "max_pixels": 12845056,
  "merge_size": 2,
  "min_pixels": 3136,
  "patch_size": 14,
  "processor_class": "Qwen2_5_VLProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "return_tensors": null,
  "size": {
    "longest_edge": 12845056,
    "shortest_edge": 3136
  },
  "temporal_patch_size": 2
}

In [10]:
# # Optional: quick shrink for prototyping (same as before)
# resize_image = lambda ex: {"image": ex["image"].resize((ex["image"].width // 4, ex["image"].height // 4))}
# train_dataset = dataset["train"].map(resize_image)
# val_dataset   = dataset["test"].map(resize_image)

### VISCY Dataloader

In [11]:
num_workers

4

In [35]:
print("Setting up data module...")
del dm
dm = TripletDataModule(
    data_path=data_path,
    tracks_path=tracks_path,
    source_channel=source_channel,
    #batch_size=batch_size, # already set in the notebook
    batch_size = 10, # testing
    num_workers=num_workers, # already set in the notebook
    z_range=(0,1),
    initial_yx_patch_size=(256, 256),
    final_yx_patch_size=(160, 160),
    normalizations=[
        NormalizeSampled(
            keys=["Phase3D"], level="fov_statistics", subtrahend="mean", divisor="std"
        ),
        Decollated(
            keys=source_channel
        ),
        ScaleIntensityRangePercentilesd(
            keys=["GFP EX488 EM525-45"], lower=50, upper=99, b_min=0.0, b_max=1.0
        ),
        ScaleIntensityRangePercentilesd(
            keys=["mCherry EX561 EM600-37"], lower=50, upper=99, b_min=0.0, b_max=1.0
        ),

    ],
    return_negative=False
)
dm.prepare_data()
dm.setup("fit")

Setting up data module...


In [36]:
train_loader =dm.train_dataloader()
dm.setup("predict")
test_loader = dm.predict_dataloader()

In [35]:
d = batch['index']

# build lookup set of (fov_name, track_id, t, parent_id)
quadruplets = set(zip(
    d["fov_name"],
    d["track_id"].tolist(),
    d["t"].tolist(),
    d["parent_id"].tolist()
))

# filter DataFrame
batch_annot = annot[annot.apply(
    lambda r: (r["fov_name"], r["track_id"], r["t"], r["parent_id"]) in quadruplets,
    axis=1
)]


In [40]:
batch['anchor'][0].shape

torch.Size([3, 1, 160, 160])

In [63]:
# 1) images -> [B, 3, H, W]
if imgs.ndim == 5 and imgs.shape[2] == 1:
    imgs = imgs.squeeze(2)
assert imgs.ndim == 4 and imgs.shape[1] == 3, f"Expected [B,3,H,W], got {tuple(imgs.shape)}"

B = imgs.shape[0]

In [67]:
# 2) extract quadruplets
fovs = [str(s) for s in idx["fov_name"]]
tids = idx["track_id"].tolist() if torch.is_tensor(idx["track_id"]) else list(idx["track_id"])
ts   = idx["t"].tolist()        if torch.is_tensor(idx["t"])        else list(idx["t"])
pids = idx["parent_id"].tolist()if torch.is_tensor(idx["parent_id"]) else list(idx["parent_id"])

texts: List[str] = []
images_for_qwen: List[List[torch.Tensor]] = []
targets: List[str] = []

In [16]:
class CollatingLoader:
    """Wrap an existing DataLoader and apply `collator` to each batch."""
    def __init__(self, base_loader, collator):
        self.base_loader = base_loader
        self.collator = collator

    def __iter__(self):
        for raw_batch in self.base_loader:
            yield self.collator(raw_batch)

    def __len__(self):
        return len(self.base_loader)

    # Handy passthroughs so PL can read dataset/sampler/batch_size, etc.
    @property
    def dataset(self):
        return self.base_loader.dataset

    @property
    def batch_size(self):
        return getattr(self.base_loader, "batch_size", None)

    @property
    def sampler(self):
        return getattr(self.base_loader, "sampler", None)

    @property
    def batch_sampler(self):
        return getattr(self.base_loader, "batch_sampler", None)

    def __getattr__(self, name):
        # delegate anything else to the base loader (e.g., drop_last, pin_memory, etc.)
        return getattr(self.base_loader, name)


### Computing loss over answers only

In [17]:
# -----------------------
# Custom collator for multimodal Qwen input
# -----------------------
class QwenIndexAnchorCollator:
    """
    Expects a batch:
      batch = {
        "index": {
          "fov_name": list[str],   # len B
          "track_id": Tensor[B],   # int
          "t":        Tensor[B],   # int
          "parent_id":Tensor[B],   # int
          # other fields are ignored
        },
        "anchor": Tensor[B, 3, 1, H, W],  # images
      }

    Produces a dict ready for Qwen forward(**dict), with assistant-only labels.
    """
    def __init__(
        self,
        processor,
        anno_lookup: Dict[Tuple[Any,...], Dict[str,Any]],
        question: str = (
                            "You are given a fluorescence microscopy image.\n\n"
                            "Task: classify three attributes.\n"
                            "Output format: exactly three words separated by single spaces, in this order: "
                            "ORGANELLE PHASE INFECTION\n"
                            "Allowed vocabularies:\n"
                            "- ORGANELLE ∈ {ER, mitochondria, golgi, lysosome, nucleus, stress_granule}\n"
                            "- PHASE ∈ {interphase, mitotic}\n"
                            "- INFECTION ∈ {infected, uninfected}\n"
                            "Rules: no punctuation, no explanations, no quotes, no newlines. "
                            "If uncertain, guess the most likely label from the allowed set."
                        ),
        pad_to_multiple_of: int | None = 8,
        fail_on_missing: bool = True,  # True: error; False: skip unmatched rows
    ):
        self.processor = processor
        self.anno_lookup = anno_lookup
        self.question = question
        self.pad_to_multiple_of = pad_to_multiple_of
        self.fail_on_missing = fail_on_missing

        # cache tokenizer ids we use often
        self.pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
        self.eos_id = getattr(self.processor.tokenizer, "eos_token_id", None)

    @staticmethod
    def _find_subsequence(seq: torch.Tensor, subseq: torch.Tensor) -> int | None:
        n, m = len(seq), len(subseq)
        if m == 0 or m > n:
            return None
        # naive scan is fine at typical batch sizes
        for s in range(n - m + 1):
            if torch.equal(seq[s:s+m], subseq):
                return s
        return None

    def __call__(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        idx = batch["index"]
        imgs = batch["anchor"]  # [B, 3, 1, H, W]

        # 1) images -> [B, 3, H, W]
        if imgs.ndim == 5 and imgs.shape[2] == 1:
            imgs = imgs.squeeze(2)
        assert imgs.ndim == 4 and imgs.shape[1] == 3, f"Expected [B,3,H,W], got {tuple(imgs.shape)}"

        B = imgs.shape[0]

        # 2) extract quadruplets
        fovs = [str(s) for s in idx["fov_name"]]
        tids = idx["track_id"].tolist() if torch.is_tensor(idx["track_id"]) else list(idx["track_id"])
        ts   = idx["t"].tolist()        if torch.is_tensor(idx["t"])        else list(idx["t"])
        pids = idx["parent_id"].tolist()if torch.is_tensor(idx["parent_id"]) else list(idx["parent_id"])

        texts: List[str] = []
        images_for_qwen: List[List[torch.Tensor]] = []
        targets: List[str] = []

        # 3) per-sample pairing + message building
        kept = 0
        for i in range(B):
            key = (str(fovs[i]), int(tids[i]), int(ts[i]), int(pids[i]))
            row = self.anno_lookup.get(key)
            if row is None:
                if self.fail_on_missing:
                    raise KeyError(f"Missing annotation for quadruplet {key}")
                else:
                    continue

            caption = str(row["__caption__"]).strip()

            # ensure per-image tensor is [3,H,W] (not HxWx3)
            im = imgs[i]
            if im.ndim == 3 and im.shape[0] != 3 and im.shape[-1] == 3:
                im = im.permute(2,0,1)

            # Qwen chat
            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": self.question},
                    {"type": "image"},
                ]},
                {"role": "assistant", "content": [
                    {"type": "text", "text": caption}
                ]},
            ]
            text = self.processor.apply_chat_template(messages, add_generation_prompt=False)

            texts.append(text.strip())
            images_for_qwen.append([im])   # list-of-images per sample
            targets.append(caption)
            kept += 1

        if kept == 0:
            raise RuntimeError("After filtering/missing annotations, the batch is empty.")

        # 4) processor (tokenize + patchify)
        out = self.processor(
            text=texts,
            images=images_for_qwen,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        # 5) assistant-only labels
        labels = out["input_ids"].clone()

        # tokenize targets once (batched)
        tgt_tok = self.processor.tokenizer(
            targets, add_special_tokens=False, padding=True, return_tensors="pt"
        )
        for i, input_ids in enumerate(out["input_ids"]):
            tgt_ids = tgt_tok["input_ids"][i]
            if self.pad_id is not None:
                tgt_ids = tgt_ids[tgt_ids != self.pad_id]
            start = self._find_subsequence(input_ids, tgt_ids)
            if start is not None:
                end = start + len(tgt_ids)
                if self.eos_id is not None and end < len(input_ids) and input_ids[end].item() == self.eos_id:
                    end += 1
                labels[i, :start] = -100
                labels[i, end:]   = -100
            else:
                labels[i, :] = -100  # safe fallback if not found

        out["labels"] = labels
        return out

In [18]:
collator = QwenIndexAnchorCollator(processor, anno_lookup, pad_to_multiple_of=8)

In [37]:
# define loaders for lightning
train_loader_qwen = CollatingLoader(train_loader, collator)
test_loader_qwen   = CollatingLoader(test_loader, collator)

In [47]:
# Steps accounting to mirror your Trainer math

#dataset_len = len(train_dataset)
dataset_len = len(train_loader)

steps_per_epoch = math.ceil(dataset_len / (batch_size * max(1, grad_accum)))
total_steps = steps_per_epoch * epochs

# Fractions → concrete steps
logging_steps = max(1, int(total_steps * float(logging_frac)))
eval_steps    = max(1, int(total_steps * float(eval_frac)))

# Warmup resolution
if warmup_steps_cfg is not None:
    warmup_steps = int(warmup_steps_cfg)
elif warmup_ratio is not None:
    warmup_steps = int(total_steps * float(warmup_ratio))
else:
    warmup_steps = 0  # matches your commented-out default

# testing
epochs = 1
del logging_steps
logging_steps = 20
del eval_steps
eval_steps = 20
grad_accum = 1


In [None]:

# # -----------------------
# # Dataloaders
# # -----------------------
# # Pin + persistent workers improve performance on repeated small batches
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=batch_size,
#     shuffle=True,
#     collate_fn=collator,
#     num_workers=num_workers,
#     pin_memory=True,
#     persistent_workers=True,
#     drop_last=True,
# )
# val_loader = DataLoader(
#     val_dataset,
#     batch_size=eval_batch_size,
#     shuffle=False,
#     collate_fn=collator,
#     num_workers=max(1, num_workers // 2),
#     pin_memory=True,
#     persistent_workers=True,
#     drop_last=False,
# )

In [26]:
# -----------------------
# Logging & checkpoints
# -----------------------
# Logger & callbacks (TensorBoard + step-based eval/checkpointing)
logger = TensorBoardLogger(
    save_dir=logging_dir,
    name=run_name,
    default_hp_metric=False  # prevents Lightning adding HP metric noise
)
ckpt_cb = ModelCheckpoint(
    dirpath=output_dir,
    monitor="val_loss",
    mode="min",
    save_top_k=1,               # keep best model only (like save_total_limit=1)
    every_n_train_steps=eval_steps,  # save on same cadence as eval
    filename="step{step}-valloss{val_loss:.4f}",
    auto_insert_metric_name=False,
    save_last=True,
)
lr_cb = LearningRateMonitor(logging_interval="step")

In [39]:
# -----------------------
# LightningModule with LoRA-wrapped Qwen
# -----------------------
class QwenLoraModule(pl.LightningModule):
    """
    LightningModule wrapping:
    - Qwen2.5-VL model with LoRA applied
    - AdamW optimizer with HF scheduler
    - Optional gradient checkpointing for VRAM savings
    """
    def __init__(
        self,
        model_id,
        lr,
        weight_decay,
        lr_scheduler_type,
        warmup_steps,
        num_training_steps,
        adam_beta1=0.9,
        adam_beta2=0.95,
        adam_epsilon=1e-8,
        gradient_checkpointing=True,
        gc_use_reentrant=False,  # False avoids Qwen checkpointing bug
        attn_implementation="eager",
        tf32=True,
        lora_cfg=None,
        device_map=None,
    ):
        super().__init__()
        # Save all hparams including lora_cfg (except num_training_steps which is large/dynamic)
        self.save_hyperparameters(ignore=["num_training_steps"])
        self.num_training_steps = num_training_steps

        # Enable TensorFloat32 for faster matmul on Ampere+ GPUs
        if tf32:
            torch.backends.cuda.matmul.allow_tf32 = True

        # Processor handles both text & image preprocessing
        self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

        # Load Qwen base model in bf16 for memory savings
        base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            attn_implementation=attn_implementation,
            device_map=device_map,
        )

        # --- Robust gradient checkpointing across Transformers versions ---
        if gradient_checkpointing:
            enabled = False
            try:
                # Newer API accepts kwargs dict
                base_model.gradient_checkpointing_enable(
                    gradient_checkpointing_kwargs={"use_reentrant": gc_use_reentrant}
                )
                enabled = True
            except TypeError:
                pass
            if not enabled:
                try:
                    # Older API: no kwargs
                    base_model.gradient_checkpointing_enable()
                    enabled = True
                except TypeError:
                    # Very old fallbacks
                    if hasattr(base_model, "enable_gradient_checkpointing"):
                        base_model.enable_gradient_checkpointing()
                        enabled = True
                    elif hasattr(base_model, "set_gradient_checkpointing"):
                        base_model.set_gradient_checkpointing(True)
                        enabled = True
            if hasattr(base_model, "enable_input_require_grads"):
                base_model.enable_input_require_grads()
        
        # --- LoRA configuration from YAML ---
        if not isinstance(lora_cfg, dict) or len(lora_cfg) == 0:
            raise ValueError(
                "LoRA configuration is missing. Please provide a 'lora:' section in qwen2vl_config.yaml."
            )
        lora_config = LoraConfig(**lora_cfg)

        # Get LoRA model
        self.model = get_peft_model(base_model, lora_config)

        # Training-friendly defaults
        if hasattr(self.model, "config"):
            self.model.config.use_cache = False
            if getattr(self.model.config, "pad_token_id", None) is None:
                self.model.config.pad_token_id = self.processor.tokenizer.eos_token_id

    def forward(self, **batch):
        return self.model(**batch)

    def training_step(self, batch, batch_idx):
        out = self(**batch)
        # Log training loss per step (no epoch avg to match HF behavior)
        self.log("train_loss", out.loss, prog_bar=True, on_step=True, on_epoch=False)
        return out.loss

    def validation_step(self, batch, batch_idx):
        out = self(**batch)
        # Log validation loss averaged over an epoch
        self.log("val_loss", out.loss, prog_bar=True, on_step=False, on_epoch=True)
        return out.loss

    def configure_optimizers(self):
        """
        Set up:
        - AdamW optimizer with HF's beta/eps/weight decay settings
        - LR scheduler from transformers.optimization.get_scheduler
        """
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.hparams.lr,
            betas=(self.hparams.adam_beta1, self.hparams.adam_beta2),
            eps=self.hparams.adam_epsilon,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = get_scheduler(
            name=self.hparams.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.num_training_steps
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",  # step-based scheduler like HF
                "frequency": 1,
                "name": "lr"
            }
        }


In [42]:
# Module
module = QwenLoraModule(
    model_id=model_id,
    lr=lr,
    weight_decay=weight_decay,
    lr_scheduler_type=lr_scheduler,
    warmup_steps=warmup_steps,
    num_training_steps=total_steps,
    gradient_checkpointing=gradient_checkpointing,
    gc_use_reentrant=gc_use_reentrant,
    attn_implementation="eager",
    tf32=tf32,
    lora_cfg=lora_cfg,
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [49]:

del trainer

# Trainer (Lightning: step-based val via val_check_interval)
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=devices,
    precision=precision,
    gradient_clip_val=1.0,
    log_every_n_steps=logging_steps,
    logger=logger,
    callbacks=[ckpt_cb, lr_cb],
    accumulate_grad_batches=grad_accum,
    plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)],
    strategy=strategy,
    val_check_interval=eval_steps,  # "eval_strategy=steps"
    limit_train_batches=200,        # exactly 100 batches per epoch
    limit_val_batches=40, 
)


/hpc/mydata/yasin.senbabaoglu/anaconda/25.3.1/x86_64/envs/qwen2vl/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /hpc/mydata/yasin.senbabaoglu/anaconda/25.3.1/x86_64 ...
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [50]:
trainer.fit(module, train_loader_qwen, test_loader_qwen)

/hpc/mydata/yasin.senbabaoglu/anaconda/25.3.1/x86_64/envs/qwen2vl/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:701: Checkpoint directory /hpc/mydata/yasin.senbabaoglu/projects/qwen_test/notebooks/fine-tuned-model exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/hpc/mydata/yasin.senbabaoglu/anaconda/25.3.1/x86_64/envs/qwen2vl/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type                 | Params | Mode
------------------------------------------------------
0 | model | PeftModelForCausalLM | 4.4 B  | eval
------------------------------------------------------
637 M     Trainable params
3.8 B     Non-trainable params
4.4 B     Total params
17,569.022Total estimated model params size (MB)
0         Modules in train mode
2342      Modules in eval

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/hpc/mydata/yasin.senbabaoglu/anaconda/25.3.1/x86_64/envs/qwen2vl/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:417: `ModelCheckpoint(monitor='val_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'epoch', 'step']. HINT: Did you call `log('val_loss', value)` in the `LightningModule`?


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
