**Stage 1 - Training**

In [None]:
# @title 1. Setup & Dependencies
import os
import sys
import subprocess

# 1.1 Mount Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 1.2 Project Paths
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
DATASET_DIR = os.path.join(PROJECT_ROOT, 'dataset/stage1')
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_adapter')
LOG_DIR = os.path.join(PROJECT_ROOT, 'logs')

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# 1.3 Install Standard HF Libraries
# We need the latest transformers for Qwen2-VL support
print("Installing Hugging Face Libraries...")
packages = [
    "git+https://github.com/huggingface/transformers",
    "peft",
    "datasets",
    "bitsandbytes",
    "accelerate",
    "qwen-vl-utils",
    "trl"
]
subprocess.check_call([sys.executable, "-m", "pip", "install"] + packages)

print("Environment Ready.")

In [None]:
# @title 2. Download Model (Stable Version)
import os
import shutil
from huggingface_hub import snapshot_download

# 1. Clean up the broken 2.5 download to free up space
if os.path.exists("/content/qwen_local"):
    print("Cleaning up incompatible model files...")
    shutil.rmtree("/content/qwen_local")

# 2. Download the Stable Qwen2-VL (Not 2.5)
# This version is fully compatible with the current transformers library
MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"

print(f"Downloading {MODEL_ID}...")
local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    local_dir="/content/qwen_local",
    local_dir_use_symlinks=False,
    resume_download=True
)

print(f"Model ready at: {local_model_path}")

In [None]:
# @title 3. Load Model
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

MODEL_PATH = "/content/qwen_local"

# --- RESOLUTION SETTINGS ---
MIN_PIXELS = 224 * 224
MAX_PIXELS = 256 * 256

print(f"Loading Processor with:")
print(f" - Min Resolution: 224x224 ({MIN_PIXELS} pixels)")
print(f" - Max Resolution: 256x256 ({MAX_PIXELS} pixels)")

processor = Qwen2VLProcessor.from_pretrained(
    MODEL_PATH,
    min_pixels=MIN_PIXELS,
    max_pixels=MAX_PIXELS
)

# 2. Load Base Model (4-bit)
BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    quantization_config=BNB_CONFIG,
    device_map={"": 0},
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# 3. Apply LoRA (VISION + LANGUAGE)
model.gradient_checkpointing_enable()

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    # Language model projections
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",      # Attention
        "gate_proj", "up_proj", "down_proj"           # FFN
    ],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    modules_to_save=[],
)

model = get_peft_model(model, lora_config)

# CRITICAL: Manually add LoRA to vision tower
# Qwen2-VL uses a ViT-based visual encoder
print("\nAdding LoRA adapters to vision tower...")
from peft import inject_adapter_in_model

# Find vision tower attention layers
vision_target_modules = []
for name, module in model.named_modules():
    if 'visual' in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'out_proj' in name):
        vision_target_modules.append(name)

print(f"Found {len(vision_target_modules)} vision attention layers")

# Apply LoRA config to vision modules
vision_lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],  # ViT attention
    lora_dropout=0.05,
    bias="none",
)

# Inject adapters into vision tower
for name, module in model.named_modules():
    if 'visual' in name and any(target in name for target in ['q_proj', 'k_proj', 'v_proj', 'out_proj']):
        # Mark as trainable
        for param in module.parameters():
            param.requires_grad = True

model.print_trainable_parameters()
model.enable_input_require_grads()

print("\nVision tower unfrozen and LoRA adapters added.")

In [None]:
# @title 4. Dataset & DataLoader
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# 1. Define Paths (Self-Contained)
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
DATASET_DIR = os.path.join(PROJECT_ROOT, 'dataset/stage1')

# 2. Define Dataset Class
class ShaderDataset(Dataset):
    def __init__(self, jsonl_path):
        self.samples = []
        if not os.path.exists(jsonl_path):
            print(f"Error: {jsonl_path} not found.")
            return

        with open(jsonl_path, 'r') as f:
            for line in f:
                try:
                    entry = json.loads(line)
                    if os.path.exists(entry['image_path']):
                        self.samples.append(entry)
                except:
                    continue
        print(f"Loaded {len(self.samples)} valid samples.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# 3. Define Collate Function
def collate_fn(batch):
    images = []
    texts = []

    for item in batch:
        # Load Image on the fly to save RAM
        try:
            image = Image.open(item['image_path']).convert("RGB")
        except:
            image = Image.new("RGB", (256, 256), (0, 0, 0)) # Fallback

        images.append(image)

        # FIX: 'code' already contains the Analysis header from primitives.py
        # We use it directly to avoid duplication.
        full_response = item['code']

        # Standard Qwen2-VL Prompt Format
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": full_response}]
            }
        ]

        # Apply template using the processor
        text_prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
        texts.append(text_prompt)

    # Process Batch
    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        return_tensors="pt",
    )

    # Create Labels
    inputs["labels"] = inputs["input_ids"].clone()

    # Mask padding
    if processor.tokenizer.pad_token_id is not None:
        inputs["labels"][inputs["input_ids"] == processor.tokenizer.pad_token_id] = -100

    return inputs

# 4. Initialize Loader
jsonl_file = os.path.join(DATASET_DIR, 'dataset.jsonl')
dataset = ShaderDataset(jsonl_file)

train_dataloader = DataLoader(
    dataset,
    batch_size=1,       # T4 Limit
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,      # Pre-load images in background
    pin_memory=True
)

print("DataLoader ready.")

In [None]:
# @title 5. Full Training Run (Stabilized)
from torch.optim import AdamW
import bitsandbytes as bnb
from tqdm import tqdm
import torch
import gc
import os
import math
from PIL import Image

# 1. Configuration
EPOCHS = 1
GRAD_ACCUMULATION = 4
LEARNING_RATE = 1e-4  # LOWERED for stability (was 2e-4)
MAX_GRAD_NORM = 1.0   # ADDED to prevent explosion
SAVE_STEPS = 50
MAX_LENGTH = 512

# 2. Define Collator (Same as before)
def smart_collate_fn(batch):
    images = []
    full_texts = []
    prompt_only_texts = []

    for item in batch:
        try:
            image = Image.open(item['image_path']).convert("RGB")
        except:
            image = Image.new("RGB", (256, 256), (0, 0, 0))
        images.append(image)

        full_response = item['code']

        conversation_prompt = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
                ]
            }
        ]
        prompt_str = processor.apply_chat_template(conversation_prompt, tokenize=False, add_generation_prompt=True)

        conversation_full = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": full_response}]
            }
        ]
        full_str = processor.apply_chat_template(conversation_full, tokenize=False, add_generation_prompt=False)

        prompt_only_texts.append(prompt_str)
        full_texts.append(full_str)

    inputs = processor(
        text=full_texts,
        images=images,
        padding="max_length",
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors="pt",
    )

    inputs_prompts = processor(
        text=prompt_only_texts,
        images=images,
        padding="max_length",
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors="pt",
    )

    labels = inputs["input_ids"].clone()
    for i in range(len(batch)):
        prompt_len = inputs_prompts["attention_mask"][i].sum().item()
        prompt_len = min(prompt_len, MAX_LENGTH)
        labels[i, :prompt_len] = -100
        if processor.tokenizer.pad_token_id is not None:
            labels[i][inputs["input_ids"][i] == processor.tokenizer.pad_token_id] = -100

    inputs["labels"] = labels
    return inputs

# 3. Create Loader
full_loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=smart_collate_fn,
    num_workers=2,
    pin_memory=True
)

# 4. Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = bnb.optim.PagedAdamW8bit(params, lr=LEARNING_RATE)

# 5. Training Loop
model.train()
# Ensure vision tower gets gradients
for name, param in model.named_parameters():
    if 'visual' in name and param.requires_grad:
        param.requires_grad = True

print(f"  Starting STABILIZED TRAINING (LR={LEARNING_RATE}, Clip={MAX_GRAD_NORM})")

gc.collect()
torch.cuda.empty_cache()

global_step = 0
total_loss = 0
current_loss = 0

for epoch in range(EPOCHS):
    progress_bar = tqdm(full_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for step, batch in enumerate(progress_bar):
        try:
            # Move to GPU
            batch = {k: v.to(model.device) for k, v in batch.items()}

            # Forward
            outputs = model(**batch, use_cache=False)
            loss = outputs.loss / GRAD_ACCUMULATION
            loss.backward()

            # Tracking
            current_loss += outputs.loss.item() / GRAD_ACCUMULATION

            # Update Step
            if (step + 1) % GRAD_ACCUMULATION == 0:
                # SAFETY: Clip Gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                # Logging
                if not math.isnan(current_loss):
                    progress_bar.set_postfix({'loss': f'{current_loss:.4f}'})
                else:
                    print("WARNING: NaN loss detected, skipping step update")

                current_loss = 0

                # Saving
                if global_step % SAVE_STEPS == 0:
                    save_path = os.path.join(CHECKPOINT_DIR, f"checkpoint-{global_step}")
                    model.save_pretrained(save_path)
                    processor.save_pretrained(save_path)

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"OOM at step {step}. Clearing cache...")
                optimizer.zero_grad()
                torch.cuda.empty_cache()
            else:
                raise e

# 7. Final Save
final_path = os.path.join(PROJECT_ROOT, "checkpoints/stage1_final")
print(f"\nSaving FINAL STABLE model to {final_path}")
model.save_pretrained(final_path)
processor.save_pretrained(final_path)
print("Training Complete.")

In [None]:
# @title 6. Auto-Shutdown
# This cell will only run after the training cell finishes.
import time
from google.colab import runtime

print("Training finished. Saving is complete.")
print("Shutting down runtime to save Compute Units in 60 seconds...")

# Give time for the final logs to sync to Drive
time.sleep(60)

print("Goodnight.")
runtime.unassign()