# Training pipeline: Qwen2-VL-2B Fine-tuning for GLSL Generation

**Goal:** Fine-tune Qwen2-VL-2B-Instruct to generate GLSL shader code from rendered images.

**Strategy:** Start simple, iterate on stability.

In [None]:
# @title Setup: Mount Drive & Install Dependencies

import os
import sys

# 1. Mount Google Drive
from google.colab import drive, userdata
print("[SYSTEM] Mounting Google Drive...")
drive.mount('/content/drive')
print(" Drive mounted")

# 2. Install core dependencies with compatible versions
print("\n[SETUP] Installing dependencies...")
!pip install -q transformers==4.45.0 accelerate==0.34.0 peft==0.12.0 datasets==2.20.0 pillow tqdm

print("\n Setup complete")
print(f" PyTorch: {__import__('torch').__version__}")
print(f" Transformers: {__import__('transformers').__version__}")
print(f" GPU Available: {__import__('torch').cuda.is_available()}")

In [None]:
# @title Configuration

CONFIG = {
    # Paths
    "dataset_dir": "/content/drive/My Drive/projects/EarthShader/dataset",
    "output_dir": "/content/drive/My Drive/projects/EarthShader/checkpoints_training", # Changed folder to avoid overwriting

    # Model
    "model_name": "Qwen/Qwen2-VL-2B-Instruct",
    # "model_id": "Qwen/Qwen2-VL-7B-Instruct", # Uncomment for 7B run
    "max_seq_length": 1024,

    # Training
    "num_train_epochs": 15,            # Increased from 3 to 15
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 8,
    "gradient_checkpointing": True,
    "learning_rate": 1e-5,             # Lowered from 2e-5 for stability
    "warmup_steps": 50,
    "logging_steps": 10,
    "save_steps": 200,                 # Save less frequently
    "max_samples": None,

    # LoRA
    "use_lora": True,
    "lora_r": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
}

# Create output directory
os.makedirs(CONFIG["output_dir"], exist_ok=True)
print(f" Configuration loaded for Training Pipeline")

In [None]:
# @title Load Dataset

import os
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

class ShaderDataset(Dataset):
    def __init__(self, dataset_dir, max_samples=None):
        self.dataset_dir = dataset_dir

        # Find all image files
        all_files = os.listdir(dataset_dir)
        image_files = sorted([f for f in all_files if f.endswith('.jpg')])

        if max_samples:
            image_files = image_files[:max_samples]

        # Verify pairs exist
        self.samples = []
        for img_file in tqdm(image_files, desc="Validating dataset"):
            base_name = img_file.replace('.jpg', '')
            glsl_file = base_name + '.glsl'

            img_path = os.path.join(dataset_dir, img_file)
            glsl_path = os.path.join(dataset_dir, glsl_file)

            if os.path.exists(img_path) and os.path.exists(glsl_path):
                self.samples.append({
                    'image_path': img_path,
                    'glsl_path': glsl_path
                })

        print(f" Loaded {len(self.samples)} valid pairs")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Load image
        image = Image.open(sample['image_path']).convert('RGB')

        # Load GLSL code
        with open(sample['glsl_path'], 'r') as f:
            glsl_code = f.read()

        return {
            'image': image,
            'glsl_code': glsl_code
        }

# Load dataset
print("[DATA] Loading dataset...")
dataset = ShaderDataset(
    CONFIG["dataset_dir"],
    max_samples=CONFIG["max_samples"]
)

# Show sample
if len(dataset) > 0:
    sample = dataset[0]
    print(f"\n Sample check:")
    print(f"  Image size: {sample['image'].size}")
    print(f"  Code length: {len(sample['glsl_code'])} chars")
    print(f"  Code preview: {sample['glsl_code'][:100]}...")

In [None]:
# @title Load Model & Processor

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import LoraConfig, get_peft_model

print("[MODEL] Loading Qwen2-VL-7B-Instruct...")

# Load processor
processor = AutoProcessor.from_pretrained(
    CONFIG["model_name"],
    trust_remote_code=True
)
print(" Processor loaded")

# Load model in float16 (no quantization)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

# Enable gradient checkpointing to save memory
if CONFIG.get("gradient_checkpointing", False):
    model.gradient_checkpointing_enable()
    print(" Gradient checkpointing enabled")

print(" Model loaded")

# Apply LoRA (without quantization prep)
if CONFIG["use_lora"]:
    print("\n[LORA] Applying LoRA adapters...")

    # Make model trainable
    for param in model.parameters():
        param.requires_grad = False

    lora_config = LoraConfig(
        r=CONFIG["lora_r"],
        lora_alpha=CONFIG["lora_alpha"],
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=CONFIG["lora_dropout"],
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    print(" LoRA applied")

print(f"\n Model ready on {model.device}")

In [None]:
# @title Prepare Training Data

from torch.utils.data import DataLoader
import torch

def collate_fn(batch):
    images = []
    full_texts = []
    prompt_texts = []

    for item in batch:
        images.append(item['image'])

        # Important: Use a specific prompt.
        # We explicitly ask for "Shadertoy-style" and "void mainImage"
        # This prevents the model from generating C++ setup code.
        prompt = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Generate Shadertoy-style GLSL code to render this image. The code should start with void mainImage.<|im_end|>\n<|im_start|>assistant\n"

        full_text = prompt + item['glsl_code'] + "<|im_end|>"

        full_texts.append(full_text)
        prompt_texts.append(prompt)

    # Process full text (Input)
    inputs = processor(
        text=full_texts,
        images=images,
        padding=True,
        truncation=True,
        max_length=CONFIG["max_seq_length"],
        return_tensors="pt"
    )

    # Process prompt only (for Masking)
    inputs_prompts = processor(
        text=prompt_texts,
        images=images,
        padding=True,
        truncation=True,
        max_length=CONFIG["max_seq_length"],
        return_tensors="pt"
    )

    # Create Labels with Masking
    labels = inputs["input_ids"].clone()
    for i in range(len(batch)):
        prompt_len = inputs_prompts["attention_mask"][i].sum().item()
        labels[i, :prompt_len] = -100

    inputs["labels"] = labels
    return inputs

# Create DataLoader
train_dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["per_device_train_batch_size"],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0
)

print(f" Training Pipeline DataLoader ready")

In [None]:
# @title Training Loop

import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

print("[TRAIN] Starting training...\n")

# Setup optimizer
optimizer = AdamW(model.parameters(), lr=CONFIG["learning_rate"])

# Calculate total steps
num_training_steps = len(train_dataloader) * CONFIG["num_train_epochs"] // CONFIG["gradient_accumulation_steps"]

# Setup scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=CONFIG["warmup_steps"],
    num_training_steps=num_training_steps
)

# Training state
model.train()
global_step = 0
total_loss = 0

# Training loop
for epoch in range(CONFIG["num_train_epochs"]):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{CONFIG['num_train_epochs']}")
    print(f"{'='*60}\n")

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")

    for step, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()}

        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss

        # Scale loss for gradient accumulation
        loss = loss / CONFIG["gradient_accumulation_steps"]
        loss.backward()

        total_loss += loss.item()

        # Update weights
        if (step + 1) % CONFIG["gradient_accumulation_steps"] == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            # Logging
            if global_step % CONFIG["logging_steps"] == 0:
                avg_loss = total_loss / CONFIG["logging_steps"]
                progress_bar.set_postfix({
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })
                total_loss = 0

            # Save checkpoint
            if global_step % CONFIG["save_steps"] == 0:
                checkpoint_dir = os.path.join(CONFIG["output_dir"], f"checkpoint-{global_step}")
                print(f"\n Saving checkpoint to {checkpoint_dir}")
                model.save_pretrained(checkpoint_dir)
                processor.save_pretrained(checkpoint_dir)

# Final save
final_dir = os.path.join(CONFIG["output_dir"], "final_model")
print(f"\n Saving final model to {final_dir}")
model.save_pretrained(final_dir)
processor.save_pretrained(final_dir)

print("\n Training complete!")

In [None]:
# @title Test Inference
import random
import torch
from PIL import Image

print("[TEST] Running Inference...\n")

# Get random test sample
test_sample = random.choice(dataset)
test_image = test_sample['image']
true_code = test_sample['glsl_code']

# Critical: Use the exact same prompt as in training.
# TODO(stefan): Make this a constant for reuse.
text = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Generate Shadertoy-style GLSL code to render this image. The code should start with void mainImage.<|im_end|>\n<|im_start|>assistant\n"

inputs = processor(
    text=[text],
    images=[test_image],
    return_tensors="pt"
).to(model.device)

# Generate with sampling (creativity) to break loops
model.eval()
with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.6,         # Slightly lower temp for code precision
        top_p=0.9,
        repetition_penalty=1.15  # Prevent "void main" loops
    )

# Decode results
generated_ids = output_ids[0][inputs['input_ids'].shape[1]:]
generated_text = processor.decode(generated_ids, skip_special_tokens=True)

# Display results
print("=" * 60)
print("GENERATED SHADERTOY CODE:")
print("=" * 60)
print(generated_text)
print("\n" + "=" * 60)
print("TRUE CODE snippet:")
print("=" * 60)
print(true_code[:300] + "...")

# Show image
from IPython.display import display
print("\nTest Image:")
display(test_image.resize((256, 256)))