# Multimodal Pruning + LoRA Recovery (T4 Demo)
**Model:** Qwen/Qwen2-VL-2B-Instruct (text+vision)  
**Pruning:** 30% attention heads _or_ 30% FFN channels (logical masking, shape-preserving)  
**Fine-tune:** LoRA on attention + MLP to recover accuracy  
**Eval:** Toy perplexity & a quick generation on colored squares


# 1) Setup (Installs & GPU check)

In [1]:
!pip -q install "transformers>=4.43.3" "accelerate>=0.32.0" "peft>=0.11.1" "datasets" "pillow" "torchvision"

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"   # or "true" if you prefer

import os, gc, math, random, platform, warnings
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import List

print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| Py:", platform.python_version())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU only")

# Repro
SEED = 42
random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

Torch: 2.3.1+cu121 | CUDA: 12.1 | Py: 3.11.11
GPU: Tesla T4


# 2) Imports & Config

In [2]:
from PIL import Image
import torchvision
import torchvision.transforms as T

from transformers import (
    AutoProcessor, AutoConfig, AutoModelForVision2Seq, get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model

# --- Demo knobs ---
MODEL_NAME   = "Qwen/Qwen2-VL-2B-Instruct"
PRUNE_MODE   = "ffn_channels"        # "attn_heads" (GQA-safe) or "ffn_channels"
PRUNE_RATIO  = 0.15                  # try 0.15–0.30
USE_FP32     = False                 # FP16 recommended on T4 (16 GB)

# LoRA / Train knobs (tiny)
LR                = 1e-4
EPOCHS            = 1
BATCH_SIZE        = 2
GRAD_ACCUM_STEPS  = 4
MAX_LENGTH        = 512
WARMUP_STEPS      = 20
N_TRAIN, N_VAL    = 800, 200

OUTPUT_DIR        = "./demo_pruned_lora_qwen2vl"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float32 if USE_FP32 else torch.float16
print(f"Device: {device} | Dtype: {dtype} | Prune: {PRUNE_MODE} @ {int(PRUNE_RATIO*100)}%")


2025-08-25 00:16:49.783489: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-25 00:16:49.797166: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-25 00:16:49.814412: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-25 00:16:49.819831: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-25 00:16:49.832526: I tensorflow/core/platform/cpu_feature_guar

Device: cuda | Dtype: torch.float16 | Prune: ffn_channels @ 15%


# 3) Load Model & Processor

In [3]:
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
config    = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)

def load_model(dtype):
    try:
        m = AutoModelForVision2Seq.from_pretrained(
            MODEL_NAME, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True
        ).to(device)
        return m, dtype
    except torch.cuda.OutOfMemoryError:
        print("OOM at requested dtype; falling back to FP16.")
        torch.cuda.empty_cache(); gc.collect()
        m = AutoModelForVision2Seq.from_pretrained(
            MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True
        ).to(device)
        return m, torch.float16

model, dtype = load_model(dtype)
model.gradient_checkpointing_enable()  # memory saver

hidden_size        = getattr(config, "hidden_size", getattr(config, "hidden_dim", None))
num_heads          = getattr(config, "num_attention_heads", getattr(config, "num_heads", None))
intermediate_size  = getattr(config, "intermediate_size", getattr(config, "ffn_hidden_size", None))
assert hidden_size and num_heads and intermediate_size, "Missing key config dims."
head_dim = hidden_size // num_heads

total_params_m = sum(p.numel() for p in model.parameters())/1e6
print(f"Total params: {total_params_m:.1f}M")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Total params: 2209.0M


# 4) Module Finders (attention & MLP)

In [4]:
def find_attn_modules(module):
    for name, m in module.named_modules():
        if all(hasattr(m, x) for x in ["q_proj","k_proj","v_proj","o_proj"]):
            yield name, m

def find_mlp_modules(module):
    for name, m in module.named_modules():
        if all(hasattr(m, x) for x in ["gate_proj","up_proj","down_proj"]):
            yield name, m


# 5) Pruning Utilities (Logical Masking)

In [None]:
@torch.no_grad()
def prune_attention_heads_logical_gqa(attn_mod, ratio: float):
    n_q  = getattr(attn_mod, "num_heads", None) \
        or getattr(attn_mod, "n_heads", None) \
        or getattr(getattr(attn_mod, "config", None), "num_attention_heads", None)
    n_kv = getattr(attn_mod, "num_key_value_heads", None) \
        or getattr(attn_mod, "n_kv_heads", None) \
        or getattr(getattr(attn_mod, "config", None), "num_key_value_heads", None) \
        or n_q
    hd = getattr(attn_mod, "head_dim", None)
    if hd is None:
        q_rows = attn_mod.q_proj.weight.shape[0]
        assert n_q and q_rows % n_q == 0, "Cannot infer head_dim."
        hd = q_rows // n_q
    assert n_q and n_kv and n_q >= 1 and n_kv >= 1
    assert n_q % n_kv == 0, "Expected n_q divisible by n_kv for GQA."

    n_keep_q = max(1, int(n_q * (1.0 - ratio)))
    prune_q  = list(range(n_keep_q, n_q))
    if not prune_q:
        return

    def rows_for_heads(head_ids, per_head):
        rows = []
        for h in head_ids:
            s = h * per_head
            rows.extend(range(s, s + per_head))
        return rows

    rows_q  = rows_for_heads(prune_q, hd)
    group   = max(1, n_q // n_kv)
    prune_kv = sorted(set(h // group for h in prune_q))
    rows_kv = rows_for_heads(prune_kv, hd)

    # q_proj rows -> mask multiply
    Wq = attn_mod.q_proj.weight
    mask_q = torch.ones(Wq.shape[0], device=Wq.device, dtype=Wq.dtype)
    if rows_q:
        idx_q = torch.tensor([r for r in rows_q if 0 <= r < Wq.shape[0]], device=Wq.device, dtype=torch.long)
        if idx_q.numel() > 0:
            mask_q.index_fill_(0, idx_q, 0)
            Wq.mul_(mask_q[:, None])
            bq = getattr(attn_mod.q_proj, "bias", None)
            if bq is not None: bq.mul_(mask_q.to(bq.dtype))

    # k_proj / v_proj rows
    for pname, rows in (("k_proj", rows_kv), ("v_proj", rows_kv)):
        proj = getattr(attn_mod, pname)
        W    = proj.weight
        mask = torch.ones(W.shape[0], device=W.device, dtype=W.dtype)
        if rows:
            idx  = torch.tensor([r for r in rows if 0 <= r < W.shape[0]], device=W.device, dtype=torch.long)
            if idx.numel() > 0:
                mask.index_fill_(0, idx, 0)
                W.mul_(mask[:, None])
                b = getattr(proj, "bias", None)
                if b is not None: b.mul_(mask.to(b.dtype))

    # o_proj columns corresponding to pruned Q rows
    Wo = attn_mod.o_proj.weight
    col_mask = torch.ones(Wo.shape[1], device=Wo.device, dtype=Wo.dtype)
    if rows_q:
        idx_cols = torch.tensor([c for c in rows_q if 0 <= c < Wo.shape[1]], device=Wo.device, dtype=torch.long)
        if idx_cols.numel() > 0:
            col_mask.index_fill_(0, idx_cols, 0)
            Wo.mul_(col_mask[None, :])

In [5]:
@torch.no_grad()
def prune_ffn_channels_logical_mask(mlp_mod: nn.Module, ratio: float, intermediate_size: int):
    n_prune = max(1, int(intermediate_size * ratio))
    if n_prune <= 0: return
    down_W = mlp_mod.down_proj.weight  # [hidden, inter]
    col_norms = torch.norm(down_W, p=1, dim=0)
    prune_idx = torch.topk(col_norms, k=n_prune, largest=False).indices

    inter_dim = down_W.shape[1]
    keep_mask = torch.ones(inter_dim, device=down_W.device, dtype=down_W.dtype)
    if prune_idx.numel() > 0:
        keep_mask.index_fill_(0, prune_idx, 0)

    # down_proj: zero columns
    mlp_mod.down_proj.weight.mul_(keep_mask[None, :])
    # up_proj / gate_proj: zero rows
    for name in ["up_proj","gate_proj"]:
        getattr(mlp_mod, name).weight.mul_(keep_mask[:, None])
        b = getattr(getattr(mlp_mod, name), "bias", None)
        if b is not None:
            b.mul_(keep_mask.to(b.dtype))

# 6) Build a Tiny Multimodal Toy Dataset

In [6]:
if processor.tokenizer.pad_token_id is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
PAD_ID   = processor.tokenizer.pad_token_id
IGNOREID = -100
EOS      = processor.tokenizer.eos_token

root = "./data_cifar10"
train_raw = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
test_raw  = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
label_names = train_raw.classes

resize_to = 448
resize_tf = T.Resize((resize_to, resize_to))

def make_examples(ds, n_take):
    items = []
    for i in range(n_take):
        img_pil, y = ds[i]
        img = resize_tf(img_pil)
        ans = label_names[y]
        items.append({
            "image": img,
            "question": "What is in this image?",
            "answer": ans,                  # clean answer
            "answer_with_eos": ans + EOS,   # teach the model to STOP
        })
    return items

N_TRAIN, N_VAL = 800, 200
train_items = make_examples(train_raw, N_TRAIN)
val_items   = make_examples(test_raw,  N_VAL)



Files already downloaded and verified
Files already downloaded and verified


In [None]:
def encode_example_vqa(ex):
    # Train template: user + assistant(answer_with_eos)
    messages_train = [
        {"role":"user","content":[
            {"type":"image","image": ex["image"]},
            {"type":"text","text": ex["question"]}
        ]},
        {"role":"assistant","content":[{"type":"text","text": ex["answer_with_eos"]}]}
    ]
    # Generation template: user only (+ generation prompt)
    messages_gen = [
        {"role":"user","content":[
            {"type":"image","image": ex["image"]},
            {"type":"text","text": ex["question"]}
        ]}
    ]
    train_text = processor.apply_chat_template(messages_train, tokenize=False, add_generation_prompt=False)
    gen_text   = processor.apply_chat_template(messages_gen,   tokenize=False, add_generation_prompt=True)

    out = processor(
        text=[train_text],
        images=[ex["image"]],
        return_tensors="pt",
        max_length=MAX_LENGTH,
        padding="longest",
        truncation=True
    )

    # Supervise only the answer_with_eos tokens (exact span match in tokenized sequence)
    def mask_answer_only(input_ids_2d, answer_text):
        full_ids = input_ids_2d[0].tolist()
        ans_ids  = processor.tokenizer(answer_text, add_special_tokens=False, return_tensors="pt")["input_ids"][0].tolist()

        def find_subseq(a, b):
            L, M = len(a), len(b)
            if M == 0 or M > L: return -1
            for i in range(L - M + 1):
                if a[i:i+M] == b: return i
            return -1

        labels = input_ids_2d.clone(); labels[:] = -100
        start = find_subseq(full_ids, ans_ids)
        if start >= 0:
            end = start + len(ans_ids)
            labels[:, start:end] = input_ids_2d[:, start:end]
        else:
            # fallback: supervise last few tokens
            keep = min(8, input_ids_2d.shape[1])
            labels[:, -keep:] = input_ids_2d[:, -keep:]
        return labels

    input_ids = out["input_ids"]
    labels    = mask_answer_only(input_ids, ex["answer_with_eos"])

    return {
        "pixel_values": out["pixel_values"].squeeze(0),
        "image_grid_thw": out.get("image_grid_thw", None).squeeze(0) if out.get("image_grid_thw", None) is not None else None,
        "input_ids": input_ids.squeeze(0),
        "attention_mask": out["attention_mask"].squeeze(0),
        "labels": labels.squeeze(0),
        # for eval
        "answer_text": ex["answer"],
        "raw_image":   ex["image"],
        "gen_prompt":  gen_text,
    }

train_encoded = [encode_example_vqa(ex) for ex in train_items]
val_encoded   = [encode_example_vqa(ex) for ex in val_items]

class CIFARVQADataset(Dataset):
    def __init__(self, encoded): self.encoded = encoded
    def __len__(self): return len(self.encoded)
    def __getitem__(self, i): return self.encoded[i]

train_ds = CIFARVQADataset(train_encoded)
val_ds   = CIFARVQADataset(val_encoded)

In [7]:
def pad_1d(seqs, pad_val):
    maxlen = max(x.size(0) for x in seqs)
    out = torch.full((len(seqs), maxlen), pad_val, dtype=seqs[0].dtype)
    for i, s in enumerate(seqs):
        out[i, :s.size(0)] = s
    return out

def collate_fn(batch):
    out = {}
    out["pixel_values"]   = torch.stack([b["pixel_values"] for b in batch], dim=0)
    grids = [b["image_grid_thw"] for b in batch]
    out["image_grid_thw"] = torch.stack(grids, dim=0) if all(g is not None for g in grids) else None

    out["input_ids"]      = pad_1d([b["input_ids"] for b in batch], PAD_ID)
    out["attention_mask"] = pad_1d([b["attention_mask"] for b in batch], 0)
    out["labels"]         = pad_1d([b["labels"] for b in batch], IGNOREID)

    out["answer_text"] = [b["answer_text"] for b in batch]
    out["raw_images"]  = [b["raw_image"]   for b in batch]
    out["gen_prompts"] = [b["gen_prompt"]  for b in batch]
    return out

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2,          shuffle=False, collate_fn=collate_fn)

# 7) Baseline Toy Perplexity

In [None]:
@torch.no_grad()
def eval_loss(model, loader):
    model.eval()
    total_loss, total_tokens = 0.0, 0
    for batch in loader:
        moved = {}
        for k, v in batch.items():
            if k == "pixel_values" and v is not None:
                moved[k] = v.to(device, dtype=dtype)
            elif k in ("input_ids","attention_mask","labels") and v is not None:
                moved[k] = v.to(device)
        if batch.get("image_grid_thw") is not None:
            moved["image_grid_thw"] = batch["image_grid_thw"].to(device)
        out = model(**moved)
        n_tokens = (moved["labels"] != -100).sum().item()
        total_loss += out.loss.item() * max(1, n_tokens)
        total_tokens += max(1, n_tokens)
    model.train()
    import math
    return math.exp(total_loss / max(1, total_tokens))

In [8]:
@torch.no_grad()
def eval_gen_accuracy(model, processor, loader, k_samples=50, max_new_tokens=3):
    model.eval()
    correct, seen = 0, 0
    for batch in loader:
        for img, gp, gold in zip(batch["raw_images"], batch["gen_prompts"], batch["answer_text"]):
            if seen >= k_samples: break

            enc = processor(text=[gp], images=[img], return_tensors="pt")
            enc = {k: v.to(device) for k, v in enc.items()}
            gen_ids = model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=1,                           # <- avoid beam visual expansion issues
                pad_token_id=processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                repetition_penalty=1.5,               # discourage "catcatcat..."
                length_penalty=2.0,                   # shorter outputs preferred
                use_cache=True,
            )
            new_tokens = gen_ids[:, enc["input_ids"].shape[1]:]
            text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip().lower()
            pred = text.split()[0] if text else ""
            if gold.lower() in pred:
                correct += 1
            seen += 1
        if seen >= k_samples: break
    model.train()
    return correct / max(1, seen)

print("Running quick baseline eval (pre-prune)...")
ppl_train = eval_loss(model, train_loader)
acc_val0  = eval_gen_accuracy(model, processor, val_loader, k_samples=40)
print(f"Train PPL (pre-prune): {ppl_train:.2f} | Val Gen@1 Acc (pre-prune): {acc_val0:.2%}")


Running quick baseline eval (pre-prune)...




Train PPL (pre-prune): 1531.34 | Val Gen@1 Acc (pre-prune): 0.00%


# 8) Apply Pruning (Pick mode via flag)

In [9]:
if PRUNE_MODE == "attn_heads":
    n = 0
    for name, attn in find_attn_modules(model):
        prune_attention_heads_logical_gqa(attn, PRUNE_RATIO)
        n += 1
    print(f"Pruned heads in {n} attention modules (GQA-safe, mask-based).")
elif PRUNE_MODE == "ffn_channels":
    n = 0
    for name, mlp in find_mlp_modules(model):
        prune_ffn_channels_logical_mask(mlp, PRUNE_RATIO, intermediate_size)
        n += 1
    print(f"Pruned channels in {n} MLP modules (mask-based).")
else:
    raise ValueError("PRUNE_MODE must be 'attn_heads' or 'ffn_channels'.")

ppl_after_prune = eval_loss(model, train_loader)
print(f"PPL (post-prune, pre-LoRA): {ppl_after_prune:.2f}")

Pruned channels in 28 MLP modules (mask-based).
PPL (post-prune, pre-LoRA): 241760.61


# 9) LoRA Setup (attention + MLP targets)

In [10]:
def collect_lora_targets(m: nn.Module) -> List[str]:
    names = set()
    for n, mod in m.named_modules():
        if isinstance(mod, nn.Linear):
            for key in ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]:
                if n.endswith(key): names.add(n.split(".")[-1])
    return sorted(list(names)) or ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]

targets = collect_lora_targets(model)
print("LoRA targets:", targets)

lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    bias="none", task_type="CAUSAL_LM",
    target_modules=targets
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

LoRA targets: ['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 18,464,768 || all params: 2,227,450,368 || trainable%: 0.8290


# 10) Tiny Fine-Tune Loop (Mixed Precision)

In [11]:
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=LR, weight_decay=0.0)
steps_per_epoch  = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_training_steps = EPOCHS * steps_per_epoch
sched = get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, num_training_steps)

scaler = torch.cuda.amp.GradScaler(enabled=(dtype==torch.float16))
model.train()
global_step = 0

for epoch in range(EPOCHS):
    for step, batch in enumerate(train_loader):
        moved = {}
        for k, v in batch.items():
            if k == "pixel_values":
                moved[k] = v.to(device, dtype=dtype)
            elif k in ("input_ids","attention_mask","labels"):
                moved[k] = v.to(device)
        if batch.get("image_grid_thw") is not None:
            moved["image_grid_thw"] = batch["image_grid_thw"].to(device)

        with torch.cuda.amp.autocast(enabled=(dtype==torch.float16)):
            out = model(**moved)
            loss = out.loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer); scaler.update()
            optimizer.zero_grad(set_to_none=True)
            sched.step()
            global_step += 1
            print(f"step {global_step} | loss={(loss.item()*GRAD_ACCUM_STEPS):.4f}")

torch.cuda.empty_cache()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


step 1 | loss=12.2511
step 2 | loss=13.1624
step 3 | loss=12.0794
step 4 | loss=11.3681
step 5 | loss=9.6685
step 6 | loss=10.9155
step 7 | loss=10.4755
step 8 | loss=9.2288
step 9 | loss=6.3020
step 10 | loss=6.9667
step 11 | loss=5.9652
step 12 | loss=5.4434
step 13 | loss=4.3216
step 14 | loss=3.4158
step 15 | loss=2.8532
step 16 | loss=5.8924
step 17 | loss=2.5747
step 18 | loss=1.6085
step 19 | loss=3.1715
step 20 | loss=3.4979
step 21 | loss=0.6728
step 22 | loss=0.8974
step 23 | loss=0.5531
step 24 | loss=0.1698
step 25 | loss=0.6673
step 26 | loss=0.3225
step 27 | loss=1.4298
step 28 | loss=0.1493
step 29 | loss=0.0857
step 30 | loss=0.1837
step 31 | loss=0.2181
step 32 | loss=0.0284
step 33 | loss=0.0163
step 34 | loss=0.0201
step 35 | loss=0.0743
step 36 | loss=0.0120
step 37 | loss=0.0068
step 38 | loss=0.0065
step 39 | loss=0.0157
step 40 | loss=0.0751
step 41 | loss=0.0039
step 42 | loss=0.2341
step 43 | loss=0.0139
step 44 | loss=0.0123
step 45 | loss=0.0014
step 46 | los

# 11) Evaluate After LoRA & Save Artifacts

In [12]:
ppl_after_lora = eval_loss(model, train_loader)
acc_val1       = eval_gen_accuracy(model, processor, val_loader, k_samples=40)
print(f"PPL (post-LoRA): {ppl_after_lora:.2f} | Val Gen@1 Acc (post-LoRA): {acc_val1:.2%}")

# Save LoRA
model.save_pretrained(os.path.join(OUTPUT_DIR, "lora_adapter"))
processor.save_pretrained(OUTPUT_DIR)
print("Saved LoRA adapter to:", os.path.join(OUTPUT_DIR, "lora_adapter"))

# Optional: merge LoRA into base (extra VRAM)
print("Merging LoRA into base (uses extra VRAM)...")
merged = model.merge_and_unload()
merged.save_pretrained(os.path.join(OUTPUT_DIR, "merged_full"))
print("Merged checkpoint saved to:", os.path.join(OUTPUT_DIR, "merged_full"))

PPL (post-LoRA): 1.10 | Val Gen@1 Acc (post-LoRA): 100.00%
Saved LoRA adapter to: ./demo_pruned_lora_qwen2vl/lora_adapter
Merging LoRA into base (uses extra VRAM)...
Merged checkpoint saved to: ./demo_pruned_lora_qwen2vl/merged_full


# 12) Quick Inference Demo (Visual QA)

In [13]:
def qwen_vl_infer(model, processor, pil_image, question: str, max_new_tokens=3):
    pil_image = pil_image.resize((448, 448))
    messages = [{"role":"user","content":[
        {"type":"image","image": pil_image},
        {"type":"text","text": question}
    ]}]
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = processor(text=[prompt], images=[pil_image], return_tensors="pt")
    if "image_grid_thw" in inputs:
        inputs["image_grid_thw"] = inputs["image_grid_thw"]

    for k, v in list(inputs.items()):
        if k == "pixel_values":
            inputs[k] = v.to(device, dtype=dtype)
        else:
            inputs[k] = v.to(device)

    gen_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        pad_token_id=processor.tokenizer.eos_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        repetition_penalty=1.5,
        length_penalty=2.0,
        use_cache=True
    )
    new_tokens = gen_ids[:, inputs["input_ids"].shape[1]:]
    text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return text

def demo_val_predictions(model_to_use, k=5):
    for i in range(k):
        img_pil, y = test_raw[i]
        ans = label_names[y]
        pred = qwen_vl_infer(model_to_use, processor, img_pil, "What is in this image?")
        print(f"GT: {ans:<10s} | PRED: {pred}")

print("\n--- Demo predictions (merged model) ---")
demo_val_predictions(merged, k=5)


--- Demo predictions (merged model) ---




GT: cat        | PRED: cat
GT: ship       | PRED: shipp
GT: ship       | PRED: shipp
GT: airplane   | PRED: airplane
GT: frog       | PRED: frog
