**Stage 1 - Validation**

In [None]:
# @title 1. Environment and paths setup
import os
import sys
from google.colab import drive

# 1. Mount drive to access the trained adapters and renderer.
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 2. Project paths using the correct Drive prefix.
PROJECT_ROOT = '/content/drive/My Drive/projects/EarthShader'
# Points to the committed stage 1 adapter folder.
ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')
LIB_DIR = os.path.join(PROJECT_ROOT, 'lib')

# 3. Ensure dependencies are present for the T4 environment.
try:
    import moderngl
except ImportError:
    print("Installing missing dependencies...")
    # These libraries are required for the vision tower and LoRA inference.
    !pip install -q moderngl moderngl-window peft bitsandbytes accelerate qwen-vl-utils

# 4. Link library for the renderer and generator.
if LIB_DIR not in sys.path:
    sys.path.append(LIB_DIR)

from gl_renderer import ShaderRenderer

# Initialize the renderer at 256 for faster validation cycles.
# This resolution matches the max_pixels used during training.
renderer = ShaderRenderer(width=256, height=256)

print(f"[SUCCESS] Environment ready. Renderer initialized at 256x256.")
print(f"Validation path set to: {ADAPTER_PATH}")

In [None]:
# @title 2. Load model and stage 1 adapter
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
import os

# 1. Configuration and committed checkpoint.
MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
# Updated to point to your finalized Stage 1 weights.
ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')

# 2. Initialize processor first to ensure it is defined for later cells.
print(f"Loading processor for {MODEL_ID}...")
processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*256, max_pixels=256*256)

if not os.path.exists(ADAPTER_PATH):
    print(f"[ERROR] Checkpoint not found at {ADAPTER_PATH}.")
    print("Ensure you have run the training notebook to completion.")
else:
    # 3. Load the base model in 4-bit for T4 memory safety.
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    print(f"Loading base model and applying adapter...")
    base_model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

    # 4. Apply the Stage 1 final adapter.
    model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
    model.eval()

    print("[SUCCESS] Stage 1 model and processor ready for evaluation.")

In [None]:
# @title 3. Inference: The full diagnostic and 5-column view
import random
import matplotlib.pyplot as plt
from PIL import Image
import textwrap
import gc
import torch
import os

# 1. Access the generator logic from your project lib.
from generators.primitives import generate_primitive

# This prompt must match the one used in the training collator logic exactly.
SYSTEM_PROMPT = "You are an EarthShader SDF compiler. Translate the visual primitive into a valid GLSL mainImage function using the common.glsl library."

def parse_response(text):
    """Splits model output into analysis and code sections."""
    code_start_markers = ["```glsl", "void mainImage", "//"]
    idx = -1
    for marker in code_start_markers:
        idx = text.find(marker)
        if idx != -1: break

    if idx != -1:
        analysis = text[:idx].strip()
        code_raw = text[idx:].strip()
        if code_raw.startswith("```glsl"): code_raw = code_raw[7:]
        if code_raw.endswith("```"): code_raw = code_raw[:-3]
        return analysis, code_raw.strip()
    return text, ""

# 2. Diagnostic loop.
num_samples = 3
plt.figure(figsize=(30, 12))

for i in range(num_samples):
    print(f"Testing sample {i+1}...")

    # A. Generate input (Ground Truth).
    # Using random seeds outside the training range (0-2000) for fair testing.
    gt_code, gt_analysis = generate_primitive(random.randint(5000, 10000))
    renderer.render(gt_code, "temp_gt.png")
    gt_image = Image.open("temp_gt.png").convert("RGB")

    # B. Prompt construction matching the training collator logic.
    conversation = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
        ]}
    ]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=[prompt], images=[gt_image], padding=True, return_tensors="pt").to(model.device)

    # C. Generate with low temperature for stable reasoning.
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=1024, temperature=0.2, do_sample=True)

    # D. Cleanup and decode.
    generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    pred_analysis, pred_code = parse_response(output_text)

    # E. Render model output.
    temp_pred_path = f"temp_pred_{i}.png"
    pred_success = renderer.render(pred_code, temp_pred_path)

    if pred_success and os.path.exists(temp_pred_path):
        pred_image = Image.open(temp_pred_path)
    else:
        # Create a dark red fallback to indicate a compilation/rendering failure.
        pred_image = Image.new("RGB", (256, 256), (100, 0, 0))

    # --- VISUALIZATION (5 Columns) ---
    plt.subplot(num_samples, 5, i*5 + 1); plt.imshow(gt_image); plt.title("Input (GT)"); plt.axis('off')
    plt.subplot(num_samples, 5, i*5 + 2); plt.imshow(pred_image); plt.title("Model Result"); plt.axis('off')

    ax3 = plt.subplot(num_samples, 5, i*5 + 3)
    plt.text(0, 1, textwrap.fill(gt_analysis, width=35), fontsize=9, va='top'); plt.title("GT Logic"); plt.axis('off')

    ax4 = plt.subplot(num_samples, 5, i*5 + 4)
    plt.text(0, 1, textwrap.fill(pred_analysis, width=35), fontsize=9, va='top', color='blue'); plt.title("Model Reasoning"); plt.axis('off')

    ax5 = plt.subplot(num_samples, 5, i*5 + 5)
    plt.text(0, 1, pred_code[:600], family='monospace', fontsize=8, va='top'); plt.title("Generated Code"); plt.axis('off')

    # F. Memory cleanup.
    del inputs, output_ids, generated_ids
    torch.cuda.empty_cache()
    gc.collect()

plt.tight_layout()
plt.show()

In [None]:
# @title 4. Raw output dump and diagnostic
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
import gc

# The system prompt anchors the model to its "compiler" persona.
SYSTEM_PROMPT = "You are an EarthShader SDF compiler. Translate the visual primitive into a valid GLSL mainImage function using the common.glsl library."

# 1. Reset metrics and run a single, deep test.
num_samples = 1

for i in range(num_samples):
    # A. Generate Ground Truth Input.
    gt_code, gt_analysis = generate_primitive(random.randint(0, 100000))
    renderer.render(gt_code, "temp_gt.png")
    gt_image = Image.open("temp_gt.png").convert("RGB")

    # B. Generate Prediction with the system role included.
    conversation = [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
        ]}
    ]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=[prompt], images=[gt_image], padding=True, return_tensors="pt").to(model.device)

    # Low temperature helps maintain structural consistency.
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=1024, temperature=0.1, do_sample=True)

    generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # --- RAW OUTPUT DUMP ---
    print(f"\n{'='*60}\nSAMPLE {i+1} RAW OUTPUT\n{'='*60}")
    print(output_text)
    print(f"{'='*60}\n")

    # C. Cleanup.
    del inputs, output_ids
    torch.cuda.empty_cache()
    gc.collect()