**Stage 1 - Test Suite**

In [None]:
# @title 1. Setup Environment
import os
import sys
import subprocess

# 1.1 Install Dependencies
print("Installing dependencies...")
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers peft accelerate scikit-learn pandas
!apt-get install -y libgl1-mesa-glx > /dev/null
!pip install -q moderngl

# 1.2 Mount Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 1.3 Configure Paths
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')
LIB_DIR = os.path.join(PROJECT_ROOT, 'lib')

# 1.4 Register Library
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)
if LIB_DIR not in sys.path:
    sys.path.insert(0, LIB_DIR)

print("Environment Ready.")

In [None]:
# @title 2. Load Model & Generators
import torch
import importlib.util
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, BitsAndBytesConfig
from peft import PeftModel

# 2.1 Load Generator Library
try:
    from generators.primitives import generate_primitive
    from gl_renderer import ShaderRenderer
    print("Libraries loaded successfully.")
except ImportError:
    # Direct load fallback for Colab path issues
    print("Standard import failed. Using direct file injection...")

    # Load Renderer
    spec_r = importlib.util.spec_from_file_location("gl_renderer", os.path.join(LIB_DIR, "gl_renderer.py"))
    mod_r = importlib.util.module_from_spec(spec_r)
    spec_r.loader.exec_module(mod_r)
    ShaderRenderer = mod_r.ShaderRenderer

    # Load Base Generator
    spec_b = importlib.util.spec_from_file_location("base", os.path.join(LIB_DIR, "generators/base.py"))
    mod_b = importlib.util.module_from_spec(spec_b)
    spec_b.loader.exec_module(mod_b)

    # Load Primitives Generator
    spec_p = importlib.util.spec_from_file_location("primitives", os.path.join(LIB_DIR, "generators/primitives.py"))
    mod_p = importlib.util.module_from_spec(spec_p)
    mod_p.base = mod_b
    spec_p.loader.exec_module(mod_p)
    generate_primitive = mod_p.generate_primitive

# 2.2 Initialize Renderer
renderer = ShaderRenderer(width=256, height=256)

# 2.3 Load Model
print(f"Loading Adapter from: {ADAPTER_PATH}")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)

model = PeftModel.from_pretrained(model, ADAPTER_PATH)
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
print("Model Ready.")

In [None]:
# @title 3. Run Test Suite
import re
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix

# CONFIG
TEST_SAMPLES = 100  # Set to 100 for a quick 3-minute test, 1000 takes ~40 mins
BATCH_SIZE = 4

# Helpers
def parse_shape_name(text):
    # Regex handles optional spaces: "// Shape : Circle"
    match = re.search(r"// Shape:\s*(\w+)", text)
    return match.group(1).lower() if match else "unknown"

def extract_code(text):
    if "// GLSL CODE" in text:
        return text.split("// GLSL CODE")[1].strip()
    if "```glsl" in text:
        return text.split("```glsl")[1].split("```")[0].strip()
    return text

def calculate_mse(img1, img2):
    # Normalize to 0-1 range for a standard MSE
    arr1 = np.array(img1).astype(float) / 255.0
    arr2 = np.array(img2).astype(float) / 255.0
    return np.mean((arr1 - arr2) ** 2)

# --- PHASE 1: GENERATE GROUND TRUTH ---
print(f"Generating {TEST_SAMPLES} ground truth samples using library...")
gt_data = []

for i in range(TEST_SAMPLES):
    # Use the library to create a valid test case
    code, analysis = generate_primitive(i)
    shape = parse_shape_name(analysis)

    renderer.render(code, "temp_gt.png")
    img = Image.open("temp_gt.png").convert("RGB")

    gt_data.append({
        "gt_code": code,
        "gt_shape": shape,
        "gt_image": img
    })

# --- PHASE 2: RUN INFERENCE ---
print("Running Inference...")
predictions = []

for i in tqdm(range(0, TEST_SAMPLES, BATCH_SIZE)):
    batch = gt_data[i : i + BATCH_SIZE]
    batch_images = [item["gt_image"] for item in batch]

    # Prepare Prompt
    prompts = []
    for img in batch_images:
        prompts.append([
            {"role": "user", "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
            ]}
        ])

    text_inputs = [processor.apply_chat_template(p, add_generation_prompt=True) for p in prompts]

    inputs = processor(
        text=text_inputs,
        images=batch_images,
        padding=True,
        return_tensors="pt"
    ).to(model.device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=512)

    generated_ids = [out[len(inp):] for inp, out in zip(inputs.input_ids, output_ids)]
    output_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

    predictions.extend(output_texts)

# --- PHASE 3: EVALUATE ---
print("Evaluating Results...")
y_true = []
y_pred = []
visual_errors = []
compile_success = 0

for i, text in enumerate(predictions):
    gt = gt_data[i]

    # 1. Parse Output
    pred_code = extract_code(text)
    pred_shape = parse_shape_name(text)

    # 2. Compile & Render
    success = renderer.render(pred_code, "temp_pred.png")

    # 3. Calculate Metrics
    mse = 1.0 # Default high error penalty
    if success and os.path.exists("temp_pred.png"):
        compile_success += 1
        pred_img = Image.open("temp_pred.png").convert("RGB")
        mse = calculate_mse(gt["gt_image"], pred_img)

    y_true.append(gt["gt_shape"])
    y_pred.append(pred_shape)
    visual_errors.append(mse)

# --- REPORT ---
print("\n" + "="*40)
print(f"TEST REPORT (N={TEST_SAMPLES})")
print("="*40)

print(f"\n1. Compilation Success: {compile_success}/{TEST_SAMPLES} ({compile_success/TEST_SAMPLES:.1%})")
print(f"2. Average Visual Error (MSE): {np.mean(visual_errors):.4f}")

print("\n3. Classification Report:")
print(classification_report(y_true, y_pred, zero_division=0))

print("\n4. Confusion Matrix:")
# Include all labels found in both truth and predictions (handles hallucinations)
all_labels = sorted(list(set(y_true + y_pred)))
cm = confusion_matrix(y_true, y_pred, labels=all_labels)
print(pd.DataFrame(cm, index=all_labels, columns=all_labels))

In [None]:
# @title 4. Auto-Shutdown
# This cell will only run after the training cell finishes.
import time
from google.colab import runtime

print("Training finished. Saving is complete.")
print("Shutting down runtime to save Compute Units in 60 seconds...")

# Give time for the final logs to sync to Drive
time.sleep(60)

print("Goodnight.")
runtime.unassign()