**Stage 1 - Test Suite**

In [None]:
# @title 1. Setup Environment
import os
import sys
import subprocess

# 1.1 Install Dependencies
print("Installing dependencies...")
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers peft accelerate scikit-learn pandas matplotlib
!apt-get install -y libgl1-mesa-glx xvfb > /dev/null
!pip install -q moderngl

# 1.2 Mount Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 1.3 Configure Paths
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')
LIB_DIR = os.path.join(PROJECT_ROOT, 'lib')
GEN_DIR = os.path.join(LIB_DIR, 'generators')

# 1.4 Fix Imports (The "Package" Fix)
if LIB_DIR not in sys.path:
    sys.path.append(LIB_DIR)

for folder in [LIB_DIR, GEN_DIR]:
    init_file = os.path.join(folder, '__init__.py')
    if not os.path.exists(init_file):
        with open(init_file, 'w') as f: f.write("")

print("Environment Ready.")

In [None]:
# @title 2. Load Model & Robust Renderer
import torch
import moderngl
import numpy as np
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, BitsAndBytesConfig
from peft import PeftModel
import importlib.util

# --- ROBUST RENDERER (Embedded to prevent import issues) ---
class ShaderRenderer:
    def __init__(self, width=256, height=256):
        self.width = width
        self.height = height
        self.ctx = None
        try:
            self.ctx = moderngl.create_context(standalone=True, backend='egl')
        except Exception:
            try:
                self.ctx = moderngl.create_context(standalone=True)
            except Exception as e:
                print(f"Renderer Init Error: {e}")

        # Geometry
        vertices = np.array([-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0], dtype='f4')
        if self.ctx:
            self.vbo = self.ctx.buffer(vertices.tobytes())
            self.fbo = self.ctx.simple_framebuffer((self.width, self.height), components=3)
            self.vert_shader = '''
            #version 330
            in vec2 in_vert;
            out vec2 uv;
            void main() {
                uv = in_vert;
                gl_Position = vec4(in_vert, 0.0, 1.0);
            }
            '''

            # Load Common Lib
            common_path = os.path.join(LIB_DIR, 'common.glsl')
            self.common_lib = ""
            if os.path.exists(common_path):
                with open(common_path, 'r') as f:
                    self.common_lib = f.read()

    def render(self, fragment_code, output_path):
        if not self.ctx: return False

        full_frag_shader = f'''
        #version 330
        uniform vec2 iResolution;
        out vec4 fragColor;
        {self.common_lib}
        {fragment_code}
        void main() {{
            vec4 color;
            mainImage(color, gl_FragCoord.xy);
            fragColor = color;
        }}
        '''
        try:
            prog = self.ctx.program(vertex_shader=self.vert_shader, fragment_shader=full_frag_shader)
            if 'iResolution' in prog: prog['iResolution'].value = (self.width, self.height)
            vao = self.ctx.simple_vertex_array(prog, self.vbo, 'in_vert')

            self.fbo.use()
            self.fbo.clear(0.0, 0.0, 0.0, 1.0)
            vao.render(moderngl.TRIANGLE_STRIP)

            data = self.fbo.read(components=3)
            image = Image.frombytes('RGB', self.fbo.size, data).transpose(Image.FLIP_TOP_BOTTOM)
            image.save(output_path)

            vao.release()
            prog.release()
            return True
        except Exception as e:
            return False

# Initialize
renderer = ShaderRenderer(width=256, height=256)

# Load Generators
try:
    if 'generators' in sys.modules: del sys.modules['generators']
    from generators.primitives import generate_primitive
    print("Generators loaded.")
except ImportError as e:
    print(f"Generator Import Error: {e}")

# Load Model
print(f"Loading Adapter from: {ADAPTER_PATH}")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
base_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
processor = Qwen2VLProcessor.from_pretrained(ADAPTER_PATH, min_pixels=256*256, max_pixels=256*256)
print("Model Ready.")

In [None]:
# @title 3. Run Test Suite (100 Samples)
import re
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt

# CONFIG
TEST_SAMPLES = 100
BATCH_SIZE = 4

# Helpers
def parse_shape_name(text):
    # Try getting explicit comment first
    match = re.search(r"// (?:Shape|Geometry):\s*(\w+)", text, re.IGNORECASE)
    if match: return match.group(1).lower()

    # Fallback: Primitive analysis
    text = text.lower()
    if "circle" in text and "length" in text: return "circle"
    if "square" in text and "max" in text: return "square"
    if "triangle" in text: return "triangle"
    return "unknown"

def extract_code(text):
    if "```glsl" in text: return text.split("```glsl")[1].split("```")[0].strip()
    if "```" in text: parts = text.split("```"); return parts[1].strip() if len(parts)>1 else text
    if "void mainImage" in text: return text[text.find("void mainImage"):].strip()
    return text

def calculate_mse(img1, img2):
    arr1 = np.array(img1.resize((256,256))).astype(float) / 255.0
    arr2 = np.array(img2.resize((256,256))).astype(float) / 255.0
    return np.mean((arr1 - arr2) ** 2)

# --- PHASE 1: GENERATE DATA ---
print(f"Generating {TEST_SAMPLES} ground truth samples...")
gt_data = []
for i in range(TEST_SAMPLES):
    code, analysis = generate_primitive(i)
    shape = parse_shape_name(analysis) # Get label from GT analysis

    renderer.render(code, "temp_gt.png")
    img = Image.open("temp_gt.png").convert("RGB")

    gt_data.append({"gt_code": code, "gt_shape": shape, "gt_image": img})

# --- PHASE 2: INFERENCE ---
print("Running Inference...")
predictions = []
batch_prompts = []
batch_images = []

for i in tqdm(range(0, TEST_SAMPLES, BATCH_SIZE)):
    batch = gt_data[i : i + BATCH_SIZE]

    # Prepare batch
    imgs = [item["gt_image"] for item in batch]
    prompts = []
    for img in imgs:
        prompts.append([
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
            ]}
        ])

    # Tokenize
    text_inputs = processor.apply_chat_template(prompts, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=text_inputs, images=imgs, padding=True, return_tensors="pt").to(model.device)

    # Generate (Strict Deterministic)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False  # Crucial for code
        )

    # Decode
    generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
    output_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
    predictions.extend(output_texts)

# --- PHASE 3: EVALUATE ---
print("Evaluating...")
y_true = []
y_pred = []
visual_errors = []
compile_success = 0

for i, text in enumerate(predictions):
    gt = gt_data[i]

    pred_code = extract_code(text)
    pred_shape = parse_shape_name(text) # Parse model's own logic

    # Render Prediction
    success = renderer.render(pred_code, "temp_pred.png")

    mse = 1.0
    if success and os.path.exists("temp_pred.png"):
        compile_success += 1
        pred_img = Image.open("temp_pred.png").convert("RGB")
        mse = calculate_mse(gt["gt_image"], pred_img)

    y_true.append(gt["gt_shape"])
    y_pred.append(pred_shape)
    visual_errors.append(mse)

# --- REPORT ---
print("\n" + "="*40)
print(f"TEST REPORT (N={TEST_SAMPLES})")
print("="*40)
print(f"1. Compilation Success: {compile_success}/{TEST_SAMPLES} ({compile_success/TEST_SAMPLES:.1%})")
print(f"2. Average Visual Error (MSE): {np.mean(visual_errors):.4f} (Lower is better)")

print("\n3. Classification Report (Did it identify the shape?):")
print(classification_report(y_true, y_pred, zero_division=0))

print("\n4. Confusion Matrix:")
labels = sorted(list(set(y_true + y_pred)))
cm = confusion_matrix(y_true, y_pred, labels=labels)
print(pd.DataFrame(cm, index=labels, columns=labels))

In [None]:
# @title 4. Auto-Shutdown
# This cell will only run after the training cell finishes.
import time
from google.colab import runtime

print("Training finished. Saving is complete.")
print("Shutting down runtime to save Compute Units in 60 seconds...")

# Give time for the final logs to sync to Drive
time.sleep(60)

print("Goodnight.")
runtime.unassign()