# Llena inspection notebook

Quick, small-scale inspection of datasets, tokenizer, model inputs/outputs.

In [None]:
# Optional: install deps in Colab
# !pip -q install torch torchvision transformers datasets peft bitsandbytes pillow tqdm wandb


In [None]:
import os, sys
from pathlib import Path

# If running from notebooks/, add repo root to sys.path
repo_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
sys.path.insert(0, str(repo_root))

print('repo_root:', repo_root)


In [None]:
import torch
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)


## Load a small dataset slice

This uses a small validation slice to keep it fast.

In [None]:
ds = load_dataset('lmms-lab/textvqa', split='validation[:50]')
print(ds)
sample = ds[0]
list(sample.keys())


In [None]:
# Inspect a sample
sample


## Build Llena model + collator

Use a small config and run a single forward pass.

In [None]:
from mm.model import LlenaModel, LlenaModelConfig
from mm.collator import LlenaCollator
from transformers import SiglipImageProcessor

cfg = LlenaModelConfig(
    llm_name='Qwen/Qwen2.5-0.5B-Instruct',
    vision_name='google/siglip-base-patch16-224',
    num_image_tokens=64,
    projector='mlp2',
    freeze_vision=True,
    freeze_llm=True,
    gradient_checkpointing=False,
    device='cuda' if device.type == 'cuda' else 'cpu',
)
model = LlenaModel(cfg)
model.eval()

image_proc = SiglipImageProcessor.from_pretrained(cfg.vision_name)
collator = LlenaCollator(
    tokenizer=model.tokenizer,
    image_processor=image_proc,
    max_seq_len=128,
    num_image_tokens=cfg.num_image_tokens,
    pad_to_multiple_of=None,
)


In [None]:
# Build a small batch from the HF dataset
# Convert HF sample to VQASample format
from PIL import Image

def to_vqa(sample):
    return {
        'image': sample['image'],
        'question': sample['question'],
        'answer': sample['answers'][0],
        'answers': sample['answers'],
    }

batch = [to_vqa(ds[i]) for i in range(2)]
out = collator(batch)
batch_t = {k: v.to(device) for k, v in out.items() if torch.is_tensor(v)}

# Forward pass
with torch.no_grad():
    outputs = model(
        pixel_values=batch_t['pixel_values'],
        input_ids=batch_t['input_ids'],
        mm_attention_mask=batch_t['mm_attention_mask'],
        mm_labels=batch_t['mm_labels'],
    )

outputs.loss, outputs.logits.shape


In [None]:
# Decode the model's argmax tokens for the answer region (quick sanity check)
logits = outputs.logits
pred_ids = logits.argmax(dim=-1)
mask = batch_t['mm_labels'][0] != -100
pred_seq = pred_ids[0][mask].tolist()
model.tokenizer.decode(pred_seq, skip_special_tokens=True)


## Optional: inspect processed JSONL dataset

If you have processed data under `datasets/processed`, you can load it with JsonlVQADataset.

In [None]:
from data.format import JsonlVQADataset

proc_root = repo_root / 'datasets' / 'processed' / 'textvqa'
jsonl = proc_root / 'validation.jsonl'
images = proc_root / 'images'
if jsonl.exists():
    ds_jsonl = JsonlVQADataset(annotations_path=jsonl, image_root=images, max_samples=5)
    print('jsonl samples:', len(ds_jsonl))
    print(ds_jsonl[0])
else:
    print('No processed JSONL found at', jsonl)
