### Sort mem
Jax shotguns a bunch of memoery on the GPU but it doesn't need it

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [2]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.11 GB
  Total: 8.00 GB
  Utilization: 1.3%


### Define dataset class

In [3]:
from datasets import load_dataset
import numpy as np  # or jax.numpy as jnp if needed
import pandas as pd
from pathlib import Path

class TLDRDataset:
    def __init__(self, data_dir, tokenizer, split, max_length=550):
        """
        Load TLDR dataset from local parquet files.
        
        Args:
            data_dir: Path to directory containing parquet files
            tokenizer: Tokenizer to use
            split: 'train', 'valid', or 'test'
            max_length: Maximum sequence length
        """
        # Load the parquet file
        parquet_file = Path(data_dir) / f"tldr_{split}.parquet"
        if not parquet_file.exists():
            raise FileNotFoundError(f"Dataset file not found: {parquet_file}")
        
        df = pd.read_parquet(parquet_file)
        
        # Combine prompt and label for training (teacher forcing)
        self.examples = [row["prompt"] + row["label"] for _, row in df.iterrows()]
        
        # Limit validation set size for faster iteration
        if "valid" in split:
            self.examples = self.examples[:2000]
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"Loaded {len(self.examples)} examples from {parquet_file}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        # Tokenize the text
        enc = self.tokenizer(
            self.examples[idx],
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
        )
        return {
            "input_ids": np.array(enc["input_ids"], dtype=np.int32),
            "attention_mask": np.array(enc["attention_mask"], dtype=np.int32),
            "labels": np.array(enc["input_ids"], dtype=np.int32),  # teacher forcing
        }


  from .autonotebook import tqdm as notebook_tqdm


### Load model and tokeniser
Stick to gpt2 now for compatibility, then move to qwen. GPT2 = 124m params, 500mb. qwen0.6b = 550m params, 2gb. So beware 3x qwen0.6b on my 8gb gpu might tank its memory.

In [4]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# 1. Tokenizer is identical
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Flax (JAX) model
#    .from_pretrained returns a FlaxAutoModelForCausalLM whose weights live in model.params
model = FlaxAutoModelForCausalLM.from_pretrained("gpt2", dtype=jnp.float16)

# 3. If you’ve added new tokens, resize just like in PyTorch:
#    model = model.resize_token_embeddings(len(tokenizer))

# 4. Make sure padding is configured
model.config.pad_token_id = tokenizer.eos_token_id

# 5. Pull out the parameter dict for training
params = model.params

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


In [5]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.65 GB
  Total: 8.00 GB
  Utilization: 8.1%


### Inspect model 

In [6]:
# Inspect JAX/Flax model architecture and parameters
import jax
from jax.tree_util import tree_map
from flax.core import freeze

print("🔍 JAX/Flax Model Inspection")
print("=" * 50)

# 1. Basic model info
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")
print(f"Vocab size: {model.config.vocab_size}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of attention heads: {model.config.n_head}")

print("\n📊 Parameter Analysis")
print("=" * 30)

# 2. Count parameters (JAX way)
def count_params(params):
    """Count total parameters in a JAX parameter tree"""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

total_params = count_params(params)
print(f"Total parameters: {total_params:,}")
print(f"Total parameters (millions): {total_params / 1_000_000:.2f}M")

# 3. Inspect parameter structure
print("\n🏗️ Parameter Structure")
print("=" * 25)

def print_param_shapes(params, prefix=""):
    """Recursively print parameter shapes"""
    if isinstance(params, dict):
        for key, value in params.items():
            print_param_shapes(value, f"{prefix}.{key}" if prefix else key)
    else:
        print(f"{prefix}: {params.shape} ({params.dtype})")

print_param_shapes(params)

# 4. Memory usage estimation
def estimate_memory(params):
    """Estimate memory usage in MB"""
    total_bytes = sum(x.nbytes for x in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

memory_mb = estimate_memory(params)
print(f"\n💾 Estimated memory usage: {memory_mb:.2f} MB")

# 5. Compare with different dtypes
print(f"\n🔢 Memory usage by dtype:")
print(f"  float32: {memory_mb:.2f} MB")
print(f"  float16: {memory_mb / 2:.2f} MB") 
print(f"  bfloat16: {memory_mb / 2:.2f} MB")

🔍 JAX/Flax Model Inspection
Model type: <class 'transformers.models.gpt2.modeling_flax_gpt2.FlaxGPT2LMHeadModel'>
Model config: GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "pad_token_id": 50256,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.53.2",
  "use_cache": true,
  "vocab

### Do stuff

In [7]:
train_dataset = TLDRDataset("../data", tokenizer, split="train")
val_dataset   = TLDRDataset("../data", tokenizer, split="valid")

print(f"📊 Dataset loaded:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

# Preview a sample
sample = train_dataset[0]
print(f"\n📝 Sample data shapes:")
print(f"  input_ids: {sample['input_ids'].shape}")
print(f"  attention_mask: {sample['attention_mask'].shape}")
print(f"  labels: {sample['labels'].shape}")

# Preview actual content
print(f"\n📋 Sample content:")
print(f"  First 100 chars of tokenized text: {train_dataset.examples[0][:100]}...")
print(f"  Input IDs (first 10): {sample['input_ids'][:10]}")
print(f"  Labels (first 10): {sample['labels'][:10]}")


Loaded 116722 examples from ../data/tldr_train.parquet
Loaded 2000 examples from ../data/tldr_valid.parquet
📊 Dataset loaded:
  Training samples: 116722
  Validation samples: 2000

📝 Sample data shapes:
  input_ids: (550,)
  attention_mask: (550,)
  labels: (550,)

📋 Sample content:
  First 100 chars of tokenized text: SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or...
  Input IDs (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]
  Labels (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]


### Train

In [None]:
import jax.numpy as jnp
from jax import random

def create_data_loader(dataset, batch_size, shuffle=True):
    # Pre-tokenize everything into arrays
    all_data = [dataset[i] for i in range(len(dataset))]
    
    # Convert to JAX arrays upfront
    input_ids = jnp.array([x["input_ids"] for x in all_data])
    attention_mask = jnp.array([x["attention_mask"] for x in all_data])
    labels = jnp.array([x["labels"] for x in all_data])
    
    num_batches = len(dataset) // batch_size  # Drop incomplete batches
    
    if shuffle:
        indices = random.permutation(random.PRNGKey(42), len(dataset))
        input_ids = input_ids[indices]
        attention_mask = attention_mask[indices]
        labels = labels[indices]
    
    # Yield consistent-shaped batches
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        yield {
            "input_ids": input_ids[start_idx:end_idx],
            "attention_mask": attention_mask[start_idx:end_idx], 
            "labels": labels[start_idx:end_idx]
        }

# Example usage
print("🔄 Creating data loaders...")

# Create data loaders
batch_size = 8  # Small batch for demo

train_loader = create_data_loader(train_dataset, batch_size, shuffle=True)
val_loader = create_data_loader(val_dataset, batch_size, shuffle=False)

# Test the data loader
print("📦 Testing batch loading...")
batch = next(iter(train_loader))

print(f"Batch shapes:")
print(f"  input_ids: {batch['input_ids'].shape}")
print(f"  attention_mask: {batch['attention_mask'].shape}")
print(f"  labels: {batch['labels'].shape}")
print(f"  Data type: {batch['input_ids'].dtype}")

# Show how many batches we'll have
train_batches = (len(train_dataset) + batch_size - 1) // batch_size
val_batches = (len(val_dataset) + batch_size - 1) // batch_size
print(f"\n📊 Batch counts:")
print(f"  Training batches: {train_batches}")
print(f"  Validation batches: {val_batches}")

print("\n✅ Data loader ready for JAX training!")

🔄 Creating data loaders...
📦 Testing batch loading...
Batch shapes:
  input_ids: (4, 550)
  attention_mask: (4, 550)
  labels: (4, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 29181
  Validation batches: 500

✅ Data loader ready for JAX training!
Batch shapes:
  input_ids: (4, 550)
  attention_mask: (4, 550)
  labels: (4, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 29181
  Validation batches: 500

✅ Data loader ready for JAX training!


In [9]:
import optax
from flax.training import train_state

# Initialize the optimizer
tx = optax.adam(1e-5)

# What is state? 
# It holds the model parameters, optimizer state, and any auxiliary data needed for training.
state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

In [None]:
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from typing import Any, Dict, Tuple
from tqdm import trange
import time
step_times = []


# Define the training step function
@jax.jit
def train_step(
    state: train_state.TrainState, 
    batch: Dict[str, jnp.ndarray],
    dropout_rng: jax.random.PRNGKey
) -> Tuple[train_state.TrainState, jnp.ndarray]:
    """
    Perform a single training step (forward, loss, backward, update).
    
    Args:
        state: TrainState containing params & optimizer state.
        batch: Dict with keys "input_ids", "attention_mask", "labels" of shape (B, L).
    
    Returns:
        new_state: Updated TrainState after applying gradients.
        loss: Scalar loss for this batch.
    """
    
    def loss_fn(params: Any) -> jnp.ndarray:
        # Forward pass: get logits [batch, seq_len, vocab_size]
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            params=params,
            train=True,
            dropout_rng=dropout_rng  # Use dropout RNG for training
        )
        logits = outputs.logits
        
        # Causal LM shift: predict token t given inputs up to t-1
        shift_logits = logits[..., :-1, :]           # drop last logit
        shift_labels = batch["labels"][..., 1:]       # drop first label
        
        # Compute per-token cross-entropy
        loss = optax.softmax_cross_entropy_with_integer_labels(
            shift_logits, shift_labels
        )  # shape: (batch, seq_len-1)
        
        # Mask out padding tokens from loss
        mask = batch["attention_mask"][..., 1:]       # same shift as labels
        loss = jnp.sum(loss * mask) / jnp.sum(mask)   # mean over non-pad tokens
        
        return loss

    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    # Apply gradients to update parameters
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

# Add memory monitoring and progress tracking to training
def train_loop_with_monitoring(
    state: train_state.TrainState,
    train_loader: Any,
    num_epochs: int = 3,
    rng_key: jnp.ndarray = None
) -> train_state.TrainState:
    """
    Training loop with memory monitoring and progress tracking.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(42)
    
    # Calculate total batches for progress tracking
    train_batches = (len(train_dataset) + batch_size - 1) // batch_size

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        steps = 0
        running_loss = 0.0
        
        print(f"📊 Starting Epoch {epoch}/{num_epochs} ({train_batches} batches)")
        
        # Iterate over all batches
        for batch in train_loader:
            # Split RNG key for this step
            rng_key, dropout_rng = random.split(rng_key)

            # Single training step with timing
            step_start = time.time()
            state, loss = train_step(state, batch, dropout_rng)
            step_time = time.time() - step_start
            step_times.append(step_time)
            
            if steps <= 5:  # Print first 5 steps
                print(f"Step {steps} took: {step_time:.2f}s")

            # Process loss and increment counters
            current_loss = float(loss)
            epoch_loss += current_loss
            running_loss += current_loss
            steps += 1
            
            # Progress tracking every 50 steps
            if steps % 25 == 0:
                avg_loss_25 = running_loss / 25
                progress_pct = (steps / train_batches) * 100
                print(f"  📈 Step {steps}/{train_batches} ({progress_pct:.1f}%) - Avg loss (last 25): {avg_loss_25:.4f}")
                running_loss = 0.0  # Reset running loss
        
        avg_loss = epoch_loss / steps
        print(f"✅ Epoch {epoch} complete — avg loss: {avg_loss:.4f}")
        print("📊 End of epoch memory:")
        print("-" * 50)
    
    return state

# Use this instead of the regular train_loop
num_epochs = 5  # Start with fewer epochs to see the pattern
state = train_loop_with_monitoring(state, train_loader, num_epochs=num_epochs)

# After training, `state.params` holds your fine-tuned model weights.



📊 Starting Epoch 1/5 (29181 batches)
Step 0 took: 37.69s
Step 0 took: 37.69s
Step 1 took: 0.24s
Step 1 took: 0.24s
Step 2 took: 0.24s
Step 2 took: 0.24s
Step 3 took: 0.24s
Step 3 took: 0.24s
Step 4 took: 0.24s
Step 4 took: 0.24s
Step 5 took: 0.25s
Step 5 took: 0.25s
  📈 Step 25/29181 (0.1%) - Avg loss (last 25): 4.0178
  📈 Step 25/29181 (0.1%) - Avg loss (last 25): 4.0178
  📈 Step 50/29181 (0.2%) - Avg loss (last 25): 4.3545
  📈 Step 50/29181 (0.2%) - Avg loss (last 25): 4.3545
  📈 Step 75/29181 (0.3%) - Avg loss (last 25): 4.3457
  📈 Step 75/29181 (0.3%) - Avg loss (last 25): 4.3457
  📈 Step 100/29181 (0.3%) - Avg loss (last 25): 3.9540
  📈 Step 100/29181 (0.3%) - Avg loss (last 25): 3.9540
  📈 Step 125/29181 (0.4%) - Avg loss (last 25): 3.9294
  📈 Step 125/29181 (0.4%) - Avg loss (last 25): 3.9294
  📈 Step 150/29181 (0.5%) - Avg loss (last 25): 3.9046
  📈 Step 150/29181 (0.5%) - Avg loss (last 25): 3.9046
  📈 Step 175/29181 (0.6%) - Avg loss (last 25): 3.8267
  📈 Step 175/29181 (0.6%

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x78882fabb4c0>>
Traceback (most recent call last):
  File "/home/bwilliams/mlx/week6/bb-finetune/ben_dev/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


  📈 Step 225/29181 (0.8%) - Avg loss (last 25): 3.6598
