In [None]:
# @title 1. Environment and Stage 1 weight loading
import os
import sys
import torch
import time
from google.colab import drive

# 1. Mount drive.
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 2. Project paths.
PROJECT_ROOT = '/content/drive/My Drive/projects/EarthShader'
STAGE1_ADAPTER_PATH = os.path.join(PROJECT_ROOT, 'checkpoints/stage1_final')
STAGE2_OUTPUT = os.path.join(PROJECT_ROOT, 'checkpoints/stage2_final')
DATASET_PATH = os.path.join(PROJECT_ROOT, 'dataset/stage2/dataset.jsonl')

os.makedirs(STAGE2_OUTPUT, exist_ok=True)

# 3. Forced installation of latest bitsandbytes for 4-bit support.
print("Ensuring environment dependencies...")
!pip install -q -U bitsandbytes
!pip install -q git+https://github.com/huggingface/transformers peft datasets accelerate qwen-vl-utils trl

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel

# 4. Memory-optimized quantization config.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16, # Use float16 for T4 compatibility.
    bnb_4bit_use_double_quant=True,
)

BASE_MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
print(f"Loading base architecture: {BASE_MODEL_ID}")

model = Qwen2VLForConditionalGeneration.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map={"": 0}, #"auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

# 5. Inject Stage 1 weights as trainable adapters.
print(f"Loading Stage 1 foundation: {STAGE1_ADAPTER_PATH}")
model = PeftModel.from_pretrained(model, STAGE1_ADAPTER_PATH, is_trainable=True)

# 6. VRAM OPTIMIZATION: Freeze vision tower.
# We only need to train the language/logic layers for CSG.
for name, param in model.named_parameters():
    if "visual" in name:
        param.requires_grad = False

model.gradient_checkpointing_enable()
model.config.use_cache = False

processor = AutoProcessor.from_pretrained(
    BASE_MODEL_ID,
    min_pixels=256*256,
    max_pixels=256*256
)

print("[SUCCESS] Model loaded and vision tower frozen to save VRAM.")

In [None]:
# @title 2. Scenario collation and anchoring
from PIL import Image
from datasets import load_dataset
from torch.utils.data import DataLoader

SYSTEM_PROMPT = "You are an EarthShader SDF compiler. Translate the visual primitive into a valid GLSL mainImage function using the common.glsl library."

def final_collate_fn(batch):
    images, full_texts, prompt_only_texts = [], [], []
    for item in batch:
        try:
            image = Image.open(item['image_path']).convert("RGB")
        except:
            raise FileNotFoundError(f"Missing training image: {item['image_path']}. Check your mount!") from e
        images.append(image)

        prompt_conv = [
            {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": "Reverse engineer the GLSL shader code for this texture. Include analysis."}
            ]}
        ]

        full_response = f"{item['analysis']}\n\n```glsl\n{item['code']}\n```"
        full_conv = prompt_conv + [{"role": "assistant", "content": [{"type": "text", "text": full_response}]}]

        prompt_only_texts.append(processor.apply_chat_template(prompt_conv, tokenize=False, add_generation_prompt=True))
        full_texts.append(processor.apply_chat_template(full_conv, tokenize=False, add_generation_prompt=False))

    inputs = processor(text=full_texts, images=images, padding="max_length", max_length=768, truncation=True, return_tensors="pt")
    inputs_prompts = processor(text=prompt_only_texts, images=images, padding="max_length", max_length=768, truncation=True, return_tensors="pt")

    labels = inputs["input_ids"].clone()
    for i in range(len(batch)):
        prompt_len = inputs_prompts["attention_mask"][i].sum().item()
        labels[i, :prompt_len] = -100
    inputs["labels"] = labels
    return inputs

raw_dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
full_loader = DataLoader(raw_dataset, batch_size=1, shuffle=False, collate_fn=final_collate_fn)

In [None]:
# @title 3. Training loop and weighted loss stabilization
import torch.nn as nn
from tqdm import tqdm
import bitsandbytes as bnb
from transformers import get_cosine_schedule_with_warmup
import time
import os
from google.colab import runtime

def compute_stage2_loss(logits, labels, tokenizer, code_weight=3.0, csg_weight=5.0, text_weight=0.5):
    """Calculates loss using float32 to prevent numerical overflow with csg weights."""
    # Encode target tokens for logical weighting.
    csg_tokens = tokenizer.encode("min max", add_special_tokens=False)
    code_marker = tokenizer.encode("```glsl", add_special_tokens=False)
    loss_fct = nn.CrossEntropyLoss(reduction='none')

    # Prepare shifts for causal language modeling.
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Cast to float32 for stable weighted loss accumulation.
    weights = torch.ones_like(shift_labels, dtype=torch.float32)

    for i in range(shift_labels.size(0)):
        row = shift_labels[i].tolist()
        try:
            # Locate the start of the code block.
            marker_idx = next(idx for idx, val in enumerate(row) if val == code_marker[-1])
            weights[i, :marker_idx] = text_weight
            weights[i, marker_idx:] = code_weight
            for idx in range(marker_idx, len(row)):
                if row[idx] in csg_tokens:
                    weights[i, idx] = csg_weight
        except StopIteration:
            pass

    # Move logits to float32 before the cross entropy calculation.
    flat_logits = shift_logits.view(-1, shift_logits.size(-1)).float()
    flat_labels = shift_labels.view(-1)

    loss = loss_fct(flat_logits, flat_labels)
    valid_mask = (flat_labels != -100)

    # Apply weights and normalize by the number of valid tokens.
    weighted_loss = (loss * weights.view(-1))[valid_mask].sum() / valid_mask.sum()
    return weighted_loss

# 1. Hyperparameters for stabilized Stage 2.
GRAD_ACCUMULATION = 4
TOTAL_ITERATIONS = 2000
EPOCHS = 3
LEARNING_RATE = 1e-5 # Lowered to prevent divergence.
TOTAL_STEPS = (TOTAL_ITERATIONS // GRAD_ACCUMULATION) * EPOCHS

model.train()
optimizer = bnb.optim.PagedAdamW8bit(model.parameters(), lr=LEARNING_RATE)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(TOTAL_STEPS * 0.1),
    num_training_steps=TOTAL_STEPS
)

global_step = 0

for epoch in range(EPOCHS):
    pbar = tqdm(full_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", total=TOTAL_ITERATIONS)
    for step, batch in enumerate(pbar):
        if step >= TOTAL_ITERATIONS:
            break

        # Periodic memory cleanup for T4 stability.
        if step % 10 == 0:
            torch.cuda.empty_cache()

        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch)

        # Calculate loss in high precision.
        loss = compute_stage2_loss(outputs.logits, batch["labels"], processor.tokenizer)

        # Skip training if loss explodes despite precautions.
        if torch.isnan(loss) or loss.item() > 100.0:
            print(f"Skipping unstable loss at step {step}.")
            optimizer.zero_grad()
            continue

        (loss / GRAD_ACCUMULATION).backward()

        if (step + 1) % GRAD_ACCUMULATION == 0:
            # Tighter gradient clipping to stabilize Boolean logic learnin.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })

            # Save checkpoints every 100 steps to protect progress.
            if global_step % 100 == 0:
                model.save_pretrained(os.path.join(STAGE2_OUTPUT, f"step-{global_step}"))

# 2. Final save for the refined Stage 2 weights.
model.save_pretrained(STAGE2_OUTPUT)
print(f"\n[SUCCESS] Stabilized Stage 2 training complete.")

# 3. Shutdown sequence to save credits.
time.sleep(60)
runtime.unassign()

In [None]:
# @title 4. Auto shutdown
import time
from google.colab import runtime

# Ensure the drive has enough time to sync the final adapter files.
print("Training sequence has finished. Synchronizing drive files...")
time.sleep(60)

print("The system is going offline to preserve compute units.")
runtime.unassign()