**Stage 1 - Training**

In [None]:
# @title 1. Setup and dependencies
import os
import sys
import subprocess
from google.colab import drive

# Mount Google Drive using the stable MyDrive alias.
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Define the project paths for the EarthShader environment.
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
DATASET_DIR = os.path.join(PROJECT_ROOT, 'dataset/stage1')
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_adapter')

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Install the necessary libraries for Qwen2-VL support.
print("Installing environment dependencies...")
!pip install -q git+https://github.com/huggingface/transformers peft datasets bitsandbytes accelerate qwen-vl-utils trl

In [None]:
# @title 2. Model staging and local download
import os
import shutil
from huggingface_hub import snapshot_download

# Clean up local staging to prevent conflicts during the session.
if os.path.exists("/content/qwen_local"):
    shutil.rmtree("/content/qwen_local")

# Download the base Qwen2-VL model to local runtime storage.
MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
print(f"Staging {MODEL_ID} to local runtime...")

local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    local_dir="/content/qwen_local",
    local_dir_use_symlinks=False,
    resume_download=True
)

In [None]:
# @title 3. Model loading and 4 bit quantization
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

# Configure 4-bit quantization for T4 VRAM safety.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load the model and processor with specific pixel constraints.
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "/content/qwen_local",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    min_pixels=256*256,
    max_pixels=256*256
)

print("Model and processor successfully loaded into memory.")

In [None]:
# @title 4. Model preparation and lora setup
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch

# Prepare the model for kbit training sessions.
model = prepare_model_for_kbit_training(model)

# Target both the language and vision projection layers.
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

# Explicitly enable gradients for the vision tower parameters.
for name, param in model.named_parameters():
    if "visual" in name and (torch.is_floating_point(param) or torch.is_complex(param)):
        param.requires_grad = True

model.print_trainable_parameters()

In [None]:
# @title 5. Scenario collation and anchoring
from PIL import Image
import torch

# This prompt anchors the model to the EarthShader compiler API.
SYSTEM_PROMPT = "You are an EarthShader SDF compiler. Translate the visual primitive into a valid GLSL mainImage function using the common.glsl library."

def final_collate_fn(batch):
    images, full_texts, prompt_only_texts = [], [], []

    for item in batch:
        try:
            image = Image.open(item['image_path']).convert("RGB")
        except:
            # Create a black fallback image if the file is missing.
            image = Image.new("RGB", (256, 256), (0, 0, 0))
        images.append(image)

        prompt_conv = [
            {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
            ]}
        ]

        # Use the structured response with dense reasoning labels.
        full_response = f"{item['analysis']}\n\n```glsl\n{item['code']}\n```"
        full_conv = prompt_conv + [{"role": "assistant", "content": [{"type": "text", "text": full_response}]}]

        prompt_only_texts.append(processor.apply_chat_template(prompt_conv, tokenize=False, add_generation_prompt=True))
        full_texts.append(processor.apply_chat_template(full_conv, tokenize=False, add_generation_prompt=False))

    inputs = processor(text=full_texts, images=images, padding="max_length", max_length=768, truncation=True, return_tensors="pt")
    inputs_prompts = processor(text=prompt_only_texts, images=images, padding="max_length", max_length=768, truncation=True, return_tensors="pt")

    # Mask the prompt tokens to focus the loss on the response.
    labels = inputs["input_ids"].clone()
    for i in range(len(batch)):
        prompt_len = inputs_prompts["attention_mask"][i].sum().item()
        labels[i, :prompt_len] = -100

    inputs["labels"] = labels
    return inputs

In [None]:
# @title 6. Dataset and dataloader initialization
from datasets import load_dataset
from torch.utils.data import DataLoader
import os

# 1. Load the dataset from the local registry.
dataset_path = os.path.join(DATASET_DIR, "dataset.jsonl")
raw_dataset = load_dataset("json", data_files=dataset_path, split="train")

# 2. Select exactly 2000 samples for the committed Stage 1 run.
raw_dataset = raw_dataset.select(range(2000))

# 3. Configure the data loader.
# We set shuffle=False to ensure the run is perfectly reproducible.
GRAD_ACCUMULATION = 4

full_loader = DataLoader(
    raw_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=final_collate_fn,
    num_workers=2,
    pin_memory=True
)

print(f"Dataset initialization complete. Committed with {len(raw_dataset)} samples.")

In [None]:
# @title 7. Training loop and weighted loss
import torch.nn as nn
from tqdm import tqdm
import bitsandbytes as bnb
import os

def compute_weighted_loss(logits, labels, tokenizer, code_weight=5.0, text_weight=0.2):
    """Applies higher weight to GLSL code tokens to prioritize syntax accuracy."""
    code_marker = tokenizer.encode("```glsl", add_special_tokens=False)
    loss_fct = nn.CrossEntropyLoss(reduction='none')

    # Standard shift for causal language modeling.
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    weights = torch.ones_like(shift_labels, dtype=torch.float32)

    for i in range(shift_labels.size(0)):
        row = shift_labels[i].tolist()
        try:
            # Locate where the reasoning ends and the code begins.
            marker_idx = next(idx for idx, val in enumerate(row) if val == code_marker[-1])
            weights[i, :marker_idx] = text_weight
            weights[i, marker_idx:] = code_weight
        except StopIteration:
            # Default weighting if the marker is missing.
            pass

    flat_logits = shift_logits.view(-1, shift_logits.size(-1))
    flat_labels = shift_labels.view(-1)
    loss = loss_fct(flat_logits, flat_labels)

    # Calculate mean loss while ignoring padding.
    valid_mask = (flat_labels != -100)
    weighted_loss = (loss * weights.view(-1))[valid_mask].sum() / valid_mask.sum()
    return weighted_loss

# 1. Configuration for the committed 2,000-sample run.
GRAD_ACCUMULATION = 4
TOTAL_ITERATIONS = 2000

model.train()
# Using PagedAdamW8bit to stay within T4 VRAM limits.
optimizer = bnb.optim.PagedAdamW8bit(model.parameters(), lr=1e-4)
global_step = 0

for epoch in range(1):
    pbar = tqdm(full_loader, desc="Training stage 1", total=TOTAL_ITERATIONS)
    for step, batch in enumerate(pbar):
        # Stop exactly at the commitment limit.
        if step >= TOTAL_ITERATIONS:
             break

        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch)

        loss = compute_weighted_loss(outputs.logits, batch["labels"], processor.tokenizer)
        (loss / GRAD_ACCUMULATION).backward()

        if (step + 1) % GRAD_ACCUMULATION == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            pbar.set_postfix({'weighted_loss': f'{loss.item():.4f}'})

            # Save periodic checkpoints every 100 global steps.
            if global_step % 100 == 0:
                model.save_pretrained(os.path.join(CHECKPOINT_DIR, f"step-{global_step}"))

# 2. Final save for the locked Stage 1 artifact.
model.save_pretrained(os.path.join(PROJECT_ROOT, "checkpoints/stage1_final"))
print(f"\n[SUCCESS] Stage 1 training complete. Final weights saved to stage1_final.")

In [None]:
# @title 8. Auto shutdown
import time
from google.colab import runtime

# Ensure the drive has enough time to sync the final adapter files.
print("Training sequence has finished. Synchronizing drive files...")
time.sleep(60)

print("The system is going offline to preserve compute units.")
runtime.unassign()