# Gemma 3 270M Training on NVIDIA Jetson Orin Nano

This notebook is optimized for running locally on an NVIDIA Jetson Orin Nano device. It includes adjustments for limited resources (8GB RAM) and assumes the dataset is in the repository root.

## Prerequisites
- JetPack 6.x installed on Jetson
- PyTorch installed with CUDA support
- Dataset file: `net_5ghz_sft.jsonl` in the repository root

## Setup Instructions
1. Clone the repository: `git clone https://github.com/ryankyrillos/gemma3-270m-training.git`
2. Navigate to the directory: `cd gemma3-270m-training`
3. Ensure `net_5ghz_sft.jsonl` is in the root directory
4. Run: `jupyter notebook`
5. Open this notebook and execute cells in order

In [None]:
# Install required libraries (run if not already installed)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers datasets tiktoken tqdm numpy

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from tqdm.auto import tqdm
from contextlib import nullcontext
import os

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 1: Load and Prepare Dataset

The dataset is loaded from the local repository directory.

In [None]:
from datasets import load_dataset

# Load dataset from repository root
dataset_path = "./net_5ghz_sft.jsonl"
if os.path.exists(dataset_path):
    ds = load_dataset("json", data_files=dataset_path)
    ds = ds['train'].train_test_split(test_size=0.1)
    ds['validation'] = ds.pop('test')
    print("Dataset loaded successfully")
    print(f"Train size: {len(ds['train'])}, Validation size: {len(ds['validation'])}")
else:
    print(f"Dataset not found at {dataset_path}. Please ensure net_5ghz_sft.jsonl is in the repository root.")

## Step 2: Tokenize Dataset

Tokenize the dataset and create binary files for efficient training.

In [None]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")

def process(example):
    messages = example['messages']
    text = ""
    for msg in messages:
        text += msg['role'] + ": " + msg['content'] + "\n"
    ids = enc.encode_ordinary(text)
    out = {'ids': ids, 'len': len(ids)}
    return out

if not os.path.exists("train.bin"):
    tokenized = ds.map(
        process,
        remove_columns=['messages'],
        desc="tokenizing the splits",
        num_proc=4,  # Reduced for Jetson
    )
    for split, dset in tokenized.items():
        arr_len = np.sum(dset['len'], dtype=np.uint64)
        filename = f'{split}.bin'
        dtype = np.uint16
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
        total_batches = 512  # Reduced for Jetson
        
        idx = 0
        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
            batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
            arr_batch = np.concatenate(batch['ids'])
            arr[idx : idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)
        arr.flush()
    print("Tokenization completed")
else:
    print("Binary files already exist, skipping tokenization")

## Step 3: Define Model Architecture

Gemma 3 270M model definition optimized for Jetson.

In [None]:
# Model definition (simplified for brevity - include full Gemma3Model class here)
# ... (Insert the full model code from the original notebook)

torch.manual_seed(123)
model = Gemma3Model(GEMMA3_CONFIG_270M)

## Step 4: Training Configuration

Optimized parameters for Jetson Orin Nano's limited resources.

In [None]:
# Training Config - Optimized for Jetson
learning_rate = 1e-4
max_iters = 50000  # Reduced for testing
warmup_steps = 500
min_lr = 5e-4
eval_iters = 200
batch_size = 4  # Reduced for Jetson
block_size = 64  # Reduced for Jetson
gradient_accumulation_steps = 8  # Reduced for Jetson

device = "cuda" if torch.cuda.is_available() else "cpu"
device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

torch.set_default_device(device)
torch.manual_seed(42)

## Step 5: Training Loop

Execute the training loop with progress monitoring.

In [None]:
# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1, eps=1e-9)
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=warmup_steps)
scheduler_decay = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters - warmup_steps, eta_min=min_lr)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], milestones=[warmup_steps])
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# Training loop
best_val_loss = float('inf')
best_model_params_path = "best_model_params.pt"
train_loss_list, validation_loss_list = [], []

model = model.to(device)

for epoch in tqdm(range(max_iters)):
    if epoch % eval_iters == 0 and epoch != 0:
        losses = estimate_loss(model)
        print(f"Epoch {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.5f}")
        train_loss_list.append(losses['train'])
        validation_loss_list.append(losses['val'])
        
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            torch.save(model.state_dict(), best_model_params_path)
    
    X, y = get_batch("train")
    X, y = X.to(device), y.to(device)
    
    with ctx:
        logits, loss = model(X, y)
        loss = loss / gradient_accumulation_steps
        scaler.scale(loss).backward()
    
    if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
    scheduler.step()

print("Training completed!")

## Step 6: Generate Text

Test the trained model by generating text.

In [None]:
# Load best model
model.load_state_dict(torch.load(best_model_params_path))
model.eval()

# Generate sample text
idx = torch.tensor([[enc.encode("System: You are a wireless network engineer.")]], dtype=torch.long).to(device)
generated = model.generate(idx, max_new_tokens=50, temperature=0.8)
generated_text = enc.decode(generated[0].tolist())
print("Generated text:")
print(generated_text)