# Mamba-MoE 509M - Production Training on Kaggle TPU

**STEP-BY-STEP NOTEBOOK**

Run each cell sequentially. You can download checkpoints at any step!

- **Target**: 10GB data, ~2.5B tokens
- **Model**: 509M parameters
- **Hardware**: TPU v5e-8
- **Time**: ~4-5 hours total

## STEP 1: Setup & Installation

In [None]:
# Clone repo
%cd /kaggle/working
!rm -rf mamba-moe-300m
!git clone https://github.com/rgprince/mamba-moe-300m.git
%cd mamba-moe-300m

# Install dependencies (if needed)
!pip install -q jax[tpu] flax optax chex einops pyyaml pydantic datasets sentencepiece tqdm -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

print("\n‚úÖ Setup complete!")

## STEP 2: Download & Prepare Training Data (10GB)

In [None]:
from datasets import load_dataset
from tqdm import tqdm
from pathlib import Path

print("="*70)
print("DOWNLOADING 10GB HIGH-QUALITY DATA")
print("="*70)
print("Streaming mode - only downloads what we need!\n")

# Target: 10GB total = 2M samples from FineWeb-Edu
datasets_config = [
    {
        "name": "HuggingFaceFW/fineweb-edu",
        "split": "train",
        "samples": 2_000_000,  # 2M samples = ~9.5-10GB
        "description": "FineWeb-Edu: Educational web content"
    },
]

all_texts = []
total_chars = 0

for ds_config in datasets_config:
    print(f"\nLoading {ds_config['description']}...")
    print(f"Streaming {ds_config['samples']:,} samples...")
    
    try:
        dataset = load_dataset(
            ds_config['name'],
            split=ds_config['split'],
            streaming=True,
            trust_remote_code=False
        )
        
        samples_collected = 0
        dataset_chars = 0
        
        pbar = tqdm(total=ds_config['samples'], desc=f"  {ds_config['name'].split('/')[1][:20]}")
        
        for item in dataset:
            text = item.get('text') or item.get('content') or ''
            
            if len(text) > 100:
                all_texts.append(text)
                dataset_chars += len(text)
                samples_collected += 1
                pbar.update(1)
                
                if samples_collected >= ds_config['samples']:
                    break
        
        pbar.close()
        
        total_chars += dataset_chars
        
        print(f"‚úì Collected {samples_collected:,} samples")
        print(f"‚úì Size: {dataset_chars/1e6:.1f}MB ({dataset_chars/1e9:.2f}GB)")
        
    except Exception as e:
        print(f"‚ö† Failed: {e}")

# Combine all text
combined_text = "\n\n".join(all_texts)

print(f"\n{'='*70}")
print(f"DATASET SUMMARY")
print(f"{'='*70}")
print(f"Total samples: {len(all_texts):,}")
print(f"Total size: {len(combined_text)/1e6:.1f}MB ({len(combined_text)/1e9:.2f}GB)")
print(f"Estimated tokens: ~{len(combined_text)/4/1e6:.1f}M tokens")
print(f"\nMemorization check:")
print(f"  Model params: 509M")
print(f"  Data tokens: ~{len(combined_text)/4/1e6:.1f}M")
print(f"  Ratio: {(len(combined_text)/4/1e6)/509:.2f} tokens/param")
if (len(combined_text)/4/1e6)/509 > 2:
    print(f"  Status: ‚úÖ GOOD (>2) - No memorization!")
else:
    print(f"  Status: ‚ö†Ô∏è LOW (<2) - May memorize")

print(f"\n‚úÖ Data ready! You can now train the tokenizer.")

## STEP 3: Train Tokenizer (SentencePiece BPE)

**üì• DOWNLOAD CHECKPOINT**: After this step, you can download `data/tokenizer.model` and `data/tokenizer.vocab`

In [None]:
from src.data import SPTokenizer

print("="*70)
print("TRAINING TOKENIZER")
print("="*70)

data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

# Save a 200MB sample for tokenizer training (prevents RAM overflow!)
TOKENIZER_SAMPLE_SIZE = 200_000_000  # 200MB
tokenizer_sample = combined_text[:TOKENIZER_SAMPLE_SIZE]

tokenizer_train_file = data_dir / "tokenizer_sample.txt"
print(f"Saving tokenizer sample ({len(tokenizer_sample)/1e6:.1f}MB)...")
with open(tokenizer_train_file, 'w', encoding='utf-8') as f:
    f.write(tokenizer_sample)

print(f"\nTraining tokenizer (vocab_size=8000)...")
print(f"(Using sample - prevents RAM overflow!)\n")

tokenizer = SPTokenizer.train(
    input_files=[str(tokenizer_train_file)],
    vocab_size=8000,
    model_prefix=str(data_dir / "tokenizer"),
    model_type="bpe",
    input_sentence_size=2_000_000
)

print(f"\n‚úÖ Tokenizer trained!")
print(f"   Vocab size: {tokenizer.vocab_size}")
print(f"   Files saved: data/tokenizer.model, data/tokenizer.vocab")
print(f"\nüì• DOWNLOAD: You can download tokenizer files from data/ folder")

## STEP 4: Tokenize Data & Create Training Batches

In [None]:
import jax.numpy as jnp

print("="*70)
print("TOKENIZING DATA & CREATING BATCHES")
print("="*70)

# Config
BATCH_SIZE = 1
SEQ_LEN = 128
TOTAL_STEPS = 50000

print(f"\nTokenizing FULL {len(combined_text):,} characters...")
print(f"(Using all data for training!)\n")

token_ids = tokenizer.encode(combined_text, add_bos=False, add_eos=False)
print(f"‚úì Tokenized: {len(token_ids):,} tokens")

# Create batches
print(f"\nCreating {TOTAL_STEPS:,} batches...")
batches = []

for i in range(0, len(token_ids) - SEQ_LEN - 1, SEQ_LEN):
    if len(batches) >= TOTAL_STEPS:
        break
    
    batch_input_ids = []
    batch_labels = []
    
    for b in range(BATCH_SIZE):
        offset = i + b * SEQ_LEN
        if offset + SEQ_LEN + 1 > len(token_ids):
            break
            
        input_ids = token_ids[offset:offset + SEQ_LEN]
        labels = token_ids[offset + 1:offset + SEQ_LEN + 1]
        
        batch_input_ids.append(input_ids)
        batch_labels.append(labels)
    
    if len(batch_input_ids) == BATCH_SIZE:
        batches.append({
            'input_ids': jnp.array(batch_input_ids, dtype=jnp.int32),
            'labels': jnp.array(batch_labels, dtype=jnp.int32)
        })

print(f"‚úì Created {len(batches):,} batches")
print(f"  Batch shape: {batches[0]['input_ids'].shape}")
print(f"\n‚úÖ Data ready for training!")

## STEP 5: Load Model (509M parameters)

In [None]:
import jax
from jax import random
from src.model import create_model_from_config, ModelConfig
from src.training import (
    create_train_step,
    create_train_state,
    create_learning_rate_schedule,
    create_optimizer,
    CheckpointManager,
    ConsoleLogger
)

print("="*70)
print("LOADING MODEL")
print("="*70)

config_path = 'configs/model_config.yaml'
model = create_model_from_config(config_path)
model_config = ModelConfig.from_yaml(config_path)

print(f"\nModel: {model_config.name}")
print(f"Layers: {model_config.num_layers}")
print(f"Hidden dim: {model_config.hidden_dim}")

# Initialize parameters
print(f"\nInitializing parameters...")
rng = random.PRNGKey(42)
rng, init_rng, dropout_rng = random.split(rng, 3)

dummy_input = jnp.ones((1, SEQ_LEN), dtype=jnp.int32)
variables = model.init(init_rng, dummy_input, deterministic=True)
params = variables['params']

param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"‚úì Parameters: {param_count / 1e6:.1f}M")

# Setup training
LEARNING_RATE = 3e-4

lr_schedule = create_learning_rate_schedule(
    warmup_steps=100,
    max_learning_rate=LEARNING_RATE,
    total_steps=TOTAL_STEPS,
    schedule_type='cosine'
)

optimizer = create_optimizer(
    learning_rate_fn=lr_schedule,
    weight_decay=0.1,
    max_grad_norm=1.0
)

state = create_train_state(model, params, optimizer, lr_schedule, dropout_rng)
train_step = create_train_step(model, lr_schedule)
train_step = jax.jit(train_step)

print(f"\n‚úÖ Model ready for training!")

## STEP 6: Train! (50,000 steps = ~4 hours)

**üì• DOWNLOAD CHECKPOINTS**: Every 500 steps, checkpoints are saved to `checkpoints/` folder

In [None]:
import time

print("="*70)
print("STARTING TRAINING")
print("="*70)

SAVE_EVERY = 500
LOG_EVERY = 50

# Setup checkpoints
ckpt_dir = Path("checkpoints")
ckpt_dir.mkdir(exist_ok=True)
ckpt_manager = CheckpointManager(
    checkpoint_dir=str(ckpt_dir),
    max_to_keep=3,
    save_interval_steps=SAVE_EVERY
)

print(f"\nConfig:")
print(f"  Total steps: {TOTAL_STEPS:,}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Sequence length: {SEQ_LEN}")
print(f"  Save every: {SAVE_EVERY} steps")
print(f"  Log every: {LOG_EVERY} steps")
print(f"\n(Compiling on first step - takes ~1 min)")
print("="*70 + "\n")

logger = ConsoleLogger(log_interval=LOG_EVERY)
start_time = time.time()
compile_time = None

for step in range(min(TOTAL_STEPS, len(batches))):
    batch = batches[step]
    
    step_start = time.time()
    state, metrics = train_step(state, batch)
    step_time = time.time() - step_start
    
    if step == 0:
        compile_time = step_time
        print(f"‚úì Compilation done ({compile_time:.1f}s)\n")
    
    metrics['step_time'] = step_time
    metrics['tokens_per_sec'] = (BATCH_SIZE * SEQ_LEN) / step_time if step > 0 else 0
    
    # Log
    if step % LOG_EVERY == 0 or step == 0:
        loss = float(metrics['loss'])
        ppl = float(metrics['perplexity'])
        lr = float(metrics['learning_rate'])
        tps = int(metrics['tokens_per_sec'])
        elapsed = time.time() - start_time
        print(f"Step {step:5d} | loss={loss:.4f} ppl={ppl:8.2f} lr={lr:.6f} | {tps:,} tok/s | {elapsed/60:.1f}min")
    
    # Save checkpoint
    if (step + 1) % SAVE_EVERY == 0:
        ckpt_manager.save(state, step, metadata={'loss': float(metrics['loss'])})
        print(f"üì• Checkpoint saved at step {step+1} - You can download from checkpoints/")

total_time = time.time() - start_time

print(f"\n{'='*70}")
print("‚úÖ TRAINING COMPLETE!")
print(f"{'='*70}")
print(f"\nResults:")
print(f"  ‚úì Trained {TOTAL_STEPS:,} steps")
print(f"  ‚úì Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"  ‚úì Final loss: {float(metrics['loss']):.4f}")
print(f"  ‚úì Final perplexity: {float(metrics['perplexity']):.2f}")
print(f"\nüì• DOWNLOAD:")
print(f"  - Tokenizer: data/tokenizer.model")
print(f"  - Checkpoints: checkpoints/")
print(f"  - Logs: (in notebook output)")
print(f"\nüéâ Training complete on {len(combined_text)/1e9:.2f}GB of clean data!")