# 🚀 Train Spectral Neural Network on Google Colab

**Hardware:** T4 GPU (15GB VRAM) - FREE on Colab!

**What this does:**
- ✅ Clones your code from GitHub (tafolabi009/RNN)
- ✅ Trains Spectral LM on WikiText-103 with proper BPE tokenization
- ✅ Mixed precision training (FP16)
- ✅ Pre-tokenized dataset (cached for speed)
- ✅ Saves checkpoints to Google Drive
- ✅ Tests generation quality

**Expected Results:**
- Training time: ~2 hours (3 epochs, 10K samples)
- Perplexity: ~20-30 (excellent!)
- Text quality: Coherent sentences, no gibberish

**Steps:**
1. Runtime → Change runtime type → T4 GPU ⚡
2. Run all cells in order (⌘/Ctrl + F9)
3. Wait ~2 hours
4. Download checkpoint from Drive

## 📋 Step 1: Check GPU

In [None]:
# Verify T4 GPU is available
!nvidia-smi --query-gpu=name,memory.total --format=csv

## 📦 Step 2: Install Dependencies

In [None]:
# Install required packages
!pip install -q transformers datasets tqdm

## 💾 Step 3: Mount Google Drive (for checkpoints)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
!mkdir -p /content/drive/MyDrive/spectral_checkpoints
print("✅ Google Drive mounted")

## 📥 Step 4: Clone Repository

In [None]:
# Clone your GitHub repo
!git clone https://github.com/tafolabi009/RNN.git

# Change to repo directory
%cd RNN

# List files to verify
!ls -la resonance_nn/

## ✅ Step 5: Verify Setup

In [None]:
# Test imports
import torch
from transformers import GPT2TokenizerFast
from datasets import load_dataset

print(f"✅ PyTorch version: {torch.__version__}")
print(f"✅ CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✅ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Test model import
from resonance_nn.spectral_optimized import SpectralLanguageModel, CONFIGS
print("✅ Spectral model imported successfully!")

## 🧹 Step 6: Clear GPU Memory

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

print("🧹 GPU memory cleared")

## 🎯 Step 7: Configure Training

**Memory-optimized for T4 GPU (15GB)**

In [None]:
# Configuration
CONFIG = {
    'model_size': 'tiny',              # 63M params
    'max_seq_len': 512,                 # Longer sequences than local training
    'batch_size': 4,                    # T4 can handle this
    'gradient_accumulation_steps': 16,  # Effective batch = 64
    'num_epochs': 3,
    'learning_rate': 3e-4,
    'max_samples': 10000,               # Quick training (remove for full dataset)
    'output_dir': '/content/drive/MyDrive/spectral_checkpoints',
    'cache_file': 'tokenized_colab_cache.pkl',
}

print("📋 Training Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## 🚀 Step 8: Start Training

**This uses the ultrafast training script with pre-tokenization.**

**Estimated time: ~2 hours for 3 epochs**

In [None]:
# Check if train_ultrafast.py exists, if not use train_production.py
import os

if os.path.exists('train_ultrafast.py'):
    print("✅ Found train_ultrafast.py")
    training_script = 'train_ultrafast.py'
else:
    print("⚠️  train_ultrafast.py not found, creating it...")
    # We'll create it inline
    training_script = None

# List available training scripts
!ls -la *.py | grep train

### Option A: If train_ultrafast.py exists

In [None]:
# Run ultrafast training (if file exists)
!python train_ultrafast.py

### Option B: Inline Training Code (if scripts missing)

In [None]:
# Inline training code - run this if train_ultrafast.py doesn't exist
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast as torch_autocast, GradScaler as TorchGradScaler
import pickle
import math
from pathlib import Path
from tqdm import tqdm
from transformers import GPT2TokenizerFast
from datasets import load_dataset
from resonance_nn.spectral_optimized import SpectralLanguageModel, CONFIGS

class PreTokenizedDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

def prepare_dataset(tokenizer, max_samples=10000, max_length=512, cache_file='tokenized_cache.pkl'):
    cache_path = Path(cache_file)
    
    if cache_path.exists():
        print(f"📦 Loading cached data...")
        with open(cache_path, 'rb') as f:
            data = pickle.load(f)
        print(f"✅ Loaded {len(data['train'])} train / {len(data['val'])} val")
        return data
    
    print(f"📦 Loading WikiText-103...")
    dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')
    texts = [ex['text'] for ex in dataset if len(ex['text'].strip()) > 50][:max_samples]
    
    print(f"🔄 Tokenizing {len(texts)} texts...")
    tokenized_samples = []
    
    for text in tqdm(texts, desc="Tokenizing"):
        encoded = tokenizer(text, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
        input_ids = encoded['input_ids'].squeeze(0)
        labels = input_ids.clone()
        labels[:-1] = input_ids[1:]
        labels[-1] = 0
        tokenized_samples.append({
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': encoded['attention_mask'].squeeze(0)
        })
    
    split_idx = int(len(tokenized_samples) * 0.9)
    data = {'train': tokenized_samples[:split_idx], 'val': tokenized_samples[split_idx:]}
    
    print(f"💾 Caching...")
    with open(cache_path, 'wb') as f:
        pickle.dump(data, f)
    return data

# Setup
device = torch.device('cuda')
print(f"\n{'='*80}")
print("COLAB TRAINING - SPECTRAL NEURAL NETWORK")
print(f"{'='*80}\n")

# Config
config = CONFIGS['tiny']
config.vocab_size = 50257
config.max_seq_len = 512
config.use_gradient_checkpointing = True

# Tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Data
data = prepare_dataset(tokenizer, max_samples=10000, max_length=512)
train_dataset = PreTokenizedDataset(data['train'])
val_dataset = PreTokenizedDataset(data['val'])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)

print(f"✅ Train batches: {len(train_loader):,}")
print(f"✅ Val batches: {len(val_loader):,}")

# Model
model = SpectralLanguageModel(config).to(device)
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✅ Model: {num_params:.1f}M parameters")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scaler = TorchGradScaler('cuda')

# Training
print(f"\n{'='*80}")
print("STARTING TRAINING")
print(f"{'='*80}\n")

num_epochs = 3
gradient_accumulation_steps = 16
best_val_loss = float('inf')
global_step = 0

for epoch in range(num_epochs):
    print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()
    
    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}")):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        
        with torch_autocast('cuda'):
            logits = model(input_ids, attention_mask)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
            loss = loss / gradient_accumulation_steps
        
        scaler.scale(loss).backward()
        epoch_loss += loss.item() * gradient_accumulation_steps
        
        if (step + 1) % gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            logits = model(input_ids, attention_mask)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    val_perplexity = math.exp(avg_val_loss)
    
    print(f"\n✅ Epoch {epoch + 1} complete!")
    print(f"   Train Loss: {epoch_loss / len(train_loader):.4f}")
    print(f"   Val Loss: {avg_val_loss:.4f}")
    print(f"   Val Perplexity: {val_perplexity:.2f}")
    
    # Save
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        checkpoint_path = Path('/content/drive/MyDrive/spectral_checkpoints/spectral_colab_best.pth')
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'config': config,
            'epoch': epoch,
            'global_step': global_step,
            'best_val_loss': best_val_loss
        }, checkpoint_path)
        print(f"   💾 Best model saved: {checkpoint_path}")

print(f"\n{'='*80}")
print("🎉 TRAINING COMPLETE!")
print(f"{'='*80}")
print(f"Best perplexity: {math.exp(best_val_loss):.2f}")

## 📊 Step 9: Monitor GPU (Optional)

In [None]:
# Check GPU usage during training
!nvidia-smi

## 🧪 Step 10: Test the Trained Model

In [None]:
# Load model
import torch
from transformers import GPT2TokenizerFast
from resonance_nn.spectral_optimized import SpectralLanguageModel

checkpoint_path = '/content/drive/MyDrive/spectral_checkpoints/spectral_colab_best.pth'
checkpoint = torch.load(checkpoint_path, weights_only=False)

config = checkpoint['config']
model = SpectralLanguageModel(config)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.cuda().eval()

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

print("✅ Model loaded!")
print(f"   Best val loss: {checkpoint.get('best_val_loss', 'N/A')}")
print(f"   Training steps: {checkpoint.get('global_step', 'N/A')}")

In [None]:
# Generate text samples
import torch.nn.functional as F

def generate_text(prompt, max_new_tokens=100, temperature=0.7, top_k=40, top_p=0.9):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    generated = input_ids.clone()
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated.size(1) >= model.config.max_seq_len:
                break
            
            logits = model(generated)
            next_token_logits = logits[:, -1, :] / temperature
            
            # Top-k
            if top_k > 0:
                indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                next_token_logits[indices_to_remove] = float('-inf')
            
            # Top-p
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                next_token_logits[indices_to_remove] = float('-inf')
            
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
    
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Test prompts
prompts = [
    "The history of artificial intelligence",
    "In the year 2050,",
    "Machine learning is",
]

print("\n" + "="*80)
print("GENERATED TEXT SAMPLES")
print("="*80)

for prompt in prompts:
    print(f"\n📝 Prompt: {prompt}")
    print("-" * 80)
    text = generate_text(prompt, max_new_tokens=80, temperature=0.7)
    print(text)
    print("="*80)

## 💾 Step 11: Download Checkpoints

In [None]:
# List checkpoints
!ls -lh /content/drive/MyDrive/spectral_checkpoints/

In [None]:
# Download to your computer
from google.colab import files
files.download('/content/drive/MyDrive/spectral_checkpoints/spectral_colab_best.pth')

## 📊 Expected Results

**After 3 epochs on 10K samples (512 token sequences):**
- ✅ Validation perplexity: **20-30** (excellent!)
- ✅ Training time: **~2 hours** on T4 GPU
- ✅ Text quality: Coherent sentences, proper grammar
- ✅ Memory usage: ~10-12GB (safe for T4)
- ✅ No gibberish! (proper BPE tokenization)

**Compare to your local training:**
- Local (256 tokens, GTX 1660 Ti): Perplexity 22.22, lots of repetition
- Colab (512 tokens, T4 GPU): Better perplexity, longer coherent text

**To improve further:**
1. Remove `max_samples=10000` to train on full WikiText-103 (~1.8M texts)
2. Increase epochs to 10-20
3. Try `model_size='small'` (428M params) if Colab Pro
4. Train on OpenWebText or C4 dataset

**Model sizes on T4 (15GB):**
- ✅ `tiny` (63M): Batch size 4, seq len 512 → **~10GB**
- ⚠️ `small` (428M): Batch size 1, seq len 256 → **~14GB** (tight!)
- ❌ `base` (1B+): Needs A100 (40GB) - Colab Pro required

## 🔧 Troubleshooting

### Out of Memory:
```python
# Reduce batch size and sequence length
config.max_seq_len = 256  # Instead of 512
# In DataLoader: batch_size=2  # Instead of 4
```

### Training too slow:
```python
# Reduce dataset size
data = prepare_dataset(tokenizer, max_samples=5000)  # Instead of 10000
```

### Colab disconnects:
- Keep browser tab active
- Click in the notebook occasionally
- Colab free tier disconnects after 12 hours idle
- Checkpoints save to Drive automatically!

### Resume training:
```python
# Load checkpoint and continue
checkpoint = torch.load('/content/drive/MyDrive/spectral_checkpoints/spectral_colab_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Continue training from checkpoint['epoch'] + 1
```