# Gemma 3 270M Training on NVIDIA Jetson Orin Nano

This notebook is optimized for running locally on an NVIDIA Jetson Orin Nano device. It includes adjustments for limited resources (8GB RAM) and assumes the dataset is in the repository root.

## Prerequisites
- JetPack 6.x installed on Jetson
- PyTorch installed with CUDA support
- Dataset file: `net_5ghz_sft.jsonl` in the repository root

## Setup Instructions
1. Clone the repository: `git clone https://github.com/ryankyrillos/gemma3-270m-training.git`
2. Navigate to the directory: `cd gemma3-270m-training`
3. Ensure `net_5ghz_sft.jsonl` is in the root directory
4. Run: `jupyter notebook`
5. Open this notebook and execute cells in order

In [None]:
# Install required libraries (run if not already installed)
!pip install transformers datasets tiktoken tqdm numpy

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from tqdm.auto import tqdm
from contextlib import nullcontext
import os

# Check device availability
device = "cpu"  # Forced to CPU since CUDA is not available
print(f"Using device: {device}")
if device == "cpu":
    print(f"CPU cores available: {os.cpu_count()}")

## Step 1: Load and Prepare Dataset

The dataset is loaded from the local repository directory.

In [None]:
from datasets import load_dataset

# Load dataset from repository root
dataset_path = "./net_5ghz_sft.jsonl"
if os.path.exists(dataset_path):
    ds = load_dataset("json", data_files=dataset_path)
    ds = ds['train'].train_test_split(test_size=0.1)
    ds['validation'] = ds.pop('test')
    print("Dataset loaded successfully")
    print(f"Train size: {len(ds['train'])}, Validation size: {len(ds['validation'])}")
else:
    print(f"Dataset not found at {dataset_path}. Please ensure net_5ghz_sft.jsonl is in the repository root.")

## Step 2: Tokenize Dataset

Tokenize the dataset and create binary files for efficient training.

In [None]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")

def process(example):
    messages = example['messages']
    text = ""
    for msg in messages:
        text += msg['role'] + ": " + msg['content'] + "\n"
    ids = enc.encode_ordinary(text)
    out = {'ids': ids, 'len': len(ids)}
    return out

if not os.path.exists("train.bin"):
    tokenized = ds.map(
        process,
        remove_columns=['messages'],
        desc="tokenizing the splits",
        num_proc=4,  # Reduced for Jetson
    )
    for split, dset in tokenized.items():
        arr_len = np.sum(dset['len'], dtype=np.uint64)
        filename = f'{split}.bin'
        dtype = np.uint16
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
        total_batches = 512  # Reduced for Jetson
        
        idx = 0
        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
            batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
            arr_batch = np.concatenate(batch['ids'])
            arr[idx : idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)
        arr.flush()
    print("Tokenization completed")
else:
    print("Binary files already exist, skipping tokenization")

In [None]:
# Load the train and validation data
train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('validation.bin', dtype=np.uint16, mode='r')

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
    return x, y

def estimate_loss(model):
    out = {}
    model.eval()
    with torch.inference_mode():
        for split in ['train', 'validation']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(split)
                X, Y = X.to(device), Y.to(device)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
    model.train()
    return out

## Step 3: Define Model Architecture

Gemma 3 270M model definition optimized for Jetson.

In [None]:
# Model definition

class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.zeros(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None

    def forward(self, x):
        return x * (self.scale + 1) / (x.norm(2, dim=-1, keepdim=True) + self.eps) + (self.shift if self.shift is not None else 0)

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = nn.Linear(cfg["emb_dim"], 3 * cfg["emb_dim"])
        self.proj = nn.Linear(cfg["emb_dim"], cfg["emb_dim"])
        self.ff = nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"])
        self.ff_proj = nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
        self.norm1 = RMSNorm(cfg["emb_dim"])
        self.norm2 = RMSNorm(cfg["emb_dim"])

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.attn(self.norm1(x))
        q, k, v = qkv.split(C, dim=-1)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        x = x + self.proj(y)
        x = x + self.ff_proj(F.gelu(self.ff(self.norm2(x))))
        return x

class Gemma3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.final_norm = RMSNorm(cfg["emb_dim"], eps=cfg["norm_eps"])
        self.lm_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

    def forward(self, idx, targets=None):
        B, T = idx.size()
        tok_emb = self.tok_emb(idx)
        x = tok_emb
        for block in self.blocks:
            x = block(x)
        x = self.final_norm(x)
        logits = self.lm_head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            loss = None
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float("-inf")
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

GEMMA3_CONFIG_270M = {
    "vocab_size": 50257,
    "emb_dim": 512,
    "n_layers": 6,
    "n_heads": 8,
    "head_dim": 64,
    "context_length": 1024,
    "norm_eps": 1e-6,
    "rope_base": 10000,
    "dtype": torch.float32
}

torch.manual_seed(123)
model = Gemma3Model(GEMMA3_CONFIG_270M)

## Step 4: Training Configuration

Optimized parameters for Jetson Orin Nano's limited resources.

In [None]:
# Training Config - Optimized for CPU on Jetson
learning_rate = 1e-4
max_iters = 100  # Further reduced for testing
warmup_steps = 10
min_lr = 5e-4
eval_iters = 10
batch_size = 1  # Reduced for CPU
block_size = 16  # Reduced for CPU
gradient_accumulation_steps = 2  # Reduced for CPU

device = "cpu"
device_type = 'cpu'
dtype = 'float32'  # Use float32 for CPU
ptdtype = torch.float32
ctx = nullcontext()

torch.set_default_device(device)
torch.manual_seed(42)

## Step 5: Training Loop

Execute the training loop with progress monitoring.

In [None]:
# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1, eps=1e-9)
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=warmup_steps)
scheduler_decay = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters - warmup_steps, eta_min=min_lr)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], milestones=[warmup_steps])

# Training loop
best_val_loss = float('inf')
best_model_params_path = "best_model_params.pt"
train_loss_list, validation_loss_list = [], []

model = model.to(device)

for epoch in tqdm(range(max_iters)):
    if epoch % eval_iters == 0 and epoch != 0:
        losses = estimate_loss(model)
        print(f"Epoch {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.5f}")
        train_loss_list.append(losses['train'])
        validation_loss_list.append(losses['val'])
        
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            torch.save(model.state_dict(), best_model_params_path)
    
    X, y = get_batch("train")
    X, y = X.to(device), y.to(device)
    
    with ctx:
        logits, loss = model(X, y)
        loss = loss / gradient_accumulation_steps
        loss.backward()
    
    if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
    scheduler.step()

print("Training completed!")

## Step 6: Generate Text

Test the trained model by generating text.

In [None]:
# Load best model
model.load_state_dict(torch.load(best_model_params_path))
model.eval()

# Generate sample text
idx = torch.tensor([[enc.encode("System: You are a wireless network engineer.")]], dtype=torch.long).to(device)
generated = model.generate(idx, max_new_tokens=50, temperature=0.8)
generated_text = enc.decode(generated[0].tolist())
print("Generated text:")
print(generated_text)