# üêç ZeroLang Snake Game Training

Bu notebook Snake oyunu i√ßin gerekli fonksiyonlarƒ± √∂ƒüretmek √ºzere tasarlandƒ±.

## Hedef Yetenekler
1. **Array Operations**: get, set, push, pop, length
2. **Point/Vector**: x,y koordinatlarƒ±, distance
3. **Collision Detection**: bounds check, point equality
4. **Game Math**: modulo, clamp, random (LCG)
5. **Direction**: enum-like values, movement vectors

## Runtime
- **Step 1**: Free CPU (veri √ºretimi)
- **Step 2**: H100 GPU (training)

---
## Step 1: Environment Setup

In [None]:
# Install LLVM and wasm-tools
!apt-get update -qq
!apt-get install -qq -y llvm lld clang

# Install wasm-tools
!wget -q https://github.com/bytecodealliance/wasm-tools/releases/download/v1.219.1/wasm-tools-1.219.1-x86_64-linux.tar.gz
!tar -xzf wasm-tools-1.219.1-x86_64-linux.tar.gz
!cp wasm-tools-1.219.1-x86_64-linux/wasm-tools /usr/local/bin/
!chmod +x /usr/local/bin/wasm-tools

# Verify
!clang --version | head -1
!wasm-tools --version

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create output directory
!mkdir -p /content/drive/MyDrive/zerolang/snake_data

---
## Step 2: Snake Game Function Generator

Elle hazƒ±rlanmƒ±≈ü, y√ºksek kaliteli C fonksiyonlarƒ± ‚Üí WAT √ßiftleri √ºretiyoruz.

In [None]:
import subprocess
import json
import os
import tempfile
from typing import List, Tuple, Optional

# Snake oyunu i√ßin el yapƒ±mƒ± fonksiyonlar
SNAKE_FUNCTIONS = [
    # ========== TEMEL MATEMATƒ∞K ==========
    {
        "instruction": "Implement: int abs(int x) - returns absolute value",
        "code": "int abs_val(int x) { return x < 0 ? -x : x; }",
        "test": "abs_val",
        "category": "math"
    },
    {
        "instruction": "Implement: int min(int a, int b) - returns minimum of two integers",
        "code": "int min(int a, int b) { return a < b ? a : b; }",
        "test": "min",
        "category": "math"
    },
    {
        "instruction": "Implement: int max(int a, int b) - returns maximum of two integers",
        "code": "int max(int a, int b) { return a > b ? a : b; }",
        "test": "max",
        "category": "math"
    },
    {
        "instruction": "Implement: int clamp(int value, int min_val, int max_val) - clamps value between min and max",
        "code": "int clamp(int value, int min_val, int max_val) { if (value < min_val) return min_val; if (value > max_val) return max_val; return value; }",
        "test": "clamp",
        "category": "math"
    },
    {
        "instruction": "Implement: int mod(int a, int b) - returns a modulo b (always positive)",
        "code": "int mod(int a, int b) { int r = a % b; return r < 0 ? r + b : r; }",
        "test": "mod",
        "category": "math"
    },
    {
        "instruction": "Implement: int sign(int x) - returns -1, 0, or 1 based on sign of x",
        "code": "int sign(int x) { if (x < 0) return -1; if (x > 0) return 1; return 0; }",
        "test": "sign",
        "category": "math"
    },
    {
        "instruction": "Implement: int wrap(int value, int max) - wraps value around 0 to max-1",
        "code": "int wrap(int value, int max) { while (value < 0) value += max; return value % max; }",
        "test": "wrap",
        "category": "math"
    },
    
    # ========== POINT/VECTOR ƒ∞≈ûLEMLERƒ∞ ==========
    {
        "instruction": "Implement: int pack_point(int x, int y) - packs two 16-bit coords into one 32-bit int",
        "code": "int pack_point(int x, int y) { return ((x & 0xFFFF) << 16) | (y & 0xFFFF); }",
        "test": "pack_point",
        "category": "point"
    },
    {
        "instruction": "Implement: int unpack_x(int packed) - extracts x coordinate from packed point",
        "code": "int unpack_x(int packed) { return (packed >> 16) & 0xFFFF; }",
        "test": "unpack_x",
        "category": "point"
    },
    {
        "instruction": "Implement: int unpack_y(int packed) - extracts y coordinate from packed point",
        "code": "int unpack_y(int packed) { return packed & 0xFFFF; }",
        "test": "unpack_y",
        "category": "point"
    },
    {
        "instruction": "Implement: int point_equal(int p1, int p2) - returns 1 if points are equal, 0 otherwise",
        "code": "int point_equal(int p1, int p2) { return p1 == p2 ? 1 : 0; }",
        "test": "point_equal",
        "category": "point"
    },
    {
        "instruction": "Implement: int manhattan_distance(int x1, int y1, int x2, int y2) - returns Manhattan distance between two points",
        "code": "int manhattan_distance(int x1, int y1, int x2, int y2) { int dx = x1 - x2; int dy = y1 - y2; if (dx < 0) dx = -dx; if (dy < 0) dy = -dy; return dx + dy; }",
        "test": "manhattan_distance",
        "category": "point"
    },
    {
        "instruction": "Implement: int move_point(int packed, int dx, int dy) - moves a packed point by delta",
        "code": "int move_point(int packed, int dx, int dy) { int x = (packed >> 16) & 0xFFFF; int y = packed & 0xFFFF; x += dx; y += dy; return ((x & 0xFFFF) << 16) | (y & 0xFFFF); }",
        "test": "move_point",
        "category": "point"
    },
    
    # ========== Y√ñN (DIRECTION) ƒ∞≈ûLEMLERƒ∞ ==========
    # Direction encoding: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
    {
        "instruction": "Implement: int get_dx(int direction) - returns x delta for direction (0=UP, 1=RIGHT, 2=DOWN, 3=LEFT)",
        "code": "int get_dx(int direction) { if (direction == 1) return 1; if (direction == 3) return -1; return 0; }",
        "test": "get_dx",
        "category": "direction"
    },
    {
        "instruction": "Implement: int get_dy(int direction) - returns y delta for direction (0=UP, 1=RIGHT, 2=DOWN, 3=LEFT)",
        "code": "int get_dy(int direction) { if (direction == 0) return -1; if (direction == 2) return 1; return 0; }",
        "test": "get_dy",
        "category": "direction"
    },
    {
        "instruction": "Implement: int opposite_direction(int dir) - returns opposite direction",
        "code": "int opposite_direction(int dir) { return (dir + 2) % 4; }",
        "test": "opposite_direction",
        "category": "direction"
    },
    {
        "instruction": "Implement: int turn_left(int dir) - returns direction after turning left",
        "code": "int turn_left(int dir) { return (dir + 3) % 4; }",
        "test": "turn_left",
        "category": "direction"
    },
    {
        "instruction": "Implement: int turn_right(int dir) - returns direction after turning right",
        "code": "int turn_right(int dir) { return (dir + 1) % 4; }",
        "test": "turn_right",
        "category": "direction"
    },
    {
        "instruction": "Implement: int is_opposite(int dir1, int dir2) - returns 1 if directions are opposite",
        "code": "int is_opposite(int dir1, int dir2) { return ((dir1 + 2) % 4) == dir2 ? 1 : 0; }",
        "test": "is_opposite",
        "category": "direction"
    },
    
    # ========== BOUNDS/COLLISION ==========
    {
        "instruction": "Implement: int in_bounds(int x, int y, int width, int height) - returns 1 if point is within bounds",
        "code": "int in_bounds(int x, int y, int width, int height) { return (x >= 0 && x < width && y >= 0 && y < height) ? 1 : 0; }",
        "test": "in_bounds",
        "category": "collision"
    },
    {
        "instruction": "Implement: int hit_wall(int x, int y, int width, int height) - returns 1 if point hits wall",
        "code": "int hit_wall(int x, int y, int width, int height) { return (x < 0 || x >= width || y < 0 || y >= height) ? 1 : 0; }",
        "test": "hit_wall",
        "category": "collision"
    },
    {
        "instruction": "Implement: int wrap_x(int x, int width) - wraps x coordinate within grid width",
        "code": "int wrap_x(int x, int width) { while (x < 0) x += width; return x % width; }",
        "test": "wrap_x",
        "category": "collision"
    },
    {
        "instruction": "Implement: int wrap_y(int y, int height) - wraps y coordinate within grid height",
        "code": "int wrap_y(int y, int height) { while (y < 0) y += height; return y % height; }",
        "test": "wrap_y",
        "category": "collision"
    },
    
    # ========== RANDOM NUMBER (LCG) ==========
    {
        "instruction": "Implement: int lcg_next(int seed) - Linear Congruential Generator, returns next random value",
        "code": "int lcg_next(int seed) { return (seed * 1103515245 + 12345) & 0x7FFFFFFF; }",
        "test": "lcg_next",
        "category": "random"
    },
    {
        "instruction": "Implement: int random_range(int seed, int min, int max) - returns random number in range [min, max)",
        "code": "int random_range(int seed, int min_val, int max_val) { int range = max_val - min_val; if (range <= 0) return min_val; return min_val + (seed % range); }",
        "test": "random_range",
        "category": "random"
    },
    {
        "instruction": "Implement: int random_position(int seed, int width, int height) - returns packed random position",
        "code": "int random_position(int seed, int width, int height) { int x = seed % width; int y = (seed / width) % height; return ((x & 0xFFFF) << 16) | (y & 0xFFFF); }",
        "test": "random_position",
        "category": "random"
    },
    
    # ========== GAME STATE ==========
    {
        "instruction": "Implement: int make_game_state(int score, int length, int alive) - packs game state into int",
        "code": "int make_game_state(int score, int length, int alive) { return ((score & 0xFFFF) << 16) | ((length & 0x7FFF) << 1) | (alive & 1); }",
        "test": "make_game_state",
        "category": "state"
    },
    {
        "instruction": "Implement: int get_score(int state) - extracts score from game state",
        "code": "int get_score(int state) { return (state >> 16) & 0xFFFF; }",
        "test": "get_score",
        "category": "state"
    },
    {
        "instruction": "Implement: int get_length(int state) - extracts snake length from game state",
        "code": "int get_length(int state) { return (state >> 1) & 0x7FFF; }",
        "test": "get_length",
        "category": "state"
    },
    {
        "instruction": "Implement: int is_alive(int state) - returns 1 if snake is alive",
        "code": "int is_alive(int state) { return state & 1; }",
        "test": "is_alive",
        "category": "state"
    },
    {
        "instruction": "Implement: int add_score(int state, int points) - adds points to game state",
        "code": "int add_score(int state, int points) { int score = ((state >> 16) & 0xFFFF) + points; int rest = state & 0xFFFF; return ((score & 0xFFFF) << 16) | rest; }",
        "test": "add_score",
        "category": "state"
    },
    {
        "instruction": "Implement: int increment_length(int state) - increases snake length by 1",
        "code": "int increment_length(int state) { int score = (state >> 16) & 0xFFFF; int length = ((state >> 1) & 0x7FFF) + 1; int alive = state & 1; return ((score & 0xFFFF) << 16) | ((length & 0x7FFF) << 1) | alive; }",
        "test": "increment_length",
        "category": "state"
    },
    {
        "instruction": "Implement: int set_dead(int state) - sets alive flag to 0",
        "code": "int set_dead(int state) { return state & ~1; }",
        "test": "set_dead",
        "category": "state"
    },
    
    # ========== GRID/INDEX OPERATIONS ==========
    {
        "instruction": "Implement: int grid_index(int x, int y, int width) - converts 2D coords to 1D array index",
        "code": "int grid_index(int x, int y, int width) { return y * width + x; }",
        "test": "grid_index",
        "category": "grid"
    },
    {
        "instruction": "Implement: int index_to_x(int index, int width) - extracts x from 1D index",
        "code": "int index_to_x(int index, int width) { return index % width; }",
        "test": "index_to_x",
        "category": "grid"
    },
    {
        "instruction": "Implement: int index_to_y(int index, int width) - extracts y from 1D index",
        "code": "int index_to_y(int index, int width) { return index / width; }",
        "test": "index_to_y",
        "category": "grid"
    },
    {
        "instruction": "Implement: int grid_size(int width, int height) - returns total grid size",
        "code": "int grid_size(int width, int height) { return width * height; }",
        "test": "grid_size",
        "category": "grid"
    },
    
    # ========== SNAKE MOVEMENT ==========
    {
        "instruction": "Implement: int move_snake_head(int head, int direction, int width, int height) - moves head in direction with wrapping",
        "code": """int move_snake_head(int head, int direction, int width, int height) {
    int x = (head >> 16) & 0xFFFF;
    int y = head & 0xFFFF;
    if (direction == 0) y--;
    else if (direction == 1) x++;
    else if (direction == 2) y++;
    else if (direction == 3) x--;
    while (x < 0) x += width;
    while (y < 0) y += height;
    x = x % width;
    y = y % height;
    return ((x & 0xFFFF) << 16) | (y & 0xFFFF);
}""",
        "test": "move_snake_head",
        "category": "snake"
    },
    {
        "instruction": "Implement: int check_food_collision(int head, int food) - returns 1 if head is at food position",
        "code": "int check_food_collision(int head, int food) { return head == food ? 1 : 0; }",
        "test": "check_food_collision",
        "category": "snake"
    },
    {
        "instruction": "Implement: int calculate_score(int base_score, int length, int multiplier) - calculates total score",
        "code": "int calculate_score(int base_score, int length, int multiplier) { return base_score + (length * multiplier); }",
        "test": "calculate_score",
        "category": "snake"
    },
    {
        "instruction": "Implement: int should_grow(int ate_food) - returns 1 if snake should grow",
        "code": "int should_grow(int ate_food) { return ate_food != 0 ? 1 : 0; }",
        "test": "should_grow",
        "category": "snake"
    },
    
    # ========== MEMORY/ARRAY HELPERS ==========
    {
        "instruction": "Implement: int circular_index(int index, int size) - returns index wrapped for circular buffer",
        "code": "int circular_index(int index, int size) { while (index < 0) index += size; return index % size; }",
        "test": "circular_index",
        "category": "array"
    },
    {
        "instruction": "Implement: int next_index(int current, int size) - returns next index in circular buffer",
        "code": "int next_index(int current, int size) { return (current + 1) % size; }",
        "test": "next_index",
        "category": "array"
    },
    {
        "instruction": "Implement: int prev_index(int current, int size) - returns previous index in circular buffer",
        "code": "int prev_index(int current, int size) { return (current + size - 1) % size; }",
        "test": "prev_index",
        "category": "array"
    },
    {
        "instruction": "Implement: int buffer_full(int head, int tail, int size) - returns 1 if circular buffer is full",
        "code": "int buffer_full(int head, int tail, int size) { return ((head + 1) % size) == tail ? 1 : 0; }",
        "test": "buffer_full",
        "category": "array"
    },
    {
        "instruction": "Implement: int buffer_empty(int head, int tail) - returns 1 if circular buffer is empty",
        "code": "int buffer_empty(int head, int tail) { return head == tail ? 1 : 0; }",
        "test": "buffer_empty",
        "category": "array"
    },
    {
        "instruction": "Implement: int buffer_length(int head, int tail, int size) - returns number of elements in circular buffer",
        "code": "int buffer_length(int head, int tail, int size) { if (head >= tail) return head - tail; return size - tail + head; }",
        "test": "buffer_length",
        "category": "array"
    }
]

print(f"Toplam {len(SNAKE_FUNCTIONS)} fonksiyon tanƒ±mlandƒ±")
print(f"Kategoriler: {set(f['category'] for f in SNAKE_FUNCTIONS)}")

---
## Step 3: C ‚Üí WAT Compiler

In [None]:
def compile_c_to_wat(c_code: str, func_name: str) -> Optional[str]:
    """Compile C code to WAT format."""
    with tempfile.TemporaryDirectory() as tmpdir:
        c_file = os.path.join(tmpdir, "func.c")
        wasm_file = os.path.join(tmpdir, "func.wasm")
        wat_file = os.path.join(tmpdir, "func.wat")
        
        # Write C file
        with open(c_file, "w") as f:
            f.write(c_code)
        
        # Compile C ‚Üí WASM
        compile_cmd = [
            "clang",
            "--target=wasm32",
            "-O2",
            "-nostdlib",
            "-fuse-ld=lld",
            "-Wl,--no-entry",
            "-Wl,--export-all",
            "-o", wasm_file,
            c_file
        ]
        
        result = subprocess.run(compile_cmd, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Compile error for {func_name}: {result.stderr}")
            return None
        
        # Convert WASM ‚Üí WAT
        wat_cmd = ["wasm-tools", "print", wasm_file, "-o", wat_file]
        result = subprocess.run(wat_cmd, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"WAT conversion error for {func_name}: {result.stderr}")
            return None
        
        # Read WAT
        with open(wat_file, "r") as f:
            return f.read()

# Test compilation
test_wat = compile_c_to_wat("int test_add(int a, int b) { return a + b; }", "test_add")
if test_wat:
    print("‚úÖ Compiler working!")
    print(test_wat[:200] + "...")
else:
    print("‚ùå Compiler failed!")

---
## Step 4: Generate Training Data

In [None]:
import random

def generate_instruction_variations(base_instruction: str, func_name: str) -> List[str]:
    """Generate variations of the instruction."""
    variations = [base_instruction]
    
    # Remove "Implement: " prefix variations
    if base_instruction.startswith("Implement: "):
        core = base_instruction[11:]
        variations.extend([
            f"Write: {core}",
            f"Create: {core}",
            f"Code: {core}",
            core,  # Just the signature
        ])
    
    # Add natural language variations
    if "returns" in base_instruction.lower():
        # Extract the return description
        parts = base_instruction.split(" - ")
        if len(parts) == 2:
            sig, desc = parts
            variations.append(f"Function that {desc}")
            variations.append(f"Write a function that {desc}")
    
    return variations

def generate_training_data() -> List[dict]:
    """Generate all training data."""
    training_data = []
    
    for i, func in enumerate(SNAKE_FUNCTIONS):
        print(f"[{i+1}/{len(SNAKE_FUNCTIONS)}] Processing {func['test']}...", end=" ")
        
        # Compile to WAT
        wat = compile_c_to_wat(func['code'], func['test'])
        if not wat:
            print("‚ùå FAILED")
            continue
        
        # Generate instruction variations
        variations = generate_instruction_variations(func['instruction'], func['test'])
        
        # Add each variation as a training example
        for variation in variations:
            training_data.append({
                "instruction": variation,
                "output": wat,
                "category": func['category'],
                "func_name": func['test']
            })
        
        print(f"‚úÖ ({len(variations)} variations)")
    
    return training_data

# Generate data
print("Generating training data...\n")
training_data = generate_training_data()
print(f"\n‚úÖ Generated {len(training_data)} training examples")

In [None]:
# Show category distribution
from collections import Counter

categories = Counter(d['category'] for d in training_data)
print("Category Distribution:")
for cat, count in sorted(categories.items()):
    print(f"  {cat}: {count}")

---
## Step 5: Convert to ChatML Format

In [None]:
SYSTEM_PROMPT = """You are ZeroLang, an AI that generates WebAssembly Text Format (WAT) code.
Given a function description or signature, output valid WAT code that implements the function.
Output only the WAT code, no explanations."""

def to_chatml(data: List[dict]) -> List[dict]:
    """Convert to ChatML format."""
    chatml_data = []
    for item in data:
        chatml_data.append({
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": item["instruction"]},
                {"role": "assistant", "content": item["output"]}
            ]
        })
    return chatml_data

# Convert
chatml_data = to_chatml(training_data)
print(f"Converted {len(chatml_data)} examples to ChatML format")

# Show example
print("\nExample:")
print(json.dumps(chatml_data[0], indent=2)[:500] + "...")

In [None]:
# Shuffle and split
random.seed(42)
random.shuffle(chatml_data)

split_idx = int(len(chatml_data) * 0.9)
train_data = chatml_data[:split_idx]
val_data = chatml_data[split_idx:]

print(f"Train: {len(train_data)} examples")
print(f"Val: {len(val_data)} examples")

In [None]:
# Save to Google Drive
output_dir = "/content/drive/MyDrive/zerolang/snake_data"

train_file = os.path.join(output_dir, "train_snake.jsonl")
val_file = os.path.join(output_dir, "val_snake.jsonl")

# Write JSONL
with open(train_file, "w") as f:
    for item in train_data:
        f.write(json.dumps(item) + "\n")

with open(val_file, "w") as f:
    for item in val_data:
        f.write(json.dumps(item) + "\n")

print(f"‚úÖ Saved to:")
print(f"   {train_file}")
print(f"   {val_file}")

---
## Step 6: Merge with Existing Data (Optional)

Mevcut training data ile birle≈ütir.

In [None]:
# Check for existing data
existing_train = "/content/drive/MyDrive/zerolang/data/train_chatml_large.jsonl"
existing_val = "/content/drive/MyDrive/zerolang/data/val_chatml_large.jsonl"

merge_data = False

if os.path.exists(existing_train):
    print(f"Found existing training data: {existing_train}")
    merge_data = True
    
    # Load existing
    with open(existing_train, "r") as f:
        existing_train_data = [json.loads(line) for line in f]
    with open(existing_val, "r") as f:
        existing_val_data = [json.loads(line) for line in f]
    
    print(f"Existing: {len(existing_train_data)} train, {len(existing_val_data)} val")
    print(f"Snake: {len(train_data)} train, {len(val_data)} val")
    
    # Merge
    merged_train = existing_train_data + train_data
    merged_val = existing_val_data + val_data
    
    # Shuffle
    random.shuffle(merged_train)
    random.shuffle(merged_val)
    
    print(f"Merged: {len(merged_train)} train, {len(merged_val)} val")
else:
    print("No existing data found. Using snake data only.")
    merged_train = train_data
    merged_val = val_data

In [None]:
# Save merged data
merged_output_dir = "/content/drive/MyDrive/zerolang/data"
os.makedirs(merged_output_dir, exist_ok=True)

merged_train_file = os.path.join(merged_output_dir, "train_chatml_snake.jsonl")
merged_val_file = os.path.join(merged_output_dir, "val_chatml_snake.jsonl")

with open(merged_train_file, "w") as f:
    for item in merged_train:
        f.write(json.dumps(item) + "\n")

with open(merged_val_file, "w") as f:
    for item in merged_val:
        f.write(json.dumps(item) + "\n")

print(f"‚úÖ Merged data saved to:")
print(f"   {merged_train_file}")
print(f"   {merged_val_file}")

---
## Step 7: Training (H100 GPU Required)

‚ö†Ô∏è **Bu a≈üamada Runtime'ƒ± H100 GPU'ya deƒüi≈ütir!**

Runtime ‚Üí Change runtime type ‚Üí A100 or H100

In [None]:
# Verify GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv

In [None]:
# Install training dependencies
!pip install -q transformers datasets peft accelerate bitsandbytes trl

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset

# Config
MODEL_NAME = "Qwen/Qwen2.5-Coder-14B-Instruct"
OUTPUT_DIR = "/content/drive/MyDrive/zerolang/models/snake_v1"

# Load data
train_file = "/content/drive/MyDrive/zerolang/data/train_chatml_snake.jsonl"
val_file = "/content/drive/MyDrive/zerolang/data/val_chatml_snake.jsonl"

if not os.path.exists(train_file):
    train_file = "/content/drive/MyDrive/zerolang/snake_data/train_snake.jsonl"
    val_file = "/content/drive/MyDrive/zerolang/snake_data/val_snake.jsonl"

print(f"Loading data from: {train_file}")

dataset = load_dataset("json", data_files={
    "train": train_file,
    "validation": val_file
})

print(f"Train: {len(dataset['train'])} examples")
print(f"Val: {len(dataset['validation'])} examples")

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

# Load model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)
print("‚úÖ Model loaded")

In [None]:
# LoRA config
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# Format function
def format_chat(example):
    return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}

formatted_dataset = dataset.map(format_chat)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=5,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    bf16=True,
    gradient_checkpointing=True,
    report_to="none"
)

# Trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset["train"],
    eval_dataset=formatted_dataset["validation"],
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048
)

In [None]:
# Train!
print("Starting Training")
trainer.train()

In [None]:
# Save
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Model saved to {OUTPUT_DIR}")

---
## Step 8: Test the Model

In [None]:
# Test prompts
test_prompts = [
    "Implement: int pack_point(int x, int y) - packs two coords into one int",
    "Implement: int get_dx(int direction) - returns x delta for direction",
    "Implement: int in_bounds(int x, int y, int width, int height)",
    "Implement: int move_snake_head(int head, int direction, int width, int height)",
    "Implement: int lcg_next(int seed) - random number generator"
]

def generate(prompt):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response

# Test
for prompt in test_prompts:
    print(f"\n{'='*60}")
    print(f"Prompt: {prompt}")
    print(f"{'='*60}")
    result = generate(prompt)
    print(result[:500] + "..." if len(result) > 500 else result)

---
## Step 9: Deploy as API (Gradio)

In [None]:
!pip install -q gradio

In [None]:
import gradio as gr

def predict(instruction):
    return generate(instruction)

demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Instruction", placeholder="Implement: int add(int a, int b)"),
    outputs=gr.Textbox(label="WAT Code", lines=20),
    title="üêç ZeroLang Snake Edition",
    description="Generate WebAssembly from natural language - Snake Game Functions"
)

demo.launch(share=True)