# TODO:
- try lora
- try reward model
- ppo

### 📊 Training Configuration - Fast Experimentation

For quick experimentation, we'll use a subset of the dataset. Change `USE_FULL_DATASET` to `True` when you want to train on the complete dataset.

In [None]:
# 🔧 EXPERIMENT CONFIGURATION
USE_FULL_DATASET = True  # Set to True for full training, False for quick experiments
EXPERIMENT_SIZE = 200    # Use only this many samples for quick testing

print(f"🎯 Configuration:")
print(f"  USE_FULL_DATASET: {USE_FULL_DATASET}")
if not USE_FULL_DATASET:
    print(f"  EXPERIMENT_SIZE: {EXPERIMENT_SIZE} samples")
    print(f"  Estimated training time: ~{EXPERIMENT_SIZE // 8 * 0.5 / 60:.1f} minutes per epoch")

🎯 Configuration:
  USE_FULL_DATASET: False
  EXPERIMENT_SIZE: 200 samples
  Estimated training time: ~0.2 minutes per epoch


### Sort mem
Jax shotguns a bunch of memoery on the GPU but it doesn't need it

In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [3]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.11 GB
  Total: 8.00 GB
  Utilization: 1.3%


### Define dataset class

In [4]:
from datasets import load_dataset
import numpy as np  # or jax.numpy as jnp if needed
import pandas as pd
from pathlib import Path

class TLDRDataset:
    def __init__(self, data_dir, tokenizer, split, max_length=550):
        """
        Load TLDR dataset from local parquet files.
        
        Args:
            data_dir: Path to directory containing parquet files
            tokenizer: Tokenizer to use
            split: 'train', 'valid', or 'test'
            max_length: Maximum sequence length
        """
        # Load the parquet file
        parquet_file = Path(data_dir) / f"tldr_{split}.parquet"
        if not parquet_file.exists():
            raise FileNotFoundError(f"Dataset file not found: {parquet_file}")
        
        df = pd.read_parquet(parquet_file)
        
        # Combine prompt and label for training (teacher forcing)
        self.examples = [row["prompt"] + row["label"] for _, row in df.iterrows()]
        
        # Limit validation set size for faster iteration
        if "valid" in split:
            self.examples = self.examples[:2000]
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"Loaded {len(self.examples)} examples from {parquet_file}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        # Tokenize the text
        enc = self.tokenizer(
            self.examples[idx],
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
        )
        return {
            "input_ids": np.array(enc["input_ids"], dtype=np.int32),
            "attention_mask": np.array(enc["attention_mask"], dtype=np.int32),
            "labels": np.array(enc["input_ids"], dtype=np.int32),  # teacher forcing
        }


  from .autonotebook import tqdm as notebook_tqdm


### Load model and tokeniser
Stick to gpt2 now for compatibility, then move to qwen. GPT2 = 124m params, 500mb. qwen0.6b = 550m params, 2gb. So beware 3x qwen0.6b on my 8gb gpu might tank its memory.

In [5]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# 1. Tokenizer is identical
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Flax (JAX) model
#    .from_pretrained returns a FlaxAutoModelForCausalLM whose weights live in model.params
model = FlaxAutoModelForCausalLM.from_pretrained("gpt2", dtype=jnp.float16) # Back to your working config

# 3. If you’ve added new tokens, resize just like in PyTorch:
#    model = model.resize_token_embeddings(len(tokenizer))

# 4. Make sure padding is configured
model.config.pad_token_id = tokenizer.eos_token_id

# 5. Pull out the parameter dict for training
params = model.params

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


In [6]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.65 GB
  Total: 8.00 GB
  Utilization: 8.1%


### Inspect model 

In [7]:
# Inspect JAX/Flax model architecture and parameters
import jax
from jax.tree_util import tree_map
from flax.core import freeze

print("🔍 JAX/Flax Model Inspection")
print("=" * 50)

# 1. Basic model info
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")
print(f"Vocab size: {model.config.vocab_size}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of attention heads: {model.config.n_head}")

print("\n📊 Parameter Analysis")
print("=" * 30)

# 2. Count parameters (JAX way)
def count_params(params):
    """Count total parameters in a JAX parameter tree"""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

total_params = count_params(params)
print(f"Total parameters: {total_params:,}")
print(f"Total parameters (millions): {total_params / 1_000_000:.2f}M")

# 3. Inspect parameter structure
print("\n🏗️ Parameter Structure")
print("=" * 25)

def print_param_shapes(params, prefix=""):
    """Recursively print parameter shapes"""
    if isinstance(params, dict):
        for key, value in params.items():
            print_param_shapes(value, f"{prefix}.{key}" if prefix else key)
    else:
        print(f"{prefix}: {params.shape} ({params.dtype})")

print_param_shapes(params)

# 4. Memory usage estimation
def estimate_memory(params):
    """Estimate memory usage in MB"""
    total_bytes = sum(x.nbytes for x in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

memory_mb = estimate_memory(params)
print(f"\n💾 Estimated memory usage: {memory_mb:.2f} MB")

# 5. Compare with different dtypes
print(f"\n🔢 Memory usage by dtype:")
print(f"  float32: {memory_mb:.2f} MB")
print(f"  float16: {memory_mb / 2:.2f} MB") 
print(f"  bfloat16: {memory_mb / 2:.2f} MB")

🔍 JAX/Flax Model Inspection
Model type: <class 'transformers.models.gpt2.modeling_flax_gpt2.FlaxGPT2LMHeadModel'>
Model config: GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "pad_token_id": 50256,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.53.2",
  "use_cache": true,
  "vocab

### Do stuff

In [8]:
# Create datasets
train_dataset = TLDRDataset("../data", tokenizer, split="train")
val_dataset   = TLDRDataset("../data", tokenizer, split="valid")

# Apply experiment configuration
if not USE_FULL_DATASET:
    print(f"🧪 Using experiment mode: limiting to {EXPERIMENT_SIZE} samples")
    # Create a subset by limiting the examples list
    train_dataset.examples = train_dataset.examples[:EXPERIMENT_SIZE]
    # Keep validation dataset smaller too for quick validation
    val_dataset.examples = val_dataset.examples[:min(100, len(val_dataset.examples))]

print(f"📊 Dataset loaded:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

if not USE_FULL_DATASET:
    estimated_time = (len(train_dataset) // 8) * 0.5 / 60  # batch_size=8, 0.5s per step
    print(f"  ⏱️  Estimated training time per epoch: ~{estimated_time:.1f} minutes")

# Preview a sample
sample = train_dataset[0]
print(f"\n📝 Sample data shapes:")
print(f"  input_ids: {sample['input_ids'].shape}")
print(f"  attention_mask: {sample['attention_mask'].shape}")
print(f"  labels: {sample['labels'].shape}")

# Preview actual content
print(f"\n📋 Sample content:")
print(f"  First 100 chars of tokenized text: {train_dataset.examples[0][:100]}...")
print(f"  Input IDs (first 10): {sample['input_ids'][:10]}")
print(f"  Labels (first 10): {sample['labels'][:10]}")


Loaded 116722 examples from ../data/tldr_train.parquet
Loaded 2000 examples from ../data/tldr_valid.parquet
🧪 Using experiment mode: limiting to 200 samples
📊 Dataset loaded:
  Training samples: 200
  Validation samples: 100
  ⏱️  Estimated training time per epoch: ~0.2 minutes

📝 Sample data shapes:
  input_ids: (550,)
  attention_mask: (550,)
  labels: (550,)

📋 Sample content:
  First 100 chars of tokenized text: SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or...
  Input IDs (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]
  Labels (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]


### Train

In [9]:
import jax.numpy as jnp
from jax import random

def create_data_loader(dataset, batch_size, shuffle=True):
    # Pre-tokenize everything into arrays (your original working approach)
    all_data = [dataset[i] for i in range(len(dataset))]
    
    # Convert to JAX arrays upfront
    input_ids = jnp.array([x["input_ids"] for x in all_data])
    attention_mask = jnp.array([x["attention_mask"] for x in all_data])
    labels = jnp.array([x["labels"] for x in all_data])
    
    num_batches = len(dataset) // batch_size  # Drop incomplete batches
    
    if shuffle:
        indices = random.permutation(random.PRNGKey(42), len(dataset))
        input_ids = input_ids[indices]
        attention_mask = attention_mask[indices]
        labels = labels[indices]
    
    # Yield consistent-shaped batches
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        yield {
            "input_ids": input_ids[start_idx:end_idx],
            "attention_mask": attention_mask[start_idx:end_idx], 
            "labels": labels[start_idx:end_idx]
        }

# Example usage
print("🔄 Creating data loaders...")

# Create data loaders
BATCH_SIZE = 8  # Back to your working batch size

train_loader = create_data_loader(train_dataset, BATCH_SIZE, shuffle=True)
val_loader = create_data_loader(val_dataset, BATCH_SIZE, shuffle=False)

# Test the data loader
print("📦 Testing batch loading...")
batch = next(iter(train_loader))

print(f"Batch shapes:")
print(f"  input_ids: {batch['input_ids'].shape}")
print(f"  attention_mask: {batch['attention_mask'].shape}")
print(f"  labels: {batch['labels'].shape}")
print(f"  Data type: {batch['input_ids'].dtype}")

# Show how many batches we'll have
train_batches = (len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
val_batches = (len(val_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"\n📊 Batch counts:")
print(f"  Training batches: {train_batches}")
print(f"  Validation batches: {val_batches}")

print("\n✅ Data loader ready for JAX training!")

🔄 Creating data loaders...
📦 Testing batch loading...
Batch shapes:
  input_ids: (8, 550)
  attention_mask: (8, 550)
  labels: (8, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 25
  Validation batches: 13

✅ Data loader ready for JAX training!
Batch shapes:
  input_ids: (8, 550)
  attention_mask: (8, 550)
  labels: (8, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 25
  Validation batches: 13

✅ Data loader ready for JAX training!


In [10]:
import optax
from flax.training import train_state

# Initialize the optimizer
tx = optax.adam(1e-5)

# What is state? 
# It holds the model parameters, optimizer state, and any auxiliary data needed for training.
state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

### 🚀 Training Speed Summary

**Understanding your training times:**

- **Full dataset (116,722 samples)**: ~1.8 hours per epoch (14,590 batches × 0.45s each)
- **Experiment mode (1,000 samples)**: ~1 minute per epoch (125 batches × 0.45s each)

**First step is always slow (~40s)** due to JAX JIT compilation, then it runs at normal speed.

**To switch between modes:**
- `USE_FULL_DATASET = False` → Quick experiments (1,000 samples)
- `USE_FULL_DATASET = True` → Full training (116,722 samples)

### 💾 Model Checkpointing

We'll use Flax's built-in checkpointing system for clean, versioned model saving.

In [11]:
import os
from pathlib import Path
from flax.training import checkpoints
import json

# 📁 Checkpoint configuration
CHECKPOINT_DIR = "../models"  # Correct relative path from ben_dev/misc/ to ben_dev/models/
MAX_CHECKPOINTS = 3  # Keep only the 3 most recent checkpoints

def save_checkpoint(state, epoch, checkpoint_dir=CHECKPOINT_DIR, max_to_keep=MAX_CHECKPOINTS):
    """
    Save model checkpoint using Flax's built-in checkpointing.
    
    Args:
        state: TrainState containing params, opt_state, and step
        epoch: Current epoch number 
        checkpoint_dir: Directory to save checkpoints (can be relative)
        max_to_keep: Maximum number of checkpoints to keep (auto-cleanup)
    """
    # Convert relative path to absolute (Flax requires absolute paths)
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    # Ensure checkpoint directory exists
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Save the complete TrainState (params + optimizer state + step)
    # Flax automatically handles versioning with step numbers
    checkpoints.save_checkpoint(
        ckpt_dir=str(checkpoint_dir),  # Convert Path to string for Flax
        target=state,  # This saves the entire TrainState
        step=state.step,  # Use training step for versioning
        keep=max_to_keep,  # Auto-cleanup: keep only last N checkpoints
        overwrite=True  # Allow overwriting if step already exists
    )
    
    # Also save model config for easy loading later
    config_path = checkpoint_dir / "model_config.json"
    if not config_path.exists():
        config_dict = {
            "model_type": "gpt2",
            "vocab_size": model.config.vocab_size,
            "n_embd": model.config.n_embd,
            "n_layer": model.config.n_layer,
            "n_head": model.config.n_head,
            "max_length": 550,  # Your max sequence length
            "dtype": "float16"  # Your model dtype
        }
        with open(config_path, 'w') as f:
            json.dump(config_dict, f, indent=2)
    
    print(f"💾 Checkpoint saved: epoch {epoch}, step {state.step}")
    print(f"   📂 Location: {checkpoint_dir}/checkpoint_{state.step}")


print(f"📂 Checkpoints will be saved to: {CHECKPOINT_DIR}")
print(f"🔄 Will keep {MAX_CHECKPOINTS} most recent checkpoints")



📂 Checkpoints will be saved to: ../models
🔄 Will keep 3 most recent checkpoints


In [12]:
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from typing import Any, Dict, Tuple
from tqdm import trange, tqdm
import time
step_times = []


# Define the training step function
@jax.jit
def train_step(
    state: train_state.TrainState, 
    batch: Dict[str, jnp.ndarray],
    dropout_rng: jax.random.PRNGKey
) -> Tuple[train_state.TrainState, jnp.ndarray]:
    """
    Perform a single training step (forward, loss, backward, update).
    
    Args:
        state: TrainState containing params & optimizer state.
        batch: Dict with keys "input_ids", "attention_mask", "labels" of shape (B, L).
    
    Returns:
        new_state: Updated TrainState after applying gradients.
        loss: Scalar loss for this batch.
    """
    
    def loss_fn(params: Any) -> jnp.ndarray:
        # Forward pass: get logits [batch, seq_len, vocab_size]
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            params=params,
            train=True,
            dropout_rng=dropout_rng  # Use dropout RNG for training
        )
        logits = outputs.logits
        
        # Causal LM shift: predict token t given inputs up to t-1
        shift_logits = logits[..., :-1, :]           # drop last logit
        shift_labels = batch["labels"][..., 1:]       # drop first label
        
        # Compute per-token cross-entropy
        loss = optax.softmax_cross_entropy_with_integer_labels(
            shift_logits, shift_labels
        )  # shape: (batch, seq_len-1)
        
        # Mask out padding tokens from loss
        mask = batch["attention_mask"][..., 1:]       # same shift as labels
        loss = jnp.sum(loss * mask) / jnp.sum(mask)   # mean over non-pad tokens
        
        return loss

    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    # Apply gradients to update parameters
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

# Add memory monitoring and progress tracking to training
def train_loop(
    state: train_state.TrainState,
    train_loader: Any,
    num_epochs: int = 3,
    rng_key: jnp.ndarray = None
) -> train_state.TrainState:
    """
    Training loop with memory monitoring and progress tracking.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(42)
    
    # Calculate total steps across all epochs
    train_batches = (len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
    total_steps = train_batches * num_epochs

    # Single progress bar for all training
    pbar = tqdm(
        total=total_steps,
        desc="🚀 Training",
        ncols=120,
        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
    )
    
    global_step = 0

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_steps = 0
        
        for batch in train_loader:
            # Split RNG key for this step
            rng_key, dropout_rng = random.split(rng_key)

            # Single training step
            step_start = time.time()
            state, loss = train_step(state, batch, dropout_rng)
            step_time = time.time() - step_start
            
            # Process loss and increment counters
            current_loss = float(loss)
            epoch_loss += current_loss
            epoch_steps += 1
            global_step += 1
            
            # Update progress bar
            pbar.set_postfix({
                'epoch': f'{epoch}/{num_epochs}',
                'loss': f'{current_loss:.4f}',
                'avg_loss': f'{epoch_loss/epoch_steps:.4f}',
                'step_time': f'{step_time:.2f}s'
            })
            pbar.update(1)
            
            # Optional: Print timing for first few steps
            if global_step <= 10:
                tqdm.write(f"⏱️  Global step {global_step} took: {step_time:.2f}s")
        
        # Print epoch summary
        avg_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0.0
        tqdm.write(f"✅ Epoch {epoch} complete — avg loss: {avg_loss:.4f}")
        
        # 💾 Save checkpoint at end of each epoch
        save_checkpoint(state, epoch)
        tqdm.write(f"📁 Checkpoint saved for epoch {epoch}")
    
    pbar.close()
    return state

# Use this instead of the regular train_loop
num_epochs = 5  # Start with fewer epochs to see the pattern
state = train_loop(state, train_loader, num_epochs=num_epochs)

# After training, `state.params` holds your fine-tuned model weights.



🚀 Training:   0%|                                                                              | 0/125 [00:00<?, ?it/s]2025-07-17 18:38:31.330716: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 3.25GiB (3492387526 bytes) by rematerialization; only reduced to 3.78GiB (4061437030 bytes), down from 5.05GiB (5420351484 bytes) originally
2025-07-17 18:38:31.330716: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 3.25GiB (3492387526 bytes) by rematerialization; only reduced to 3.78GiB (4061437030 bytes), down from 5.05GiB (5420351484 bytes) originally
🚀 Training:   1%|▌                                                                   | 1/125 [00:41<1:26:08, 41.69s/it]

⏱️  Global step 1 took: 41.65s


🚀 Training:   2%|█                                                                     | 2/125 [00:42<36:43, 17.91s/it]

⏱️  Global step 2 took: 1.27s


🚀 Training:   2%|█▋                                                                    | 3/125 [00:43<20:12,  9.94s/it]

⏱️  Global step 3 took: 0.44s


🚀 Training:   3%|██▏                                                                   | 4/125 [00:43<12:29,  6.19s/it]

⏱️  Global step 4 took: 0.45s


🚀 Training:   4%|██▊                                                                   | 5/125 [00:44<08:14,  4.12s/it]

⏱️  Global step 5 took: 0.44s


🚀 Training:   5%|███▎                                                                  | 6/125 [00:44<05:42,  2.88s/it]

⏱️  Global step 6 took: 0.45s


🚀 Training:   6%|███▉                                                                  | 7/125 [00:45<04:06,  2.09s/it]

⏱️  Global step 7 took: 0.45s


🚀 Training:   6%|████▍                                                                 | 8/125 [00:45<03:03,  1.57s/it]

⏱️  Global step 8 took: 0.46s


🚀 Training:   7%|█████                                                                 | 9/125 [00:46<02:21,  1.22s/it]

⏱️  Global step 9 took: 0.46s


🚀 Training:   8%|█████▌                                                               | 10/125 [00:46<01:53,  1.01it/s]

⏱️  Global step 10 took: 0.46s


🚀 Training:  19%|█████████████▏                                                       | 24/125 [00:53<00:47,  2.14it/s]

✅ Epoch 1 complete — avg loss: 3.5143


🚀 Training:  19%|█████████████▏                                                       | 24/125 [00:54<00:47,  2.14it/s]

💾 Checkpoint saved: epoch 1, step 24
   📂 Location: /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models/checkpoint_24
📁 Checkpoint saved for epoch 1
✅ Epoch 2 complete — avg loss: 0.0000


🚀 Training:  19%|█████████████▏                                                       | 24/125 [00:56<00:47,  2.14it/s]

💾 Checkpoint saved: epoch 2, step 24
   📂 Location: /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models/checkpoint_24
📁 Checkpoint saved for epoch 2
✅ Epoch 3 complete — avg loss: 0.0000


🚀 Training:  19%|█████████████▏                                                       | 24/125 [00:58<00:47,  2.14it/s]

💾 Checkpoint saved: epoch 3, step 24
   📂 Location: /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models/checkpoint_24
📁 Checkpoint saved for epoch 3
✅ Epoch 4 complete — avg loss: 0.0000


🚀 Training:  19%|█████████████▏                                                       | 24/125 [01:00<00:47,  2.14it/s]

💾 Checkpoint saved: epoch 4, step 24
   📂 Location: /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models/checkpoint_24
📁 Checkpoint saved for epoch 4
✅ Epoch 5 complete — avg loss: 0.0000


🚀 Training:  19%|█████████████▏                                                       | 24/125 [01:02<04:24,  2.62s/it]

💾 Checkpoint saved: epoch 5, step 24
   📂 Location: /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models/checkpoint_24
📁 Checkpoint saved for epoch 5





### 🔄 Checkpoint Usage Examples

Here's how to use the checkpoint system for different scenarios:

In [17]:
def load_checkpoint(checkpoint_dir=CHECKPOINT_DIR, step=None):
    """
    Load model checkpoint using Flax's built-in checkpointing.
    
    Args:
        checkpoint_dir: Directory containing checkpoints (can be relative)
        step: Specific step to load (None = latest)
    
    Returns:
        restored_state: TrainState with loaded params/optimizer/step
        epoch_info: Dict with loaded step information
    """
    # Convert relative path to absolute (Flax requires absolute paths)
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    if not checkpoint_dir.exists():
        raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
    
    # Create a dummy state for restoration template
    # (Flax needs the structure to know what to restore)
    dummy_state = train_state.TrainState.create(
        apply_fn=model.__call__,
        params=params,  # Use current params as template
        tx=optax.adam(1e-5)  # Use current optimizer as template
    )
    
    # Restore the checkpoint
    restored_state = checkpoints.restore_checkpoint(
        ckpt_dir=str(checkpoint_dir),  # Convert Path to string for Flax
        target=dummy_state,  # Template for structure
        step=step  # None = latest, or specify step number
    )
    
    epoch_info = {
        "step": restored_state.step,
        "approximate_epoch": restored_state.step // ((len(train_dataset) // BATCH_SIZE))
    }
    
    print(f"📥 Checkpoint loaded: step {restored_state.step}")
    print(f"   🔄 Approximate epoch: {epoch_info['approximate_epoch']}")
    
    return restored_state, epoch_info

# 🔍 Utility function to list available checkpoints
def list_checkpoints(checkpoint_dir=CHECKPOINT_DIR):
    """List all available checkpoints in the directory."""
    # Convert relative path to absolute
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    if not checkpoint_dir.exists():
        print(f"No checkpoint directory found: {checkpoint_dir}")
        return []
    
    # Get all checkpoint steps by scanning for checkpoint_* directories
    checkpoint_pattern = checkpoint_dir / "checkpoint_*"
    checkpoint_dirs = list(checkpoint_dir.glob("checkpoint_*"))
    
    if not checkpoint_dirs:
        print(f"No checkpoints found in {checkpoint_dir}")
        return []
    
    # Extract step numbers from directory names
    available_steps = []
    for ckpt_dir in checkpoint_dirs:
        try:
            step = int(ckpt_dir.name.replace("checkpoint_", ""))
            available_steps.append(step)
        except ValueError:
            continue  # Skip invalid checkpoint directory names
    
    # Sort steps numerically
    available_steps.sort()
    
    print(f"📋 Available checkpoints in {checkpoint_dir}:")
    for step in available_steps:
        approximate_epoch = step // ((len(train_dataset) // BATCH_SIZE))
        print(f"   Step {step} (≈ epoch {approximate_epoch})")
    
    return available_steps


In [18]:
# 📋 Example 1: List all available checkpoints
print("=" * 50)
print("🔍 LISTING AVAILABLE CHECKPOINTS")
print("=" * 50)
available_steps = list_checkpoints()

# 📥 Example 2: Load the latest checkpoint (for resuming training)
print("\n" + "=" * 50)
print("📥 LOADING LATEST CHECKPOINT")
print("=" * 50)
try:
    # Uncomment these lines when you have checkpoints saved:
    # restored_state, info = load_checkpoint()
    # print(f"Resumed from step {info['step']}, approximate epoch {info['approximate_epoch']}")
    # state = restored_state  # Use this state to continue training
    print("💡 Uncomment the lines above once you have saved checkpoints!")
except Exception as e:
    print(f"No checkpoints found yet: {e}")

# 📥 Example 3: Load a specific checkpoint by step number
print("\n" + "=" * 50)
print("📥 LOADING SPECIFIC CHECKPOINT")
print("=" * 50)
try:
    # Example: load checkpoint from step 500
    # restored_state, info = load_checkpoint(step=500)
    print("💡 Use load_checkpoint(step=N) to load a specific checkpoint")
except Exception as e:
    print(f"Specific checkpoint example: {e}")

# 🚀 Example 4: Resume training from checkpoint
print("\n" + "=" * 50)
print("🚀 RESUMING TRAINING")
print("=" * 50)
print("""
# To resume training from the latest checkpoint:
restored_state, info = load_checkpoint()
state = restored_state

# Then continue training normally:
state = train_loop(state, train_loader, num_epochs=additional_epochs)

# The training will pick up from where it left off!
""")

print("✅ Checkpoint examples ready to use!")

🔍 LISTING AVAILABLE CHECKPOINTS
📋 Available checkpoints in /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models:
   Step 24 (≈ epoch 0)

📥 LOADING LATEST CHECKPOINT
💡 Uncomment the lines above once you have saved checkpoints!

📥 LOADING SPECIFIC CHECKPOINT
💡 Use load_checkpoint(step=N) to load a specific checkpoint

🚀 RESUMING TRAINING

# To resume training from the latest checkpoint:
restored_state, info = load_checkpoint()
state = restored_state

# Then continue training normally:
state = train_loop(state, train_loader, num_epochs=additional_epochs)

# The training will pick up from where it left off!

✅ Checkpoint examples ready to use!
