**Stage 1 - Validation**

In [None]:
# @title 1. Install Dependencies
import os
import sys

# 1.1 Force Upgrade Libraries
print("Installing libraries...")
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers peft accelerate
!pip install -q moderngl
!apt-get install -y libgl1-mesa-glx > /dev/null

# 1.2 Restart Check
try:
    import bitsandbytes
    from transformers.utils import is_bitsandbytes_available
    if not is_bitsandbytes_available():
        print("\n" + "="*60)
        print("  PLEASE RESTART RUNTIME (Runtime > Restart Session)")
        print("Then run this cell again.")
        print("="*60 + "\n")
    else:
        print("  Environment Ready.")
except ImportError:
    pass

In [None]:
# @title 2. Load Model, Adapter & STRICT Renderer
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, BitsAndBytesConfig
from peft import PeftModel
import os
import sys
import moderngl
import numpy as np
from PIL import Image

# 1. Configuration
PROJECT_ROOT = '/content/drive/MyDrive/projects/EarthShader'
ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')
LIB_DIR = os.path.join(PROJECT_ROOT, 'lib')

# 2. Mount Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 3. --- STRICT RENDERER ---
class ShaderRenderer:
    def __init__(self, width=256, height=256):
        self.width = width
        self.height = height
        self.ctx = None

        # ATTEMPT 1: EGL (Required for Colab)
        try:
            self.ctx = moderngl.create_context(standalone=True, backend='egl')
            print("[Renderer] Success: EGL Backend initialized.")
        except Exception as e:
            print(f"[Renderer] EGL Failed: {e}")
            # ATTEMPT 2: Standard (Fallback - usually fails on Colab)
            try:
                self.ctx = moderngl.create_context(standalone=True)
                print("[Renderer] Warning: Using Standard Backend (May produce noise).")
            except Exception as e2:
                raise Exception(f"CRITICAL: Could not create ANY ModernGL context.\nEGL: {e}\nStd: {e2}")

        # Geometry
        vertices = np.array([-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0], dtype='f4')
        self.vbo = self.ctx.buffer(vertices.tobytes())
        self.fbo = self.ctx.simple_framebuffer((self.width, self.height), components=3)
        self.vert_shader = '''
        #version 330
        in vec2 in_vert;
        out vec2 uv;
        void main() {
            uv = in_vert;
            gl_Position = vec4(in_vert, 0.0, 1.0);
        }
        '''

        # Load Common Lib
        common_path = os.path.join(LIB_DIR, 'common.glsl')
        self.common_lib = ""
        if os.path.exists(common_path):
            with open(common_path, 'r') as f:
                self.common_lib = f.read()
            print(f"[Renderer] common.glsl loaded ({len(self.common_lib)} chars).")
        else:
            print(f"[Renderer] WARNING: common.glsl NOT FOUND at {common_path}")

    def render(self, fragment_code, output_path):
        full_frag_shader = f'''
        #version 330
        uniform vec2 iResolution;
        out vec4 fragColor;
        {self.common_lib}
        {fragment_code}
        void main() {{
            vec4 color;
            mainImage(color, gl_FragCoord.xy);
            fragColor = color;
        }}
        '''
        try:
            prog = self.ctx.program(vertex_shader=self.vert_shader, fragment_shader=full_frag_shader)
            if 'iResolution' in prog: prog['iResolution'].value = (self.width, self.height)
            vao = self.ctx.simple_vertex_array(prog, self.vbo, 'in_vert')

            # CLEAR AND RENDER
            self.fbo.use()
            self.fbo.clear(0.0, 0.0, 0.0, 1.0) # Clear to Black
            vao.render(moderngl.TRIANGLE_STRIP)

            # READ
            data = self.fbo.read(components=3)
            image = Image.frombytes('RGB', self.fbo.size, data).transpose(Image.FLIP_TOP_BOTTOM)
            image.save(output_path)

            # Cleanup
            vao.release()
            prog.release()
            return True
        except Exception as e:
            # print(f"Render Error: {e}") # Uncomment to debug shader errors
            return False

# Initialize
renderer = ShaderRenderer(width=256, height=256)

# 4. Load Model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
base_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
processor = Qwen2VLProcessor.from_pretrained(ADAPTER_PATH, min_pixels=256*256, max_pixels=256*256)
print("[SUCCESS] Pipeline Ready.")

In [None]:
# @title 3. Inference: The Full Diagnostic (5-Column View)
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
import os
import sys
import importlib.util
import textwrap

# 1. Import Generator Logic
try:
    from generators.primitives import generate_primitive
except ImportError:
    spec_b = importlib.util.spec_from_file_location("base", os.path.join(LIB_DIR, "generators/base.py"))
    mod_b = importlib.util.module_from_spec(spec_b)
    spec_b.loader.exec_module(mod_b)
    spec_p = importlib.util.spec_from_file_location("primitives", os.path.join(LIB_DIR, "generators/primitives.py"))
    mod_p = importlib.util.module_from_spec(spec_p)
    mod_p.base = mod_b
    spec_p.loader.exec_module(mod_p)
    generate_primitive = mod_p.generate_primitive

# 2. Helpers
def parse_response(text):
    """Splits model output into Analysis and Code sections."""
    # Find code start
    code_start_markers = ["```glsl", "void mainImage"]
    idx = -1
    for marker in code_start_markers:
        idx = text.find(marker)
        if idx != -1:
            break

    if idx != -1:
        analysis = text[:idx].strip()
        code_raw = text[idx:].strip()
        # Clean Markdown
        if code_raw.startswith("```glsl"):
            code_raw = code_raw[7:]
        if code_raw.endswith("```"):
            code_raw = code_raw[:-3]
        return analysis, code_raw.strip()
    else:
        return text, "" # No code found

# 3. Inference Loop
print("Running Inference on 3 Random Samples (Full Diagnostic)...\n")
# Extra wide layout for 5 columns
plt.figure(figsize=(30, 12))

rows = 3
cols = 5

for i in range(rows):
    print(f"Sample {i+1}/{rows}...")

    # A. Generate Ground Truth (Input)
    gt_code, gt_analysis = generate_primitive(random.randint(0, 100000))
    renderer.render(gt_code, "temp_gt.png")
    gt_image = Image.open("temp_gt.png").convert("RGB")

    # B. Prompt
    conversation = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
        ]}
    ]
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    inputs = processor(text=[text_prompt], images=[gt_image], padding=True, return_tensors="pt").to(model.device)

    # C. Generate
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=False,
        )

    # D. Process Output
    generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Split Analysis vs Code
    pred_analysis, pred_code = parse_response(output_text)

    # Render Prediction
    pred_success = renderer.render(pred_code, f"temp_pred_{i}.png")
    pred_image = Image.open(f"temp_pred_{i}.png") if pred_success else Image.new("RGB", (256, 256), (255, 0, 0))

    # --- VISUALIZATION ---

    # Col 1: Input (GT)
    plt.subplot(rows, cols, i*cols + 1)
    plt.imshow(gt_image)
    plt.title("Input (GT)", fontweight='bold')
    plt.axis('off')

    # Col 2: Prediction
    plt.subplot(rows, cols, i*cols + 2)
    plt.imshow(pred_image)
    plt.title("Prediction", fontweight='bold')
    plt.axis('off')

    # Col 3: GT Analysis
    ax = plt.subplot(rows, cols, i*cols + 3)
    plt.text(0, 1, textwrap.fill(gt_analysis, width=35),
             family='sans-serif', fontsize=9, verticalalignment='top')
    plt.title("Ground Truth Logic", fontweight='bold', color='green')
    plt.axis('off')

    # Col 4: Pred Analysis (The "Thought Process")
    ax = plt.subplot(rows, cols, i*cols + 4)
    # If empty, warn user
    disp_analysis = pred_analysis if pred_analysis else "[NO ANALYSIS GENERATED]"
    plt.text(0, 1, textwrap.fill(disp_analysis, width=35),
             family='sans-serif', fontsize=9, verticalalignment='top')
    plt.title("Model's Logic", fontweight='bold', color='blue')
    plt.axis('off')

    # Col 5: Generated Code
    ax = plt.subplot(rows, cols, i*cols + 5)
    display_code = pred_code[:600] + ("\n..." if len(pred_code) > 600 else "")
    plt.text(0, 1, display_code, family='monospace', fontsize=8, verticalalignment='top')
    plt.title("Generated Code", fontweight='bold')
    plt.axis('off')

plt.tight_layout()
plt.show()