# TypeScript SLM Training on TPU (Optimized)

This notebook trains the TypeScript-SLM model using Google Colab TPUs.

## Setup Instructions:
1. **Select TPU Runtime**: Go to `Runtime > Change runtime type > TPU v2 or v5e`
2. **Run the installation cell below and RESTART RUNTIME when prompted**
3. **Upload your data**: Upload `train.jsonl` to Colab or mount Google Drive
4. **After restart, skip the installation cell and run all other cells**

**Optimized for TPU to avoid OOM errors with reduced batch size and sequence length.**

In [None]:
# STEP 1: Run this cell ONCE, then Runtime > Restart runtime
# After restart, SKIP this cell and continue from the next cell

# Install PyTorch and PyTorch/XLA for TPU
!pip install torch torch-xla -f https://storage.googleapis.com/libtpu-releases/index.html

# Install other dependencies
!pip install -U transformers datasets peft accelerate trl wandb

print("\n" + "="*70)
print("✓ Installation complete!")
print("="*70)
print("\n⚠️  NEXT STEP: Runtime > Restart runtime")
print("⚠️  After restart, SKIP this cell and run from the next cell\n")
print("="*70)

In [None]:
# Verify TPU setup and import libraries
import os
import torch

print(f"PyTorch version: {torch.__version__}")

# Try to import torch_xla
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    print(f"✓ torch-xla version: {torch_xla.__version__}")
    print(f"✓ TPU device: {device}")
    print("\n✓ TPU is ready to use!")
except ImportError as e:
    print("\n❌ ERROR: torch-xla is not installed!")
    print("\nPlease run the installation commands:")
    print("!pip install torch torch-xla -f https://storage.googleapis.com/libtpu-releases/index.html")
    print("!pip install -U transformers datasets peft accelerate trl wandb")
    print("\nThen: Runtime > Restart runtime")
    raise
except Exception as e:
    print(f"\n❌ ERROR: {e}")
    print("\nMake sure you selected a TPU runtime:")
    print("Runtime > Change runtime type > TPU")
    raise

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer

# Check if TPU is available
try:
    import torch_xla.core.xla_model as xm
    device = xm.xla_device()
    print(f"Using TPU: {device}")
except:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

# Login to Hugging Face (optional)
from huggingface_hub import login
# login(token="YOUR_TOKEN")

In [None]:
# Configuration
model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
new_model = "typescript-slm-1.5b"

# Load dataset
dataset = load_dataset("json", data_files="train.jsonl", split="train")
print(f"Loaded {len(dataset)} samples")

In [None]:
# Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.model_max_length = 1024  # Set max length for the tokenizer

# NOTE: Removed device_map="auto" for TPU compatibility
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    # device_map="auto"  <-- REMOVED
)

# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()

In [None]:
# LoRA Configuration
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

In [None]:
# Training Arguments - OPTIMIZED FOR MEMORY
training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1,      # Reduced from 4 to 1
    gradient_accumulation_steps=16,     # Increased to maintain effective batch size
    optim="adamw_torch",
    save_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=True,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="cosine",
    report_to="wandb",
    gradient_checkpointing=True,        # Explicitly enable
)

In [None]:
# Prepare formatting function for SFTTrainer
def formatting_func(example):
    """Format the dataset for training."""
    return example["text"]

# Trainer - Updated for latest TRL API
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=tokenizer,  # Use processing_class instead of tokenizer
    args=training_arguments,
    formatting_func=formatting_func,
)

In [None]:
# Start Training
trainer.train()

In [None]:
# Save Model
trainer.model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)