# Qwen2‑VL Demo + Exercises (Starter Notebook)

This notebook includes:
- **Demo**: Load Qwen2‑VL, (optional) prune, LoRA fine‑tune on CIFAR‑10 as VQA, evaluate, simple inference.
- **Exercise A**: Constrained decoding via candidate scoring (log‑prob aggregation).
- **Exercise B**: Pruning ablation (FFN vs Attn heads) measuring perplexity and images/sec.
- **Exercise C**: Prompt/EOS ablation—measure repetition and accuracy effects.

All student work is in clearly marked **TODO** functions.


## 0) Setup (Installs & GPU)

This cell:
- Sets up the environment for the notebook.
- Prints the versions of PyTorch, CUDA, and Python.
- Checks if a GPU is available and prints its name.
- Sets a random seed for reproducibility.

In [3]:

import os, gc, math, random, platform, time, warnings
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import List

os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| Py:", platform.python_version())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU only")

SEED = 42
random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)


Torch: 2.3.1+cu121 | CUDA: 12.1 | Py: 3.11.11
GPU: Tesla T4


## 1) Imports & Config

This cell:
- Imports libraries for image processing, model configuration, and training.
- Defines key configuration parameters:
  - `MODEL_NAME`: The name of the model to load.
  - `PRUNE_MODE`: The pruning strategy (attention heads or FFN channels).
  - `PRUNE_RATIO`: The percentage of components to prune.
  - `USE_FP32`: Whether to use FP32 precision (default is FP16 for T4 GPUs).
- Sets learning rate, batch size, and other training parameters.
- Creates an output directory for saving results.

In [17]:

from PIL import Image
import torchvision
import torchvision.transforms as T

from transformers import (
    AutoProcessor, AutoConfig, AutoModelForVision2Seq, get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model

# --- Demo knobs ---
MODEL_NAME   = "Qwen/Qwen2-VL-2B-Instruct"
PRUNE_MODE   = "ffn_channels"               # "attn_heads" or "ffn_channels"
PRUNE_RATIO  = 0.15
USE_FP32     = False                        # FP16 recommended on T4

# LoRA / Train knobs (tiny)
LR                = 1e-4
EPOCHS            = 1
BATCH_SIZE        = 2
GRAD_ACCUM_STEPS  = 4
MAX_LENGTH        = 512
WARMUP_STEPS      = 20
N_TRAIN, N_VAL    = 400, 100  # smaller for faster classroom runs

OUTPUT_DIR        = "./demo_pruned_lora_qwen2vl"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float32 if USE_FP32 else torch.float16
print(f"Device: {device} | Dtype: {dtype} | Prune: {PRUNE_MODE} @ {int(PRUNE_RATIO*100)}%")

processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
config    = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)




Device: cuda | Dtype: torch.float16 | Prune: ffn_channels @ 15%


## Load Model

This cell:
- Loads the Qwen2-VL model with the specified precision (FP16 or FP32).
- Handles out-of-memory errors by falling back to FP16 if necessary.
- Enables gradient checkpointing to save memory during training.
- Extracts key model dimensions such as hidden size, number of attention heads, and intermediate size.
- Prints the total number of parameters in the model.

In [16]:
def load_model(dtype):
    try:
        m = AutoModelForVision2Seq.from_pretrained(
            MODEL_NAME, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True
        ).to(device)
        return m, dtype
    except torch.cuda.OutOfMemoryError:
        print("OOM at requested dtype; falling back to FP16.")
        torch.cuda.empty_cache(); gc.collect()
        m = AutoModelForVision2Seq.from_pretrained(
            MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True
        ).to(device)
        return m, torch.float16

model, dtype = load_model(dtype)
model.gradient_checkpointing_enable()

hidden_size        = getattr(config, "hidden_size", getattr(config, "hidden_dim", None))
num_heads          = getattr(config, "num_attention_heads", getattr(config, "num_heads", None))
intermediate_size  = getattr(config, "intermediate_size", getattr(config, "ffn_hidden_size", None))
assert hidden_size and num_heads and intermediate_size, "Missing key config dims."
head_dim = hidden_size // num_heads

total_params_m = sum(p.numel() for p in model.parameters())/1e6
print(f"Total params: {total_params_m:.1f}M")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Total params: 2209.0M


## 2) Pruning Utilities

This section:
- Defines utility functions to locate specific modules in the model:
  - `find_attn_modules`: Finds attention modules with Q, K, V, and O projections.
  - `find_mlp_modules`: Finds MLP modules with gate, up, and down projections.
- Implements pruning utilities for attention heads and FFN channels:
  - `prune_attention_heads_logical_gqa`: Prunes attention heads while maintaining GQA (Grouped Query Attention) constraints.
  - `prune_ffn_channels_logical_mask`: Prunes FFN channels by masking weights based on their norms.

In [5]:

def find_attn_modules(module):
    for name, m in module.named_modules():
        if all(hasattr(m, x) for x in ["q_proj","k_proj","v_proj","o_proj"]):
            yield name, m

def find_mlp_modules(module):
    for name, m in module.named_modules():
        if all(hasattr(m, x) for x in ["gate_proj","up_proj","down_proj"]):
            yield name, m




In [None]:
@torch.no_grad()
def prune_attention_heads_logical_gqa(attn_mod, ratio: float):
    n_q  = getattr(attn_mod, "num_heads", None)         or getattr(attn_mod, "n_heads", None)         or getattr(getattr(attn_mod, "config", None), "num_attention_heads", None)
    n_kv = getattr(attn_mod, "num_key_value_heads", None)         or getattr(attn_mod, "n_kv_heads", None)         or getattr(getattr(attn_mod, "config", None), "num_key_value_heads", None)         or n_q
    hd = getattr(attn_mod, "head_dim", None)
    if hd is None:
        q_rows = attn_mod.q_proj.weight.shape[0]
        assert n_q and q_rows % n_q == 0, "Cannot infer head_dim."
        hd = q_rows // n_q
    assert n_q and n_kv and n_q >= 1 and n_kv >= 1
    assert n_q % n_kv == 0, "Expected n_q divisible by n_kv for GQA."

    n_keep_q = max(1, int(n_q * (1.0 - ratio)))
    prune_q  = list(range(n_keep_q, n_q))
    if not prune_q:
        return

    def rows_for_heads(head_ids, per_head):
        rows = []
        for h in head_ids:
            s = h * per_head
            rows.extend(range(s, s + per_head))
        return rows

    rows_q  = rows_for_heads(prune_q, hd)
    group   = max(1, n_q // n_kv)
    prune_kv = sorted(set(h // group for h in prune_q))
    rows_kv = rows_for_heads(prune_kv, hd)

    # q_proj rows -> mask multiply
    Wq = attn_mod.q_proj.weight
    mask_q = torch.ones(Wq.shape[0], device=Wq.device, dtype=Wq.dtype)
    if rows_q:
        idx_q = torch.tensor([r for r in rows_q if 0 <= r < Wq.shape[0]], device=Wq.device, dtype=torch.long)
        if idx_q.numel() > 0:
            mask_q.index_fill_(0, idx_q, 0)
            Wq.mul_(mask_q[:, None])
            bq = getattr(attn_mod.q_proj, "bias", None)
            if bq is not None: bq.mul_(mask_q.to(bq.dtype))

    # k_proj / v_proj rows
    for pname, rows in (("k_proj", rows_kv), ("v_proj", rows_kv)):
        proj = getattr(attn_mod, pname)
        W    = proj.weight
        mask = torch.ones(W.shape[0], device=W.device, dtype=W.dtype)
        if rows:
            idx  = torch.tensor([r for r in rows if 0 <= r < W.shape[0]], device=W.device, dtype=torch.long)
            if idx.numel() > 0:
                mask.index_fill_(0, idx, 0)
                W.mul_(mask[:, None])
                b = getattr(proj, "bias", None)
                if b is not None: b.mul_(mask.to(b.dtype))

    # o_proj columns corresponding to pruned Q rows
    Wo = attn_mod.o_proj.weight
    col_mask = torch.ones(Wo.shape[1], device=Wo.device, dtype=Wo.dtype)
    if rows_q:
        idx_cols = torch.tensor([c for c in rows_q if 0 <= c < Wo.shape[1]], device=Wo.device, dtype=torch.long)
        if idx_cols.numel() > 0:
            col_mask.index_fill_(0, idx_cols, 0)
            Wo.mul_(col_mask[None, :])

@torch.no_grad()
def prune_ffn_channels_logical_mask(mlp_mod: nn.Module, ratio: float, intermediate_size: int):
    n_prune = max(1, int(intermediate_size * ratio))
    if n_prune <= 0: return
    down_W = mlp_mod.down_proj.weight  # [hidden, inter]
    col_norms = torch.norm(down_W, p=1, dim=0)
    prune_idx = torch.topk(col_norms, k=n_prune, largest=False).indices

    inter_dim = down_W.shape[1]
    keep_mask = torch.ones(inter_dim, device=down_W.device, dtype=down_W.dtype)
    if prune_idx.numel() > 0:
        keep_mask.index_fill_(0, prune_idx, 0)

    # down_proj: zero columns
    mlp_mod.down_proj.weight.mul_(keep_mask[None, :])
    # up_proj / gate_proj: zero rows
    for name in ["up_proj","gate_proj"]:
        getattr(mlp_mod, name).weight.mul_(keep_mask[:, None])
        b = getattr(getattr(mlp_mod, name), "bias", None)
        if b is not None:
            b.mul_(keep_mask.to(b.dtype))

## 3) CIFAR‑10 → VQA Data (EOS‑Supervised Answers)

This section:
- Prepares a toy dataset using CIFAR-10 images and their labels.
- Resizes images to 448x448 pixels to match the model's input size.
- Creates training and validation datasets with a fixed number of examples.
- Each example includes:
  - An image.
  - A question ("What is in this image?").
  - The answer (e.g., "cat") with and without an EOS token.
- Encodes the dataset for visual question answering (VQA) tasks.
- Masks labels to supervise only the answer tokens.

In [6]:

# tokenizer setup
if processor.tokenizer.pad_token_id is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
PAD_ID   = processor.tokenizer.pad_token_id
IGNOREID = -100
EOS      = processor.tokenizer.eos_token

root = "./data_cifar10"
train_raw = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
test_raw  = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
label_names = train_raw.classes

resize_to = 448
resize_tf = T.Resize((resize_to, resize_to))

def make_examples(ds, n_take):
    items = []
    for i in range(n_take):
        img_pil, y = ds[i]
        img = resize_tf(img_pil)
        ans = label_names[y]
        items.append({
            "image": img,
            "question": "What is in this image?",
            "answer": ans,
            "answer_with_eos": ans + EOS,
        })
    return items

train_items = make_examples(train_raw, N_TRAIN)
val_items   = make_examples(test_raw,  N_VAL)

def encode_example_vqa(ex):
    messages_train = [
        {"role":"user","content":[
            {"type":"image","image": ex["image"]},
            {"type":"text","text": ex["question"]}
        ]},
        {"role":"assistant","content":[{"type":"text","text": ex["answer_with_eos"]}]}
    ]
    messages_gen = [
        {"role":"user","content":[
            {"type":"image","image": ex["image"]},
            {"type":"text","text": ex["question"]}
        ]}
    ]
    train_text = processor.apply_chat_template(messages_train, tokenize=False, add_generation_prompt=False)
    gen_text   = processor.apply_chat_template(messages_gen,   tokenize=False, add_generation_prompt=True)

    out = processor(
        text=[train_text],
        images=[ex["image"]],
        return_tensors="pt",
        max_length=MAX_LENGTH,
        padding="longest",
        truncation=True
    )

    def mask_answer_only(input_ids_2d, answer_text):
        full_ids = input_ids_2d[0].tolist()
        ans_ids  = processor.tokenizer(answer_text, add_special_tokens=False, return_tensors="pt")["input_ids"][0].tolist()
        def find_subseq(a, b):
            L, M = len(a), len(b)
            if M == 0 or M > L: return -1
            for i in range(L - M + 1):
                if a[i:i+M] == b: return i
            return -1
        labels = input_ids_2d.clone(); labels[:] = -100
        start = find_subseq(full_ids, ans_ids)
        if start >= 0:
            end = start + len(ans_ids)
            labels[:, start:end] = input_ids_2d[:, start:end]
        else:
            keep = min(8, input_ids_2d.shape[1])
            labels[:, -keep:] = input_ids_2d[:, -keep:]
        return labels

    input_ids = out["input_ids"]
    labels    = mask_answer_only(input_ids, ex["answer_with_eos"])

    return {
        "pixel_values": out["pixel_values"].squeeze(0),
        "image_grid_thw": out.get("image_grid_thw", None).squeeze(0) if out.get("image_grid_thw", None) is not None else None,
        "input_ids": input_ids.squeeze(0),
        "attention_mask": out["attention_mask"].squeeze(0),
        "labels": labels.squeeze(0),
        "answer_text": ex["answer"],
        "raw_image":   ex["image"],
        "gen_prompt":  gen_text,
    }

train_encoded = [encode_example_vqa(ex) for ex in train_items]
val_encoded   = [encode_example_vqa(ex) for ex in val_items]




Files already downloaded and verified
Files already downloaded and verified


In [None]:
class CIFARVQADataset(Dataset):
    def __init__(self, encoded): self.encoded = encoded
    def __len__(self): return len(self.encoded)
    def __getitem__(self, i): return self.encoded[i]

train_ds = CIFARVQADataset(train_encoded)
val_ds   = CIFARVQADataset(val_encoded)

def pad_1d(seqs, pad_val):
    maxlen = max(x.size(0) for x in seqs)
    out = torch.full((len(seqs), maxlen), pad_val, dtype=seqs[0].dtype)
    for i, s in enumerate(seqs):
        out[i, :s.size(0)] = s
    return out

def collate_fn(batch):
    out = {}
    out["pixel_values"]   = torch.stack([b["pixel_values"] for b in batch], dim=0)
    grids = [b["image_grid_thw"] for b in batch]
    out["image_grid_thw"] = torch.stack(grids, dim=0) if all(g is not None for g in grids) else None

    out["input_ids"]      = pad_1d([b["input_ids"] for b in batch], PAD_ID)
    out["attention_mask"] = pad_1d([b["attention_mask"] for b in batch], 0)
    out["labels"]         = pad_1d([b["labels"] for b in batch], IGNOREID)

    out["answer_text"] = [b["answer_text"] for b in batch]
    out["raw_images"]  = [b["raw_image"]   for b in batch]
    out["gen_prompts"] = [b["gen_prompt"]  for b in batch]
    return out

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2,          shuffle=False, collate_fn=collate_fn)

## 4) Baseline Eval Helpers

This section:
- Defines helper functions for evaluating the model:
  - `eval_loss`: Computes the perplexity of the model on the training dataset.
  - `eval_gen_accuracy`: Evaluates the model's generation accuracy on the validation dataset.
- Runs a quick baseline evaluation to measure perplexity and accuracy before pruning.

In [7]:

@torch.no_grad()
def eval_loss(model, loader):
    model.eval()
    total_loss, total_tokens = 0.0, 0
    for batch in loader:
        moved = {}
        for k, v in batch.items():
            if k == "pixel_values" and v is not None:
                moved[k] = v.to(device, dtype=dtype)
            elif k in ("input_ids","attention_mask","labels") and v is not None:
                moved[k] = v.to(device)
        if batch.get("image_grid_thw") is not None:
            moved["image_grid_thw"] = batch["image_grid_thw"].to(device)
        out = model(**moved)
        n_tokens = (moved["labels"] != -100).sum().item()
        total_loss += out.loss.item() * max(1, n_tokens)
        total_tokens += max(1, n_tokens)
    model.train()
    import math
    return math.exp(total_loss / max(1, total_tokens))

@torch.no_grad()
def eval_gen_accuracy(model, processor, loader, k_samples=50, max_new_tokens=3):
    model.eval()
    correct, seen = 0, 0
    for batch in loader:
        for img, gp, gold in zip(batch["raw_images"], batch["gen_prompts"], batch["answer_text"]):
            if seen >= k_samples: break
            enc = processor(text=[gp], images=[img], return_tensors="pt")
            enc = {k: v.to(device) for k, v in enc.items()}
            gen_ids = model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=1,
                pad_token_id=processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                repetition_penalty=1.5,
                length_penalty=2.0,
                use_cache=True,
            )
            new_tokens = gen_ids[:, enc["input_ids"].shape[1]:]
            text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip().lower()
            pred = text.split()[0] if text else ""
            if gold.lower() in pred:
                correct += 1
            seen += 1
        if seen >= k_samples: break
    model.train()
    return correct / max(1, seen)

print("Quick baseline eval (pre‑prune)…")
ppl_train = eval_loss(model, train_loader)
acc_val0  = eval_gen_accuracy(model, processor, val_loader, k_samples=30)
print(f"Train PPL (pre‑prune): {ppl_train:.2f} ")


Quick baseline eval (pre‑prune)…




Train PPL (pre‑prune): 1504.16 | 


## 5) Apply Pruning

This cell:
- Applies pruning to the model based on the selected mode (`PRUNE_MODE`):
  - `attn_heads`: Prunes attention heads.
  - `ffn_channels`: Prunes FFN channels.
- Evaluates the model's perplexity after pruning but before fine-tuning.

In [8]:

if PRUNE_MODE == "attn_heads":
    n = 0
    for _, attn in find_attn_modules(model):
        prune_attention_heads_logical_gqa(attn, PRUNE_RATIO)
        n += 1
    print(f"Pruned heads in {n} attention modules (mask‑based).")
elif PRUNE_MODE == "ffn_channels":
    n = 0
    for _, mlp in find_mlp_modules(model):
        prune_ffn_channels_logical_mask(mlp, PRUNE_RATIO, intermediate_size)
        n += 1
    print(f"Pruned channels in {n} MLP modules (mask‑based).")
else:
    raise ValueError("PRUNE_MODE must be 'attn_heads' or 'ffn_channels'.")

ppl_after_prune = eval_loss(model, train_loader)
print(f"PPL (post‑prune, pre‑LoRA): {ppl_after_prune:.2f}")


Pruned channels in 28 MLP modules (mask‑based).
PPL (post‑prune, pre‑LoRA): 262912.71


## 6) LoRA Setup & Tiny Fine‑Tune

This section:
- Identifies target modules for LoRA (Low-Rank Adaptation).
- Configures LoRA parameters such as rank, alpha, and dropout.
- Wraps the model with LoRA layers for fine-tuning.
- Implements a fine-tuning loop with mixed precision (FP16).
- Uses gradient accumulation to simulate a larger batch size.
- Updates the optimizer and learning rate scheduler after each step.
- Prints the loss at each step for monitoring.

In [9]:

def collect_lora_targets(m: nn.Module) -> List[str]:
    names = set()
    for n, mod in m.named_modules():
        if isinstance(mod, nn.Linear):
            for key in ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]:
                if n.endswith(key): names.add(n.split(".")[-1])
    return sorted(list(names)) or ["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]

targets = collect_lora_targets(model)
print("LoRA targets:", targets)

lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    bias="none", task_type="CAUSAL_LM",
    target_modules=targets
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=LR, weight_decay=0.0)
steps_per_epoch  = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_training_steps = EPOCHS * steps_per_epoch
sched = get_linear_schedule_with_warmup(optimizer, WARMUP_STEPS, num_training_steps)

scaler = torch.cuda.amp.GradScaler(enabled=(dtype==torch.float16))
model.train()
global_step = 0

for epoch in range(EPOCHS):
    for step, batch in enumerate(train_loader):
        moved = {}
        for k, v in batch.items():
            if k == "pixel_values":
                moved[k] = v.to(device, dtype=dtype)
            elif k in ("input_ids","attention_mask","labels"):
                moved[k] = v.to(device)
        if batch.get("image_grid_thw") is not None:
            moved["image_grid_thw"] = batch["image_grid_thw"].to(device)

        with torch.cuda.amp.autocast(enabled=(dtype==torch.float16)):
            out = model(**moved)
            loss = out.loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer); scaler.update()
            optimizer.zero_grad(set_to_none=True)
            sched.step()
            global_step += 1
            if global_step % 5 == 0:
                print(f"step {global_step} | loss={(loss.item()*GRAD_ACCUM_STEPS):.4f}")

torch.cuda.empty_cache()


LoRA targets: ['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 18,464,768 || all params: 2,227,450,368 || trainable%: 0.8290


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


step 5 | loss=10.3545
step 10 | loss=8.7103
step 15 | loss=5.7335
step 20 | loss=2.5249
step 25 | loss=0.4311
step 30 | loss=0.1571
step 35 | loss=0.0465
step 40 | loss=1.3951
step 45 | loss=0.0270
step 50 | loss=0.2318


In [10]:

ppl_after_lora = eval_loss(model, train_loader)
acc_val1       = eval_gen_accuracy(model, processor, val_loader, k_samples=30)
print(f"PPL (post‑LoRA): {ppl_after_lora:.2f} | Val Gen@1 Acc (post‑LoRA): {acc_val1:.2%}")

# Save LoRA + merged
model.save_pretrained(os.path.join(OUTPUT_DIR, "lora_adapter"))
processor.save_pretrained(OUTPUT_DIR)
print("Saved LoRA adapter to:", os.path.join(OUTPUT_DIR, "lora_adapter"))

print("Merging LoRA into base (uses extra VRAM)…")
merged = model.merge_and_unload()
merged.save_pretrained(os.path.join(OUTPUT_DIR, "merged_full"))
print("Merged checkpoint saved to:", os.path.join(OUTPUT_DIR, "merged_full"))


PPL (post‑LoRA): 1.12 | Val Gen@1 Acc (post‑LoRA): 90.00%
Saved LoRA adapter to: ./demo_pruned_lora_qwen2vl/lora_adapter
Merging LoRA into base (uses extra VRAM)…
Merged checkpoint saved to: ./demo_pruned_lora_qwen2vl/merged_full


## 7) Inference Demo (Greedy)

This section:
- Demonstrates the model's ability to answer visual questions after fine-tuning.
- Defines a function to generate answers for images and questions.
- Runs the model on a few examples from the validation dataset and prints the predictions.

In [11]:

def qwen_vl_infer(model, processor, pil_image, question: str, max_new_tokens=3):
    pil_image = pil_image.resize((448, 448))
    messages = [{"role":"user","content":[
        {"type":"image","image": pil_image},
        {"type":"text","text": question}
    ]}]
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = processor(text=[prompt], images=[pil_image], return_tensors="pt")
    for k, v in list(inputs.items()):
        if k == "pixel_values":
            inputs[k] = v.to(device, dtype=dtype)
        else:
            inputs[k] = v.to(device)

    gen_ids = model.generate(
        **inputs,
        max_new_tokens=3,
        do_sample=False,
        num_beams=1,
        pad_token_id=processor.tokenizer.eos_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        repetition_penalty=1.5,
        length_penalty=2.0,
        use_cache=True
    )
    new_tokens = gen_ids[:, inputs["input_ids"].shape[1]:]
    text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return text

print("\n--- Demo predictions (merged model) ---")
for i in range(5):
    img_pil, y = test_raw[i]
    ans = label_names[y]
    pred = qwen_vl_infer(merged, processor, img_pil, "What is in this image?")
    print(f"GT: {ans:<10s} | PRED: {pred}")



--- Demo predictions (merged model) ---




GT: cat        | PRED: cat
GT: ship       | PRED: shipship
GT: ship       | PRED: ship
GT: airplane   | PRED: airplane
GT: frog       | PRED: frog


## Exercise A — Constrained Decoding (Candidate Scoring)

This exercise:
- Implements constrained decoding by scoring each candidate answer using log-prob aggregation.
- Defines functions to:
  - `score_answer_logprob`: Compute the log-probability of a candidate answer given the input.
  - `predict_with_candidates`: Select the best candidate answer based on log-probability scores.
- Tests the implementation with a smoke test.


**Goal:** Always answer with a CIFAR‑10 class name by scoring each candidate and choosing the argmax.

Implement:
- `score_answer_logprob(model, processor, pil_img, question, answer)`
- `predict_with_candidates(model, processor, pil_img, question, candidates)`

Hints:
1. Build the user‑only prompt as in the demo (with `add_generation_prompt=True`).
2. Tokenize candidate `answer` with the tokenizer only (no image), `add_special_tokens=False`.
3. Accumulate log‑probs by feeding the next gold token and summing `log_softmax(logits[:, -1])`.


In [15]:
import torch
from torch.nn.functional import log_softmax

CANDIDATES = [c.lower() for c in label_names]

@torch.no_grad()
def score_answer_logprob(model, processor, pil_img, question, answer) -> float:
    """Return total conditional log-prob of 'answer' tokens given (image, question)."""
    # 1) Build base prompt (user: image + question), no generation prompt
    messages = [{"role":"user","content":[
        {"type":"image","image": pil_img.resize((448, 448))},
        {"type":"text","text": question}
    ]}]
    base_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

    # 2) Tokenize base once
    enc = processor(text=[base_text], images=[pil_img], return_tensors="pt")
    # (Qwen2-VL sometimes returns image_grid_thw)
    if "image_grid_thw" in enc:
        enc["image_grid_thw"] = enc["image_grid_thw"]

    # Move to device/dtypes model expects
    device = next(model.parameters()).device
    for k, v in list(enc.items()):
        if k == "pixel_values":
            enc[k] = v.to(device, dtype=getattr(model, "dtype", torch.float16))
        else:
            enc[k] = v.to(device)

    # 3) Tokenize the answer WITHOUT specials
    ans_ids = processor.tokenizer(
        answer, add_special_tokens=False, return_tensors="pt"
    )["input_ids"][0].to(device)

    # 4) Teacher-forced scoring loop
    #    We will append tokens step by step and accumulate log p(next_token | context).
    total_logp = 0.0

    input_ids      = enc["input_ids"]
    attention_mask = enc["attention_mask"]

    for t in ans_ids.tolist():
        # Forward with current context
        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=enc["pixel_values"],
            image_grid_thw=enc.get("image_grid_thw", None),
            use_cache=False,  # keep it simple/stable
        )
        # Take logits for the last position
        next_logits = out.logits[:, -1, :]                        # [B, V]
        next_logp   = log_softmax(next_logits, dim=-1)[0, t].item()
        total_logp += next_logp

        # Append the gold token to the context (teacher forcing)
        input_ids      = torch.cat([input_ids, torch.tensor([[t]], device=device)], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, :1])], dim=1)

    return float(total_logp)


@torch.no_grad()
def predict_with_candidates(model, processor, pil_img, question, candidates) -> str:
    """Score each candidate via teacher forcing and return the best string."""
    best_ans, best_score = None, float("-inf")
    for cand in candidates:
        # Skip empty strings just in case
        if not cand:
            continue
        s = score_answer_logprob(model, processor, pil_img, question, cand)
        if s > best_score:
            best_score, best_ans = s, cand
    return best_ans if best_ans is not None else ""


# Quick smoke (should now run without raising)
try:
    test = predict_with_candidates(
        merged, processor, test_raw[0][0], "What is in this image?", CANDIDATES
    )
    print(f"Exercise A smoke test ran. {test}")
except Exception as e:
    print("Exercise A: unexpected error ->", e)




Exercise A smoke test ran. frog


## Exercise B — Pruning Ablation (Quality & Speed)

This exercise:
- Compares the impact of pruning FFN channels vs attention heads at different ratios.
- Measures:
  - Perplexity on the training dataset.
  - Images per second during inference using the candidate predictor from Exercise A.
- Defines functions to:
  - `apply_pruning`: Apply pruning to the model based on the specified mode and ratio.
  - `measure_decode_ips`: Measure the inference speed in images per second.

In [13]:

SETTINGS = [
    ("ffn_channels", 0.15),
    ("ffn_channels", 0.30),
    ("attn_heads",   0.15),
    ("attn_heads",   0.30),
]

def apply_pruning(model, mode: str, ratio: float):
    """Apply FFN-channel or attention-head pruning in-place (uses helpers above)."""
    if mode == "ffn_channels":
        n = 0
        for name, mlp in find_mlp_modules(model):
            prune_ffn_channels_logical_mask(mlp, ratio, intermediate_size)
            n += 1
        print(f"[Prune] FFN channels @ {ratio:.2f} on {n} MLP blocks.")
    elif mode == "attn_heads":
        n = 0
        for name, attn in find_attn_modules(model):
            prune_attention_heads_logical_gqa(attn, ratio)
            n += 1
        print(f"[Prune] Attention heads @ {ratio:.2f} on {n} attention blocks.")
    else:
        raise ValueError("mode must be 'ffn_channels' or 'attn_heads'")

@torch.no_grad()
def measure_decode_ips(model, processor, val_loader, n_samples=64) -> float:
    """
    Time candidate predictions (using predict_with_candidates from Exercise A)
    over n_samples validation images and return images/sec.
    """
    model.eval()
    device = next(model.parameters()).device

    # Warmup a single pass to stabilize kernels / autotune
    try:
        warm_img = val_loader.dataset.encoded[0]["raw_image"]
    except Exception:
        # fallback: pull first from loader
        warm_batch = next(iter(val_loader))
        warm_img = warm_batch["raw_images"][0]
    _ = predict_with_candidates(model, processor, warm_img, "What is in this image?", CANDIDATES)
    if device.type == "cuda":
        torch.cuda.synchronize(device)

    seen = 0
    t0 = time.perf_counter()
    for batch in val_loader:
        for pil_img in batch["raw_images"]:
            _ = predict_with_candidates(model, processor, pil_img, "What is in this image?", CANDIDATES)
            seen += 1
            if seen >= n_samples:
                break
        if seen >= n_samples:
            break
    if device.type == "cuda":
        torch.cuda.synchronize(device)
    t1 = time.perf_counter()

    elapsed = max(1e-6, t1 - t0)
    return float(seen / elapsed)

ablation_results = []
for mode, ratio in SETTINGS:
    base_model, _ = load_model(dtype)  # fresh base
    base_model.eval()
    apply_pruning(base_model, mode, ratio)
    ppl = eval_loss(base_model, train_loader)
    try:
        ips = measure_decode_ips(base_model, processor, val_loader, n_samples=64)
    except Exception as e:
        ips = float("nan")
        print(f"[{mode}@{ratio}] measure_decode_ips error:", e)
    ablation_results.append({"mode": mode, "ratio": ratio, "ppl": ppl, "img_per_sec": ips})
    print(ablation_results[-1])

ablation_results


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Prune] FFN channels @ 0.15 on 28 MLP blocks.
{'mode': 'ffn_channels', 'ratio': 0.15, 'ppl': 262912.7042795568, 'img_per_sec': 0.4641501000217031}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Prune] FFN channels @ 0.30 on 28 MLP blocks.
{'mode': 'ffn_channels', 'ratio': 0.3, 'ppl': 864006.941955165, 'img_per_sec': 0.46507204081176806}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Prune] Attention heads @ 0.15 on 28 attention blocks.
{'mode': 'attn_heads', 'ratio': 0.15, 'ppl': 2075.9372525803396, 'img_per_sec': 0.4633236150500874}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Prune] Attention heads @ 0.30 on 28 attention blocks.
{'mode': 'attn_heads', 'ratio': 0.3, 'ppl': 2075.9371614799047, 'img_per_sec': 0.46318713102019793}


[{'mode': 'ffn_channels',
  'ratio': 0.15,
  'ppl': 262912.7042795568,
  'img_per_sec': 0.4641501000217031},
 {'mode': 'ffn_channels',
  'ratio': 0.3,
  'ppl': 864006.941955165,
  'img_per_sec': 0.46507204081176806},
 {'mode': 'attn_heads',
  'ratio': 0.15,
  'ppl': 2075.9372525803396,
  'img_per_sec': 0.4633236150500874},
 {'mode': 'attn_heads',
  'ratio': 0.3,
  'ppl': 2075.9371614799047,
  'img_per_sec': 0.46318713102019793}]

## Exercise C — Prompt / EOS Ablation (Repetition vs Accuracy)

This exercise:
- Evaluates the impact of including an EOS token in supervised answers.
- Measures:
  - Repetition rate in generated text.
  - Accuracy of predictions with and without the EOS token.
- Defines functions to:
  - `make_examples_with_eos_flag`: Create datasets with and without EOS tokens.
  - `compute_repetition_rate`: Calculate the repetition rate in generated text.
- Compares the results for both settings.

In [22]:
# ==== Config for this evaluation ====
PROMPT_TEXT = "Answer in 3–8 words. Start with the class name, then one short phrase."
MAX_NEW_TOKENS = 25
DO_SAMPLE = True
TEMPERATURE = 0.8
TOP_P = 0.9
NUM_BEAMS = 1

# ==== Robust repetition metrics ====
from collections import Counter

def ngrams(tokens, n):
    return [" ".join(tokens[i:i+n]) for i in range(0, max(0, len(tokens)-n+1))]

def looks_repetitive_full(t: str, n_max: int = 3, min_repeat: int = 2) -> bool:
    """
    Heuristic repetition detector scanning the whole output:
      - flags any unigram/bigram/trigram repeated >= min_repeat times
      - flags long runs (>=3) of the same token
      - flags collapsed prefix repeats (e.g., 'shipship...')
    """
    s = (t or "").strip().lower()
    if not s:
        return False
    toks = s.split()

    # Runs of the same token (e.g., "cat cat cat")
    run_len = 1
    for i in range(1, len(toks)):
        if toks[i] == toks[i-1]:
            run_len += 1
            if run_len >= 3:
                return True
        else:
            run_len = 1

    # n-gram repeats anywhere
    for n in range(1, n_max+1):
        ng = ngrams(toks, n)
        if not ng:
            continue
        counts = Counter(ng)
        if any(c >= min_repeat for c in counts.values()):
            return True

    # collapsed repetition at the start (no spaces), e.g., "carcar..."
    first = toks[0]
    if first and s.startswith(first + first):
        return True
    if len(toks) >= 2:
        bg = toks[0] + " " + toks[1]
        if s.startswith(bg + bg):
            return True

    return False

def compute_repetition_rate_full(texts, **kw):
    n = max(1, len(texts))
    return sum(1 for t in texts if looks_repetitive_full(t, **kw)) / float(n)

# ==== Accuracy metrics ====
def exact_match(p: str, g: str) -> int:
    return int(p.strip() == g.strip())

def class_in_text(p: str, g: str) -> int:
    # accept if the gold label appears as a standalone token in prediction
    pw = p.lower().split()
    return int(g.lower() in pw)

# ==== Decoding with longer, sampled outputs ====
def decode_set(merged_model, items, k=60):
    eos_id = processor.tokenizer.eos_token_id
    preds, golds, lengths, stopped_with_eos = [], [], [], []

    for ex in items[:k]:
        img = ex["image"]
        ans = ex["answer"]
        msg = [{"role":"user","content":[{"type":"image","image": img},
                                         {"type":"text","text": PROMPT_TEXT}]}]
        gp  = processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        enc = processor(text=[gp], images=[img], return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}

        gen_ids = merged_model.generate(
            **enc,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            num_beams=NUM_BEAMS,
            pad_token_id=eos_id,
            eos_token_id=eos_id,
        )
        new_tokens = gen_ids[:, enc["input_ids"].shape[1]:]

        # text output
        text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
        preds.append(text)
        golds.append(ans.strip())

        # length + EOS stop detection
        lengths.append(len(text.split()))
        stopped_with_eos.append(int((new_tokens == eos_id).any().item()))

    return preds, golds, lengths, stopped_with_eos

# ==== Build eval sets (reuse your earlier dataset helpers if already defined) ====
SMALL_N = 60
try:
    val_eos   = make_examples_with_eos_flag(test_raw, SMALL_N, use_eos=True)
    val_noeos = make_examples_with_eos_flag(test_raw, SMALL_N, use_eos=False)
except Exception as e:
    print("Exercise C (dataset build) -> implement TODOs first:", e)
    val_eos, val_noeos = [], []

# ==== Run decode + compute metrics ====
if val_eos and val_noeos:
    preds_eos,   golds_eos,   lens_eos,   stops_eos   = decode_set(merged, val_eos,   k=SMALL_N)
    preds_noeos, golds_noeos, lens_noeos, stops_noeos = decode_set(merged, val_noeos, k=SMALL_N)

    # stricter accuracy (exact), plus a looser fallback (class-in-text)
    acc_exact_eos   = sum(exact_match(p, g)   for p, g in zip(preds_eos,   golds_eos))   / max(1, len(golds_eos))
    acc_exact_noeos = sum(exact_match(p, g)   for p, g in zip(preds_noeos, golds_noeos)) / max(1, len(golds_noeos))

    acc_loose_eos   = sum(class_in_text(p, g) for p, g in zip(preds_eos,   golds_eos))   / max(1, len(golds_eos))
    acc_loose_noeos = sum(class_in_text(p, g) for p, g in zip(preds_noeos, golds_noeos)) / max(1, len(golds_noeos))

    # repetition over full text
    rep_eos   = compute_repetition_rate_full(preds_eos, n_max=3, min_repeat=2)
    rep_noeos = compute_repetition_rate_full(preds_noeos, n_max=3, min_repeat=2)

    # length + eos stop rate
    avg_len_eos   = sum(lens_eos)   / max(1, len(lens_eos))
    avg_len_noeos = sum(lens_noeos) / max(1, len(lens_noeos))
    eos_stop_rate_eos   = sum(stops_eos)   / max(1, len(stops_eos))
    eos_stop_rate_noeos = sum(stops_noeos) / max(1, len(stops_noeos))

    summary = {
        "acc_exact_eos":        round(acc_exact_eos, 4),
        "acc_exact_noeos":      round(acc_exact_noeos, 4),
        "acc_loose_eos":        round(acc_loose_eos, 4),
        "acc_loose_noeos":      round(acc_loose_noeos, 4),
        "rep_eos":              round(rep_eos, 4),
        "rep_noeos":            round(rep_noeos, 4),
        "avg_len_eos":          round(avg_len_eos, 2),
        "avg_len_noeos":        round(avg_len_noeos, 2),
        "eos_stop_rate_eos":    round(eos_stop_rate_eos, 4),
        "eos_stop_rate_noeos":  round(eos_stop_rate_noeos, 4),
        "gen_cfg": {
            "max_new_tokens": MAX_NEW_TOKENS,
            "do_sample": DO_SAMPLE,
            "temperature": TEMPERATURE,
            "top_p": TOP_P,
            "num_beams": NUM_BEAMS
        }
    }
    print(summary)

    # If you want to quickly eyeball some outputs:
    for i in range(min(5, SMALL_N)):
        print(f"[EOS ] gold={golds_eos[i]!r}  pred={preds_eos[i]!r}")
    for i in range(min(5, SMALL_N)):
        print(f"[NOEOS] gold={golds_noeos[i]!r} pred={preds_noeos[i]!r}")

else:
    print("Exercise C waiting for TODO implementations.")


{'acc_exact_eos': 0.25, 'acc_exact_noeos': 0.25, 'acc_loose_eos': 0.25, 'acc_loose_noeos': 0.25, 'rep_eos': 0.0, 'rep_noeos': 0.0, 'avg_len_eos': 0.5, 'avg_len_noeos': 0.5, 'eos_stop_rate_eos': 1.0, 'eos_stop_rate_noeos': 1.0, 'gen_cfg': {'max_new_tokens': 25, 'do_sample': True, 'temperature': 0.8, 'top_p': 0.9, 'num_beams': 1}}
[EOS ] gold='cat'  pred=''
[EOS ] gold='ship'  pred='ship'
[EOS ] gold='ship'  pred='ship'
[EOS ] gold='airplane'  pred='air'
[EOS ] gold='frog'  pred='frog'
[NOEOS] gold='cat' pred=''
[NOEOS] gold='ship' pred='ship'
[NOEOS] gold='ship' pred='ship'
[NOEOS] gold='airplane' pred='air'
[NOEOS] gold='frog' pred='frog'
