# TODO:
- try lora
- try reward model
- ppo

### 📊 Training Configuration - Fast Experimentation

For quick experimentation, we'll use a subset of the dataset. Change `USE_FULL_DATASET` to `True` when you want to train on the complete dataset.

In [1]:
# 🔧 EXPERIMENT CONFIGURATION
USE_FULL_DATASET = True  # Set to True for full training, False for quick experiments
EXPERIMENT_SIZE = 200    # Use only this many samples for quick testing

print(f"🎯 Configuration:")
print(f"  USE_FULL_DATASET: {USE_FULL_DATASET}")
if not USE_FULL_DATASET:
    print(f"  EXPERIMENT_SIZE: {EXPERIMENT_SIZE} samples")
    print(f"  Estimated training time: ~{EXPERIMENT_SIZE // 8 * 0.5 / 60:.1f} minutes per epoch")

🎯 Configuration:
  USE_FULL_DATASET: True


### Sort mem
Jax shotguns a bunch of memoery on the GPU but it doesn't need it

In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [3]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.11 GB
  Total: 8.00 GB
  Utilization: 1.3%


### Define dataset class

In [4]:
from datasets import load_dataset
import numpy as np  # or jax.numpy as jnp if needed
import pandas as pd
from pathlib import Path

class TLDRDataset:
    def __init__(self, data_dir, tokenizer, split, max_length=550):
        """
        Load TLDR dataset from local parquet files.
        
        Args:
            data_dir: Path to directory containing parquet files
            tokenizer: Tokenizer to use
            split: 'train', 'valid', or 'test'
            max_length: Maximum sequence length
        """
        # Load the parquet file
        parquet_file = Path(data_dir) / f"tldr_{split}.parquet"
        if not parquet_file.exists():
            raise FileNotFoundError(f"Dataset file not found: {parquet_file}")
        
        df = pd.read_parquet(parquet_file)
        
        # Combine prompt and label for training (teacher forcing)
        self.examples = [row["prompt"] + row["label"] for _, row in df.iterrows()]
        
        # Limit validation set size for faster iteration
        if "valid" in split:
            self.examples = self.examples[:2000]
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"Loaded {len(self.examples)} examples from {parquet_file}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        # Tokenize the text
        enc = self.tokenizer(
            self.examples[idx],
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
        )
        return {
            "input_ids": np.array(enc["input_ids"], dtype=np.int32),
            "attention_mask": np.array(enc["attention_mask"], dtype=np.int32),
            "labels": np.array(enc["input_ids"], dtype=np.int32),  # teacher forcing
        }


  from .autonotebook import tqdm as notebook_tqdm


### Load model and tokeniser
Stick to gpt2 now for compatibility, then move to qwen. GPT2 = 124m params, 500mb. qwen0.6b = 550m params, 2gb. So beware 3x qwen0.6b on my 8gb gpu might tank its memory.

In [5]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# 1. Tokenizer is identical
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Flax (JAX) model
#    .from_pretrained returns a FlaxAutoModelForCausalLM whose weights live in model.params
model = FlaxAutoModelForCausalLM.from_pretrained("gpt2", dtype=jnp.float16) # Back to your working config

# 3. If you’ve added new tokens, resize just like in PyTorch:
#    model = model.resize_token_embeddings(len(tokenizer))

# 4. Make sure padding is configured
model.config.pad_token_id = tokenizer.eos_token_id

# 5. Pull out the parameter dict for training
params = model.params

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


In [6]:
import jax
import subprocess

def check_memory_usage():
    # JAX device info
    device = jax.devices()[0]
    print(f"JAX device: {device}")
    
    # Try JAX memory stats (newer versions)
    if hasattr(device, 'memory_stats'):
        stats = device.memory_stats()
        if stats is not None:
            print(f"JAX memory stats:")
            print(f"  Bytes in use: {stats.get('bytes_in_use', 0) / 1024**3:.2f} GB")
            print(f"  Pool bytes: {stats.get('pool_bytes', 0) / 1024**3:.2f} GB")
            print(f"  Peak bytes: {stats.get('peak_bytes_in_use', 0) / 1024**3:.2f} GB")
        else:
            print("JAX memory stats not available")
    else:
        print("JAX memory_stats method not available")
    
    # Use nvidia-smi command instead
    try:
        result = subprocess.run([
            'nvidia-smi', 
            '--query-gpu=memory.used,memory.total', 
            '--format=csv,noheader,nounits'
        ], capture_output=True, text=True, check=True)
        
        used, total = result.stdout.strip().split(', ')
        used_gb = int(used) / 1024
        total_gb = int(total) / 1024
        print(f"\nNVIDIA GPU memory:")
        print(f"  Used: {used_gb:.2f} GB")
        print(f"  Total: {total_gb:.2f} GB")
        print(f"  Utilization: {used_gb/total_gb*100:.1f}%")
    except Exception as e:
        print(f"Could not get GPU memory info: {e}")

# Run this before and during training
check_memory_usage()

JAX device: cuda:0
JAX memory stats not available

NVIDIA GPU memory:
  Used: 0.65 GB
  Total: 8.00 GB
  Utilization: 8.1%


### Inspect model 

In [7]:
# Inspect JAX/Flax model architecture and parameters
import jax
from jax.tree_util import tree_map
from flax.core import freeze

print("🔍 JAX/Flax Model Inspection")
print("=" * 50)

# 1. Basic model info
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")
print(f"Vocab size: {model.config.vocab_size}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of attention heads: {model.config.n_head}")

print("\n📊 Parameter Analysis")
print("=" * 30)

# 2. Count parameters (JAX way)
def count_params(params):
    """Count total parameters in a JAX parameter tree"""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

total_params = count_params(params)
print(f"Total parameters: {total_params:,}")
print(f"Total parameters (millions): {total_params / 1_000_000:.2f}M")

# 3. Inspect parameter structure
print("\n🏗️ Parameter Structure")
print("=" * 25)

def print_param_shapes(params, prefix=""):
    """Recursively print parameter shapes"""
    if isinstance(params, dict):
        for key, value in params.items():
            print_param_shapes(value, f"{prefix}.{key}" if prefix else key)
    else:
        print(f"{prefix}: {params.shape} ({params.dtype})")

print_param_shapes(params)

# 4. Memory usage estimation
def estimate_memory(params):
    """Estimate memory usage in MB"""
    total_bytes = sum(x.nbytes for x in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

memory_mb = estimate_memory(params)
print(f"\n💾 Estimated memory usage: {memory_mb:.2f} MB")

# 5. Compare with different dtypes
print(f"\n🔢 Memory usage by dtype:")
print(f"  float32: {memory_mb:.2f} MB")
print(f"  float16: {memory_mb / 2:.2f} MB") 
print(f"  bfloat16: {memory_mb / 2:.2f} MB")

🔍 JAX/Flax Model Inspection
Model type: <class 'transformers.models.gpt2.modeling_flax_gpt2.FlaxGPT2LMHeadModel'>
Model config: GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "pad_token_id": 50256,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.53.2",
  "use_cache": true,
  "vocab

### Do stuff

In [8]:
# Create datasets
train_dataset = TLDRDataset("../data", tokenizer, split="train")
val_dataset   = TLDRDataset("../data", tokenizer, split="valid")

# Apply experiment configuration
if not USE_FULL_DATASET:
    print(f"🧪 Using experiment mode: limiting to {EXPERIMENT_SIZE} samples")
    # Create a subset by limiting the examples list
    train_dataset.examples = train_dataset.examples[:EXPERIMENT_SIZE]
    # Keep validation dataset smaller too for quick validation
    val_dataset.examples = val_dataset.examples[:min(100, len(val_dataset.examples))]

print(f"📊 Dataset loaded:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

if not USE_FULL_DATASET:
    estimated_time = (len(train_dataset) // 8) * 0.5 / 60  # batch_size=8, 0.5s per step
    print(f"  ⏱️  Estimated training time per epoch: ~{estimated_time:.1f} minutes")

# Preview a sample
sample = train_dataset[0]
print(f"\n📝 Sample data shapes:")
print(f"  input_ids: {sample['input_ids'].shape}")
print(f"  attention_mask: {sample['attention_mask'].shape}")
print(f"  labels: {sample['labels'].shape}")

# Preview actual content
print(f"\n📋 Sample content:")
print(f"  First 100 chars of tokenized text: {train_dataset.examples[0][:100]}...")
print(f"  Input IDs (first 10): {sample['input_ids'][:10]}")
print(f"  Labels (first 10): {sample['labels'][:10]}")


Loaded 116722 examples from ../data/tldr_train.parquet
Loaded 2000 examples from ../data/tldr_valid.parquet
📊 Dataset loaded:
  Training samples: 116722
  Validation samples: 2000

📝 Sample data shapes:
  input_ids: (550,)
  attention_mask: (550,)
  labels: (550,)

📋 Sample content:
  First 100 chars of tokenized text: SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or...
  Input IDs (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]
  Labels (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]


### Train

In [9]:
import jax.numpy as jnp
from jax import random

def create_data_loader(dataset, batch_size, shuffle=True):
    # Pre-tokenize everything into arrays (your original working approach)
    all_data = [dataset[i] for i in range(len(dataset))]
    
    # Convert to JAX arrays upfront
    input_ids = jnp.array([x["input_ids"] for x in all_data])
    attention_mask = jnp.array([x["attention_mask"] for x in all_data])
    labels = jnp.array([x["labels"] for x in all_data])
    
    num_batches = len(dataset) // batch_size  # Drop incomplete batches
    
    if shuffle:
        indices = random.permutation(random.PRNGKey(42), len(dataset))
        input_ids = input_ids[indices]
        attention_mask = attention_mask[indices]
        labels = labels[indices]
    
    # Yield consistent-shaped batches
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        yield {
            "input_ids": input_ids[start_idx:end_idx],
            "attention_mask": attention_mask[start_idx:end_idx], 
            "labels": labels[start_idx:end_idx]
        }

# Example usage
print("🔄 Creating data loaders...")

# Create data loaders
BATCH_SIZE = 4  # Back to your working batch size

train_loader = create_data_loader(train_dataset, BATCH_SIZE, shuffle=True)
val_loader = create_data_loader(val_dataset, BATCH_SIZE, shuffle=False)

# Test the data loader
print("📦 Testing batch loading...")
batch = next(iter(train_loader))

print(f"Batch shapes:")
print(f"  input_ids: {batch['input_ids'].shape}")
print(f"  attention_mask: {batch['attention_mask'].shape}")
print(f"  labels: {batch['labels'].shape}")
print(f"  Data type: {batch['input_ids'].dtype}")

# Show how many batches we'll have
train_batches = (len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
val_batches = (len(val_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"\n📊 Batch counts:")
print(f"  Training batches: {train_batches}")
print(f"  Validation batches: {val_batches}")

print("\n✅ Data loader ready for JAX training!")

🔄 Creating data loaders...
📦 Testing batch loading...
Batch shapes:
  input_ids: (4, 550)
  attention_mask: (4, 550)
  labels: (4, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 29181
  Validation batches: 500

✅ Data loader ready for JAX training!
Batch shapes:
  input_ids: (4, 550)
  attention_mask: (4, 550)
  labels: (4, 550)
  Data type: int32

📊 Batch counts:
  Training batches: 29181
  Validation batches: 500

✅ Data loader ready for JAX training!


In [10]:
import optax
from flax.training import train_state

# Initialize the optimizer
tx = optax.adam(1e-5)

# What is state? 
# It holds the model parameters, optimizer state, and any auxiliary data needed for training.
state = train_state.TrainState.create(apply_fn=model.__call__,
                                      params=params,
                                      tx=tx)

### 🚀 Training Speed Summary

**Understanding your training times:**

- **Full dataset (116,722 samples)**: ~1.8 hours per epoch (14,590 batches × 0.45s each)
- **Experiment mode (1,000 samples)**: ~1 minute per epoch (125 batches × 0.45s each)

**First step is always slow (~40s)** due to JAX JIT compilation, then it runs at normal speed.

**To switch between modes:**
- `USE_FULL_DATASET = False` → Quick experiments (1,000 samples)
- `USE_FULL_DATASET = True` → Full training (116,722 samples)

### 💾 Model Checkpointing

We'll use Flax's built-in checkpointing system for clean, versioned model saving.

In [11]:
import os
from pathlib import Path
from flax.training import checkpoints
import json
from typing import Any

# 📁 Checkpoint configuration
CHECKPOINT_DIR = "../models"  # Correct relative path from ben_dev/misc/ to ben_dev/models/
MAX_CHECKPOINTS = 3  # Keep only the 3 most recent checkpoints

def save_checkpoint(state, epoch, checkpoint_dir=CHECKPOINT_DIR, max_to_keep=MAX_CHECKPOINTS):
    """
    Save model checkpoint using Flax's built-in checkpointing.
    
    Args:
        state: TrainState containing params, opt_state, and step
        epoch: Current epoch number 
        checkpoint_dir: Directory to save checkpoints (can be relative)
        max_to_keep: Maximum number of checkpoints to keep (auto-cleanup)
    """
    # Convert relative path to absolute (Flax requires absolute paths)
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    # Ensure checkpoint directory exists
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Save the complete TrainState (params + optimizer state + step)
    # Flax automatically handles versioning with step numbers
    checkpoints.save_checkpoint(
        ckpt_dir=str(checkpoint_dir),  # Convert Path to string for Flax
        target=state,  # This saves the entire TrainState
        step=state.step,  # Use training step for versioning
        keep=max_to_keep,  # Auto-cleanup: keep only last N checkpoints
        overwrite=True  # Allow overwriting if step already exists
    )
    
    # Also save model config for easy loading later
    config_path = checkpoint_dir / "model_config.json"
    if not config_path.exists():
        config_dict = {
            "model_type": "gpt2",
            "vocab_size": model.config.vocab_size,
            "n_embd": model.config.n_embd,
            "n_layer": model.config.n_layer,
            "n_head": model.config.n_head,
            "max_length": 550,  # Your max sequence length
            "dtype": "float16"  # Your model dtype
        }
        with open(config_path, 'w') as f:
            json.dump(config_dict, f, indent=2)
    
    print(f"💾 Checkpoint saved: epoch {epoch}, step {state.step}")
    print(f"   📂 Location: {checkpoint_dir}/checkpoint_{state.step}")


print(f"📂 Checkpoints will be saved to: {CHECKPOINT_DIR}")
print(f"🔄 Will keep {MAX_CHECKPOINTS} most recent checkpoints")

# Add memory monitoring and progress tracking to training
def train_loop(
    state: train_state.TrainState,
    train_loader: Any,
    num_epochs: int = 3,
    rng_key: jnp.ndarray = None,
    log_every: int = 100  # Print loss every N steps
) -> train_state.TrainState:
    """
    Training loop with memory monitoring and progress tracking.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(42)
    
    # Calculate total steps across all epochs
    train_batches = (len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
    total_steps = train_batches * num_epochs

    # Single progress bar for all training
    pbar = tqdm(
        total=total_steps,
        desc="🚀 Training",
        ncols=120,
        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
    )
    
    global_step = 0
    total_loss = 0.0
    total_steps_logged = 0

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_steps = 0
        
        # Create fresh data loader for each epoch
        data_loader = create_data_loader(train_dataset, batch_size=BATCH_SIZE)
        
        for batch in data_loader:
            # Split RNG key for this step
            rng_key, dropout_rng = random.split(rng_key)

            # Single training step
            step_start = time.time()
            state, loss = train_step(state, batch, dropout_rng)
            step_time = time.time() - step_start
            
            # Process loss and increment counters
            current_loss = float(loss)
            epoch_loss += current_loss
            epoch_steps += 1
            global_step += 1
            total_loss += current_loss
            total_steps_logged += 1
            
            # Update progress bar
            pbar.set_postfix({
                'epoch': f'{epoch}/{num_epochs}',
                'loss': f'{current_loss:.4f}',
                'avg_loss': f'{epoch_loss/epoch_steps:.4f}',
                'step_time': f'{step_time:.2f}s'
            })
            pbar.update(1)
            
            # 📊 Log loss every N steps
            if global_step % log_every == 0:
                avg_loss_so_far = total_loss / total_steps_logged
                tqdm.write(f"🔥 Step {global_step}/{total_steps} | Loss: {current_loss:.4f} | Avg Loss: {avg_loss_so_far:.4f} | Step Time: {step_time:.2f}s")
            
            # Optional: Print timing for first few steps
            if global_step <= 10:
                tqdm.write(f"⏱️  Global step {global_step} took: {step_time:.2f}s")
        
        # Print epoch summary
        avg_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0.0
        tqdm.write(f"✅ Epoch {epoch} complete — avg loss: {avg_loss:.4f}")
        
        # 💾 Save checkpoint at end of each epoch
        save_checkpoint(state, epoch)
        tqdm.write(f"📁 Checkpoint saved for epoch {epoch}")
    
    pbar.close()
    return state



📂 Checkpoints will be saved to: ../models
🔄 Will keep 3 most recent checkpoints


### 📊 Enhanced Training Loop with Frequent Logging

The training loop now includes **frequent progress logging** to help you monitor training in real-time:

**Key Features:**
- **`log_every` parameter**: Control how often to print detailed loss updates
- **Real-time progress**: See loss, average loss, and step timing during training
- **Flexible logging frequency**: 
  - `log_every=10` → Very frequent (good for debugging)
  - `log_every=100` → Moderate frequency (good for monitoring)
  - `log_every=500` → Less frequent (good for production)

**What you'll see:**
- 🔥 Step-by-step loss updates every N steps
- ⏱️ Timing information for the first 10 steps (to monitor JIT compilation)
- ✅ Epoch summaries with average loss
- 📁 Checkpoint saving confirmations

This gives you much better visibility into training progress compared to only seeing loss at the end of each epoch!

In [None]:
import jax
import jax
import jax.numpy as jnp
from jax import random, value_and_grad
import optax
from flax.training import train_state
from typing import Any, Dict, Tuple
from tqdm import trange, tqdm
import time
step_times = []

# 🎯 TRAINING CONFIGURATION
# Set log_every to control how often to print loss updates during training
# - log_every=10: Very frequent logging (every 10 steps) - good for debugging
# - log_every=100: Moderate logging (every 100 steps) - good for monitoring progress  
# - log_every=500: Less frequent logging - good for production runs
LOG_EVERY = 25  # Print loss every 25 steps for our quick experiment

print(f"🔧 Training Configuration:")
print(f"  📊 Will log progress every {LOG_EVERY} steps")
print(f"  🎯 Total steps per epoch: ~{(len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE}")

# Define the training step function
@jax.jit
def train_step(
    state: train_state.TrainState, 
    batch: Dict[str, jnp.ndarray],
    dropout_rng: jax.random.PRNGKey
) -> Tuple[train_state.TrainState, jnp.ndarray]:
    """
    Perform a single training step (forward, loss, backward, update).
    
    Args:
        state: TrainState containing params & optimizer state.
        batch: Dict with keys "input_ids", "attention_mask", "labels" of shape (B, L).
    
    Returns:
        new_state: Updated TrainState after applying gradients.
        loss: Scalar loss for this batch.
    """
    
    def loss_fn(params: Any) -> jnp.ndarray:
        # Forward pass: get logits [batch, seq_len, vocab_size]
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            params=params,
            train=True,
            dropout_rng=dropout_rng  # Use dropout RNG for training
        )
        logits = outputs.logits
        
        # Causal LM shift: predict token t given inputs up to t-1
        shift_logits = logits[..., :-1, :]           # drop last logit
        shift_labels = batch["labels"][..., 1:]       # drop first label
        
        # Compute per-token cross-entropy
        loss = optax.softmax_cross_entropy_with_integer_labels(
            shift_logits, shift_labels
        )  # shape: (batch, seq_len-1)
        
        # Mask out padding tokens from loss
        mask = batch["attention_mask"][..., 1:]       # same shift as labels
        loss = jnp.sum(loss * mask) / jnp.sum(mask)   # mean over non-pad tokens
        
        return loss

    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    # Apply gradients to update parameters
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

# Add memory monitoring and progress tracking to training
def train_loop(
    state: train_state.TrainState,
    train_dataset: Any,  # Pass dataset instead of loader
    num_epochs: int = 3,
    rng_key: jnp.ndarray = None,
    log_every: int = 100  # Print progress every N steps
) -> train_state.TrainState:
    """
    Training loop with memory monitoring and progress tracking.
    """
    if rng_key is None:
        rng_key = random.PRNGKey(42)
    
    # Calculate total steps across all epochs
    train_batches = (len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE
    total_steps = train_batches * num_epochs

    # Single progress bar for all training
    pbar = tqdm(
        total=total_steps,
        desc="🚀 Training",
        ncols=120,
        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
    )
    
    global_step = 0

    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_steps = 0
        
        # Create a fresh data loader for each epoch
        train_loader = create_data_loader(train_dataset, BATCH_SIZE, shuffle=True)
        
        for batch in train_loader:
            # Split RNG key for this step
            rng_key, dropout_rng = random.split(rng_key)

            # Single training step
            step_start = time.time()
            state, loss = train_step(state, batch, dropout_rng)
            step_time = time.time() - step_start
            
            # Process loss and increment counters
            current_loss = float(loss)
            epoch_loss += current_loss
            epoch_steps += 1
            global_step += 1
            
            # Update progress bar
            pbar.set_postfix({
                'epoch': f'{epoch}/{num_epochs}',
                'loss': f'{current_loss:.4f}',
                'avg_loss': f'{epoch_loss/epoch_steps:.4f}',
                'step_time': f'{step_time:.2f}s'
            })
            pbar.update(1)
            
            # Print detailed progress every log_every steps
            if global_step % log_every == 0 or global_step <= 10:
                avg_loss_so_far = epoch_loss / epoch_steps
                steps_per_sec = 1.0 / step_time if step_time > 0 else 0
                tqdm.write(
                    f"📊 Step {global_step:,} | "
                    f"Epoch {epoch}/{num_epochs} | "
                    f"Loss: {current_loss:.4f} | "
                    f"Avg Loss: {avg_loss_so_far:.4f} | "
                    f"Speed: {steps_per_sec:.1f} steps/s | "
                    f"Step Time: {step_time:.2f}s"
                )
            
            # Print timing for first few steps
            elif global_step <= 10:
                tqdm.write(f"⏱️  Global step {global_step} took: {step_time:.2f}s")
        
        # Print epoch summary
        avg_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0.0
        total_epoch_time = epoch_steps * step_time  # Rough estimate
        tqdm.write(f"✅ Epoch {epoch} complete — avg loss: {avg_loss:.4f} | "
                  f"Steps: {epoch_steps} | Time: {total_epoch_time/60:.1f}m")
        
        # 💾 Save checkpoint at end of each epoch
        save_checkpoint(state, epoch)
        tqdm.write(f"📁 Checkpoint saved for epoch {epoch} (step {state.step})")
    
    pbar.close()
    return state

# 🚀 FIXED: Pass train_dataset (not train_loader) to train_loop
print(f"🔧 Starting training with enhanced logging...")
print(f"   Dataset size: {len(train_dataset)} samples")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Steps per epoch: {(len(train_dataset) + BATCH_SIZE - 1) // BATCH_SIZE}")

# Use this instead of the regular train_loop
num_epochs = 5  # Start with fewer epochs to see the pattern
state = train_loop(state, train_dataset, num_epochs=num_epochs, log_every=LOG_EVERY)  # ✅ Fixed: pass train_dataset

print("🎉 Training complete!")
# After training, `state.params` holds your fine-tuned model weights.



🔧 Training Configuration:
  📊 Will log progress every 25 steps
  🎯 Total steps per epoch: ~29181
🔧 Starting training with enhanced logging...
   Dataset size: 116722 samples
   Batch size: 4
   Steps per epoch: 29181


🚀 Training:   0%|                                                             | 1/145905 [01:44<4235:10:46, 104.50s/it]

📊 Step 1 | Epoch 1/5 | Loss: 3.2637 | Avg Loss: 3.2637 | Speed: 0.0 steps/s | Step Time: 38.05s


🚀 Training:   0%|                                                              | 2/145905 [01:45<1774:35:33, 43.79s/it]

📊 Step 2 | Epoch 1/5 | Loss: 3.1680 | Avg Loss: 3.2158 | Speed: 0.8 steps/s | Step Time: 1.28s


🚀 Training:   0%|                                                               | 3/145905 [01:46<968:55:41, 23.91s/it]

📊 Step 3 | Epoch 1/5 | Loss: 3.4746 | Avg Loss: 3.3021 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                               | 4/145905 [01:46<590:19:43, 14.57s/it]

📊 Step 4 | Epoch 1/5 | Loss: 3.4766 | Avg Loss: 3.3457 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                               | 5/145905 [01:46<381:03:05,  9.40s/it]

📊 Step 5 | Epoch 1/5 | Loss: 3.2422 | Avg Loss: 3.3250 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                               | 6/145905 [01:46<254:56:14,  6.29s/it]

📊 Step 6 | Epoch 1/5 | Loss: 3.2793 | Avg Loss: 3.3174 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                               | 7/145905 [01:47<174:52:13,  4.31s/it]

📊 Step 7 | Epoch 1/5 | Loss: 3.4531 | Avg Loss: 3.3368 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                               | 8/145905 [01:47<122:24:49,  3.02s/it]

📊 Step 8 | Epoch 1/5 | Loss: 3.5352 | Avg Loss: 3.3616 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                                | 9/145905 [01:47<87:13:40,  2.15s/it]

📊 Step 9 | Epoch 1/5 | Loss: 3.5938 | Avg Loss: 3.3874 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                               | 10/145905 [01:47<63:28:09,  1.57s/it]

📊 Step 10 | Epoch 1/5 | Loss: 3.3359 | Avg Loss: 3.3822 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                               | 25/145905 [01:51<10:15:31,  3.95it/s]

📊 Step 25 | Epoch 1/5 | Loss: 3.6348 | Avg Loss: 3.4469 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   0%|                                                               | 50/145905 [01:57<10:05:21,  4.02it/s]

📊 Step 50 | Epoch 1/5 | Loss: 4.4219 | Avg Loss: 3.8817 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                               | 75/145905 [02:04<10:13:07,  3.96it/s]

📊 Step 75 | Epoch 1/5 | Loss: 4.4219 | Avg Loss: 4.0197 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 100/145905 [02:10<10:10:51,  3.98it/s]

📊 Step 100 | Epoch 1/5 | Loss: 4.2109 | Avg Loss: 4.1003 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 125/145905 [02:16<10:29:00,  3.86it/s]

📊 Step 125 | Epoch 1/5 | Loss: 4.2266 | Avg Loss: 4.1704 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 150/145905 [02:23<10:16:53,  3.94it/s]

📊 Step 150 | Epoch 1/5 | Loss: 4.4023 | Avg Loss: 4.1821 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 175/145905 [02:29<10:30:35,  3.85it/s]

📊 Step 175 | Epoch 1/5 | Loss: 3.9160 | Avg Loss: 4.1719 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 200/145905 [02:36<10:33:28,  3.83it/s]

📊 Step 200 | Epoch 1/5 | Loss: 3.9043 | Avg Loss: 4.1526 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 225/145905 [02:42<10:21:17,  3.91it/s]

📊 Step 225 | Epoch 1/5 | Loss: 4.0039 | Avg Loss: 4.1269 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 250/145905 [02:48<10:13:38,  3.96it/s]

📊 Step 250 | Epoch 1/5 | Loss: 3.8223 | Avg Loss: 4.0987 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|                                                              | 275/145905 [02:55<10:32:37,  3.84it/s]

📊 Step 275 | Epoch 1/5 | Loss: 4.0039 | Avg Loss: 4.0773 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   0%|▏                                                             | 300/145905 [03:01<10:09:05,  3.98it/s]

📊 Step 300 | Epoch 1/5 | Loss: 4.0195 | Avg Loss: 4.0680 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 325/145905 [03:07<10:07:32,  3.99it/s]

📊 Step 325 | Epoch 1/5 | Loss: 3.8555 | Avg Loss: 4.0490 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 350/145905 [03:14<10:20:59,  3.91it/s]

📊 Step 350 | Epoch 1/5 | Loss: 3.4141 | Avg Loss: 4.0265 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   0%|▏                                                             | 375/145905 [03:20<10:32:06,  3.84it/s]

📊 Step 375 | Epoch 1/5 | Loss: 3.8535 | Avg Loss: 4.0109 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   0%|▏                                                             | 400/145905 [03:26<10:05:39,  4.00it/s]

📊 Step 400 | Epoch 1/5 | Loss: 3.6816 | Avg Loss: 3.9955 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   0%|▏                                                             | 425/145905 [03:33<10:21:27,  3.90it/s]

📊 Step 425 | Epoch 1/5 | Loss: 3.5723 | Avg Loss: 3.9755 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 450/145905 [03:39<10:24:10,  3.88it/s]

📊 Step 450 | Epoch 1/5 | Loss: 3.7305 | Avg Loss: 3.9621 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   0%|▏                                                             | 475/145905 [03:46<10:20:58,  3.90it/s]

📊 Step 475 | Epoch 1/5 | Loss: 3.4746 | Avg Loss: 3.9436 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 500/145905 [03:52<10:42:22,  3.77it/s]

📊 Step 500 | Epoch 1/5 | Loss: 3.7441 | Avg Loss: 3.9306 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   0%|▏                                                             | 525/145905 [03:58<10:12:25,  3.96it/s]

📊 Step 525 | Epoch 1/5 | Loss: 3.5801 | Avg Loss: 3.9171 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 550/145905 [04:05<10:05:52,  4.00it/s]

📊 Step 550 | Epoch 1/5 | Loss: 3.4121 | Avg Loss: 3.9016 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|▏                                                             | 575/145905 [04:11<10:20:30,  3.90it/s]

📊 Step 575 | Epoch 1/5 | Loss: 3.6523 | Avg Loss: 3.8871 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|▎                                                             | 600/145905 [04:17<10:08:44,  3.98it/s]

📊 Step 600 | Epoch 1/5 | Loss: 3.5176 | Avg Loss: 3.8749 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|▎                                                             | 625/145905 [04:24<10:09:02,  3.98it/s]

📊 Step 625 | Epoch 1/5 | Loss: 3.4199 | Avg Loss: 3.8628 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   0%|▎                                                             | 650/145905 [04:30<10:24:16,  3.88it/s]

📊 Step 650 | Epoch 1/5 | Loss: 3.4258 | Avg Loss: 3.8502 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|▎                                                             | 675/145905 [04:36<10:17:48,  3.92it/s]

📊 Step 675 | Epoch 1/5 | Loss: 3.3086 | Avg Loss: 3.8391 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   0%|▎                                                             | 700/145905 [04:43<10:07:53,  3.98it/s]

📊 Step 700 | Epoch 1/5 | Loss: 3.4180 | Avg Loss: 3.8280 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   0%|▎                                                             | 725/145905 [04:49<10:09:39,  3.97it/s]

📊 Step 725 | Epoch 1/5 | Loss: 3.7246 | Avg Loss: 3.8209 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   1%|▎                                                             | 750/145905 [04:56<10:11:11,  3.96it/s]

📊 Step 750 | Epoch 1/5 | Loss: 3.6270 | Avg Loss: 3.8128 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   1%|▎                                                             | 775/145905 [05:02<10:12:08,  3.95it/s]

📊 Step 775 | Epoch 1/5 | Loss: 3.6543 | Avg Loss: 3.8067 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▎                                                             | 800/145905 [05:08<10:11:19,  3.96it/s]

📊 Step 800 | Epoch 1/5 | Loss: 3.8008 | Avg Loss: 3.7973 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▎                                                             | 825/145905 [05:15<10:06:53,  3.98it/s]

📊 Step 825 | Epoch 1/5 | Loss: 3.9023 | Avg Loss: 3.7903 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   1%|▎                                                             | 850/145905 [05:21<10:09:36,  3.97it/s]

📊 Step 850 | Epoch 1/5 | Loss: 3.4414 | Avg Loss: 3.7817 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▎                                                             | 875/145905 [05:27<10:10:08,  3.96it/s]

📊 Step 875 | Epoch 1/5 | Loss: 3.9395 | Avg Loss: 3.7739 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   1%|▍                                                             | 900/145905 [05:34<10:30:39,  3.83it/s]

📊 Step 900 | Epoch 1/5 | Loss: 3.1504 | Avg Loss: 3.7654 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▍                                                             | 925/145905 [05:40<10:14:35,  3.93it/s]

📊 Step 925 | Epoch 1/5 | Loss: 3.6758 | Avg Loss: 3.7602 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                             | 950/145905 [05:47<10:25:16,  3.86it/s]

📊 Step 950 | Epoch 1/5 | Loss: 3.3750 | Avg Loss: 3.7530 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   1%|▍                                                             | 975/145905 [05:53<10:11:57,  3.95it/s]

📊 Step 975 | Epoch 1/5 | Loss: 3.5391 | Avg Loss: 3.7469 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                            | 1000/145905 [05:59<10:35:39,  3.80it/s]

📊 Step 1,000 | Epoch 1/5 | Loss: 3.7090 | Avg Loss: 3.7405 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▍                                                            | 1025/145905 [06:06<10:07:19,  3.98it/s]

📊 Step 1,025 | Epoch 1/5 | Loss: 3.3262 | Avg Loss: 3.7329 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                            | 1050/145905 [06:12<10:25:54,  3.86it/s]

📊 Step 1,050 | Epoch 1/5 | Loss: 3.9121 | Avg Loss: 3.7280 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                            | 1075/145905 [06:19<10:24:42,  3.86it/s]

📊 Step 1,075 | Epoch 1/5 | Loss: 3.3125 | Avg Loss: 3.7197 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                            | 1100/145905 [06:25<10:33:39,  3.81it/s]

📊 Step 1,100 | Epoch 1/5 | Loss: 3.2695 | Avg Loss: 3.7143 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   1%|▍                                                            | 1125/145905 [06:31<10:28:32,  3.84it/s]

📊 Step 1,125 | Epoch 1/5 | Loss: 3.4766 | Avg Loss: 3.7080 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▍                                                            | 1150/145905 [06:38<10:09:26,  3.96it/s]

📊 Step 1,150 | Epoch 1/5 | Loss: 3.3613 | Avg Loss: 3.7011 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▍                                                            | 1175/145905 [06:44<10:13:28,  3.93it/s]

📊 Step 1,175 | Epoch 1/5 | Loss: 3.3301 | Avg Loss: 3.6955 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1200/145905 [06:51<10:09:30,  3.96it/s]

📊 Step 1,200 | Epoch 1/5 | Loss: 3.4414 | Avg Loss: 3.6899 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   1%|▌                                                            | 1225/145905 [06:57<10:28:47,  3.83it/s]

📊 Step 1,225 | Epoch 1/5 | Loss: 3.5137 | Avg Loss: 3.6860 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1250/145905 [07:04<10:08:24,  3.96it/s]

📊 Step 1,250 | Epoch 1/5 | Loss: 3.4746 | Avg Loss: 3.6804 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1275/145905 [07:10<10:33:04,  3.81it/s]

📊 Step 1,275 | Epoch 1/5 | Loss: 3.7930 | Avg Loss: 3.6751 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   1%|▌                                                            | 1300/145905 [07:17<10:10:56,  3.94it/s]

📊 Step 1,300 | Epoch 1/5 | Loss: 3.4043 | Avg Loss: 3.6705 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1325/145905 [07:23<10:18:34,  3.90it/s]

📊 Step 1,325 | Epoch 1/5 | Loss: 3.3691 | Avg Loss: 3.6664 | Speed: 3.8 steps/s | Step Time: 0.27s


🚀 Training:   1%|▌                                                            | 1350/145905 [07:29<10:32:42,  3.81it/s]

📊 Step 1,350 | Epoch 1/5 | Loss: 3.2070 | Avg Loss: 3.6619 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▌                                                            | 1375/145905 [07:36<10:41:18,  3.76it/s]

📊 Step 1,375 | Epoch 1/5 | Loss: 3.0332 | Avg Loss: 3.6571 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▌                                                            | 1400/145905 [07:42<10:06:49,  3.97it/s]

📊 Step 1,400 | Epoch 1/5 | Loss: 3.4297 | Avg Loss: 3.6536 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1425/145905 [07:49<10:37:51,  3.78it/s]

📊 Step 1,425 | Epoch 1/5 | Loss: 3.6562 | Avg Loss: 3.6500 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▌                                                            | 1450/145905 [07:55<10:11:38,  3.94it/s]

📊 Step 1,450 | Epoch 1/5 | Loss: 3.4492 | Avg Loss: 3.6458 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   1%|▌                                                            | 1475/145905 [08:01<10:08:31,  3.96it/s]

📊 Step 1,475 | Epoch 1/5 | Loss: 3.4199 | Avg Loss: 3.6423 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1500/145905 [08:08<10:16:30,  3.90it/s]

📊 Step 1,500 | Epoch 1/5 | Loss: 3.2402 | Avg Loss: 3.6382 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1525/145905 [08:14<10:26:27,  3.84it/s]

📊 Step 1,525 | Epoch 1/5 | Loss: 3.7129 | Avg Loss: 3.6352 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▋                                                            | 1550/145905 [08:21<10:12:37,  3.93it/s]

📊 Step 1,550 | Epoch 1/5 | Loss: 3.6758 | Avg Loss: 3.6307 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   1%|▋                                                            | 1575/145905 [08:27<10:25:49,  3.84it/s]

📊 Step 1,575 | Epoch 1/5 | Loss: 3.5527 | Avg Loss: 3.6274 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1600/145905 [08:34<10:15:00,  3.91it/s]

📊 Step 1,600 | Epoch 1/5 | Loss: 2.9355 | Avg Loss: 3.6225 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1625/145905 [08:40<10:06:22,  3.97it/s]

📊 Step 1,625 | Epoch 1/5 | Loss: 3.3965 | Avg Loss: 3.6191 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1650/145905 [08:46<10:28:35,  3.82it/s]

📊 Step 1,650 | Epoch 1/5 | Loss: 3.2656 | Avg Loss: 3.6156 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▋                                                            | 1675/145905 [08:53<10:10:59,  3.93it/s]

📊 Step 1,675 | Epoch 1/5 | Loss: 3.3340 | Avg Loss: 3.6118 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▋                                                            | 1700/145905 [08:59<10:31:54,  3.80it/s]

📊 Step 1,700 | Epoch 1/5 | Loss: 3.1543 | Avg Loss: 3.6079 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   1%|▋                                                            | 1725/145905 [09:06<10:02:31,  3.99it/s]

📊 Step 1,725 | Epoch 1/5 | Loss: 3.2910 | Avg Loss: 3.6050 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   1%|▋                                                            | 1750/145905 [09:12<10:11:42,  3.93it/s]

📊 Step 1,750 | Epoch 1/5 | Loss: 3.7422 | Avg Loss: 3.6029 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▋                                                            | 1775/145905 [09:19<10:08:23,  3.95it/s]

📊 Step 1,775 | Epoch 1/5 | Loss: 3.0234 | Avg Loss: 3.5992 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▊                                                            | 1800/145905 [09:25<10:21:01,  3.87it/s]

📊 Step 1,800 | Epoch 1/5 | Loss: 3.6055 | Avg Loss: 3.5966 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 1825/145905 [09:31<10:04:44,  3.97it/s]

📊 Step 1,825 | Epoch 1/5 | Loss: 3.4102 | Avg Loss: 3.5934 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▊                                                            | 1850/145905 [09:38<10:17:55,  3.89it/s]

📊 Step 1,850 | Epoch 1/5 | Loss: 3.4355 | Avg Loss: 3.5904 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▊                                                            | 1875/145905 [09:44<10:11:12,  3.93it/s]

📊 Step 1,875 | Epoch 1/5 | Loss: 3.2383 | Avg Loss: 3.5877 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▊                                                            | 1900/145905 [09:51<10:22:14,  3.86it/s]

📊 Step 1,900 | Epoch 1/5 | Loss: 3.2773 | Avg Loss: 3.5850 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 1925/145905 [09:57<10:24:11,  3.84it/s]

📊 Step 1,925 | Epoch 1/5 | Loss: 3.3594 | Avg Loss: 3.5822 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   1%|▊                                                            | 1950/145905 [10:03<10:26:56,  3.83it/s]

📊 Step 1,950 | Epoch 1/5 | Loss: 3.3730 | Avg Loss: 3.5794 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 1975/145905 [10:10<10:24:50,  3.84it/s]

📊 Step 1,975 | Epoch 1/5 | Loss: 3.2969 | Avg Loss: 3.5769 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 2000/145905 [10:16<10:17:21,  3.88it/s]

📊 Step 2,000 | Epoch 1/5 | Loss: 3.4414 | Avg Loss: 3.5742 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 2025/145905 [10:23<10:25:15,  3.84it/s]

📊 Step 2,025 | Epoch 1/5 | Loss: 3.9258 | Avg Loss: 3.5726 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▊                                                            | 2050/145905 [10:29<10:10:54,  3.92it/s]

📊 Step 2,050 | Epoch 1/5 | Loss: 3.5840 | Avg Loss: 3.5695 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   1%|▊                                                            | 2075/145905 [10:35<10:19:53,  3.87it/s]

📊 Step 2,075 | Epoch 1/5 | Loss: 3.2148 | Avg Loss: 3.5668 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▉                                                            | 2100/145905 [10:42<10:12:35,  3.91it/s]

📊 Step 2,100 | Epoch 1/5 | Loss: 3.2070 | Avg Loss: 3.5649 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▉                                                            | 2125/145905 [10:48<10:19:58,  3.87it/s]

📊 Step 2,125 | Epoch 1/5 | Loss: 3.4023 | Avg Loss: 3.5622 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   1%|▉                                                            | 2150/145905 [10:55<10:01:07,  3.99it/s]

📊 Step 2,150 | Epoch 1/5 | Loss: 3.3242 | Avg Loss: 3.5594 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   1%|▉                                                            | 2175/145905 [11:01<10:15:08,  3.89it/s]

📊 Step 2,175 | Epoch 1/5 | Loss: 3.7383 | Avg Loss: 3.5577 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|▉                                                            | 2200/145905 [11:08<10:14:32,  3.90it/s]

📊 Step 2,200 | Epoch 1/5 | Loss: 3.2969 | Avg Loss: 3.5561 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|▉                                                            | 2225/145905 [11:14<10:01:16,  3.98it/s]

📊 Step 2,225 | Epoch 1/5 | Loss: 3.2422 | Avg Loss: 3.5539 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|▉                                                            | 2250/145905 [11:20<10:25:05,  3.83it/s]

📊 Step 2,250 | Epoch 1/5 | Loss: 3.2500 | Avg Loss: 3.5516 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|▉                                                            | 2275/145905 [11:27<10:47:59,  3.69it/s]

📊 Step 2,275 | Epoch 1/5 | Loss: 3.4922 | Avg Loss: 3.5502 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|▉                                                            | 2300/145905 [11:33<10:05:40,  3.95it/s]

📊 Step 2,300 | Epoch 1/5 | Loss: 3.1465 | Avg Loss: 3.5476 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|▉                                                            | 2325/145905 [11:40<10:04:29,  3.96it/s]

📊 Step 2,325 | Epoch 1/5 | Loss: 3.5254 | Avg Loss: 3.5453 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|▉                                                            | 2350/145905 [11:46<10:00:01,  3.99it/s]

📊 Step 2,350 | Epoch 1/5 | Loss: 3.1738 | Avg Loss: 3.5437 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|▉                                                            | 2375/145905 [11:53<10:06:29,  3.94it/s]

📊 Step 2,375 | Epoch 1/5 | Loss: 3.4316 | Avg Loss: 3.5422 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2400/145905 [11:59<10:06:26,  3.94it/s]

📊 Step 2,400 | Epoch 1/5 | Loss: 3.5000 | Avg Loss: 3.5408 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2425/145905 [12:05<10:19:05,  3.86it/s]

📊 Step 2,425 | Epoch 1/5 | Loss: 3.6758 | Avg Loss: 3.5387 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2450/145905 [12:12<10:22:13,  3.84it/s]

📊 Step 2,450 | Epoch 1/5 | Loss: 3.6602 | Avg Loss: 3.5369 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2475/145905 [12:18<10:13:26,  3.90it/s]

📊 Step 2,475 | Epoch 1/5 | Loss: 3.4551 | Avg Loss: 3.5351 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█                                                            | 2500/145905 [12:25<10:05:04,  3.95it/s]

📊 Step 2,500 | Epoch 1/5 | Loss: 3.5645 | Avg Loss: 3.5337 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2525/145905 [12:31<10:04:12,  3.96it/s]

📊 Step 2,525 | Epoch 1/5 | Loss: 3.8359 | Avg Loss: 3.5320 | Speed: 4.2 steps/s | Step Time: 0.24s


🚀 Training:   2%|█                                                            | 2550/145905 [12:38<10:15:35,  3.88it/s]

📊 Step 2,550 | Epoch 1/5 | Loss: 3.4180 | Avg Loss: 3.5298 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2575/145905 [12:44<10:28:28,  3.80it/s]

📊 Step 2,575 | Epoch 1/5 | Loss: 3.3438 | Avg Loss: 3.5283 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   2%|█                                                            | 2600/145905 [12:51<10:03:42,  3.96it/s]

📊 Step 2,600 | Epoch 1/5 | Loss: 3.2461 | Avg Loss: 3.5268 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█                                                            | 2625/145905 [12:57<10:26:26,  3.81it/s]

📊 Step 2,625 | Epoch 1/5 | Loss: 3.4238 | Avg Loss: 3.5254 | Speed: 3.8 steps/s | Step Time: 0.26s


🚀 Training:   2%|█                                                            | 2650/145905 [13:03<10:00:53,  3.97it/s]

📊 Step 2,650 | Epoch 1/5 | Loss: 3.0605 | Avg Loss: 3.5233 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█                                                            | 2675/145905 [13:10<10:20:14,  3.85it/s]

📊 Step 2,675 | Epoch 1/5 | Loss: 3.2031 | Avg Loss: 3.5204 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▏                                                           | 2700/145905 [13:16<10:01:18,  3.97it/s]

📊 Step 2,700 | Epoch 1/5 | Loss: 3.4121 | Avg Loss: 3.5189 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▏                                                           | 2725/145905 [13:23<10:21:18,  3.84it/s]

📊 Step 2,725 | Epoch 1/5 | Loss: 3.3223 | Avg Loss: 3.5173 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▏                                                           | 2750/145905 [13:29<10:13:24,  3.89it/s]

📊 Step 2,750 | Epoch 1/5 | Loss: 3.2480 | Avg Loss: 3.5156 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▏                                                           | 2775/145905 [13:35<10:04:50,  3.94it/s]

📊 Step 2,775 | Epoch 1/5 | Loss: 3.3711 | Avg Loss: 3.5141 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▏                                                           | 2800/145905 [13:42<10:06:56,  3.93it/s]

📊 Step 2,800 | Epoch 1/5 | Loss: 3.2480 | Avg Loss: 3.5126 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▏                                                           | 2825/145905 [13:48<10:00:38,  3.97it/s]

📊 Step 2,825 | Epoch 1/5 | Loss: 3.3496 | Avg Loss: 3.5112 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▏                                                           | 2850/145905 [13:55<10:18:49,  3.85it/s]

📊 Step 2,850 | Epoch 1/5 | Loss: 3.5273 | Avg Loss: 3.5102 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▏                                                           | 2875/145905 [14:01<10:20:25,  3.84it/s]

📊 Step 2,875 | Epoch 1/5 | Loss: 3.3594 | Avg Loss: 3.5086 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▏                                                           | 2900/145905 [14:08<10:14:35,  3.88it/s]

📊 Step 2,900 | Epoch 1/5 | Loss: 3.2246 | Avg Loss: 3.5070 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█▏                                                           | 2925/145905 [14:14<10:20:18,  3.84it/s]

📊 Step 2,925 | Epoch 1/5 | Loss: 3.2656 | Avg Loss: 3.5056 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▎                                                            | 2950/145905 [14:20<9:56:16,  4.00it/s]

📊 Step 2,950 | Epoch 1/5 | Loss: 3.1348 | Avg Loss: 3.5042 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                            | 2975/145905 [14:27<9:58:47,  3.98it/s]

📊 Step 2,975 | Epoch 1/5 | Loss: 3.4941 | Avg Loss: 3.5029 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                            | 3000/145905 [14:33<9:56:18,  3.99it/s]

📊 Step 3,000 | Epoch 1/5 | Loss: 3.3281 | Avg Loss: 3.5016 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                            | 3025/145905 [14:39<9:57:13,  3.99it/s]

📊 Step 3,025 | Epoch 1/5 | Loss: 3.2812 | Avg Loss: 3.5000 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                           | 3050/145905 [14:46<10:13:24,  3.88it/s]

📊 Step 3,050 | Epoch 1/5 | Loss: 3.4824 | Avg Loss: 3.4983 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▎                                                            | 3075/145905 [14:52<9:56:07,  3.99it/s]

📊 Step 3,075 | Epoch 1/5 | Loss: 3.1934 | Avg Loss: 3.4967 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█▎                                                           | 3100/145905 [14:59<10:02:02,  3.95it/s]

📊 Step 3,100 | Epoch 1/5 | Loss: 3.2246 | Avg Loss: 3.4949 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                           | 3125/145905 [15:05<10:15:36,  3.87it/s]

📊 Step 3,125 | Epoch 1/5 | Loss: 3.1875 | Avg Loss: 3.4936 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▎                                                           | 3150/145905 [15:11<10:09:30,  3.90it/s]

📊 Step 3,150 | Epoch 1/5 | Loss: 3.3711 | Avg Loss: 3.4919 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                            | 3175/145905 [15:18<9:54:38,  4.00it/s]

📊 Step 3,175 | Epoch 1/5 | Loss: 3.1953 | Avg Loss: 3.4903 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▎                                                           | 3200/145905 [15:24<10:18:43,  3.84it/s]

📊 Step 3,200 | Epoch 1/5 | Loss: 3.5664 | Avg Loss: 3.4888 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▎                                                           | 3225/145905 [15:30<10:08:55,  3.91it/s]

📊 Step 3,225 | Epoch 1/5 | Loss: 3.5996 | Avg Loss: 3.4873 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▍                                                            | 3250/145905 [15:37<9:52:34,  4.01it/s]

📊 Step 3,250 | Epoch 1/5 | Loss: 3.3281 | Avg Loss: 3.4863 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█▎                                                           | 3275/145905 [15:43<10:16:55,  3.85it/s]

📊 Step 3,275 | Epoch 1/5 | Loss: 3.4609 | Avg Loss: 3.4847 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▍                                                            | 3300/145905 [15:49<9:54:48,  4.00it/s]

📊 Step 3,300 | Epoch 1/5 | Loss: 3.1914 | Avg Loss: 3.4834 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                            | 3325/145905 [15:56<9:56:41,  3.98it/s]

📊 Step 3,325 | Epoch 1/5 | Loss: 3.3105 | Avg Loss: 3.4827 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                            | 3350/145905 [16:02<9:55:03,  3.99it/s]

📊 Step 3,350 | Epoch 1/5 | Loss: 3.1113 | Avg Loss: 3.4809 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                           | 3375/145905 [16:08<10:05:56,  3.92it/s]

📊 Step 3,375 | Epoch 1/5 | Loss: 3.2715 | Avg Loss: 3.4796 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   2%|█▍                                                           | 3400/145905 [16:15<10:01:16,  3.95it/s]

📊 Step 3,400 | Epoch 1/5 | Loss: 3.2031 | Avg Loss: 3.4781 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                            | 3425/145905 [16:21<9:57:49,  3.97it/s]

📊 Step 3,425 | Epoch 1/5 | Loss: 3.1719 | Avg Loss: 3.4767 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                            | 3450/145905 [16:28<9:58:28,  3.97it/s]

📊 Step 3,450 | Epoch 1/5 | Loss: 3.2520 | Avg Loss: 3.4757 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                           | 3475/145905 [16:34<10:14:21,  3.86it/s]

📊 Step 3,475 | Epoch 1/5 | Loss: 3.2188 | Avg Loss: 3.4746 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                           | 3500/145905 [16:40<10:13:25,  3.87it/s]

📊 Step 3,500 | Epoch 1/5 | Loss: 3.4316 | Avg Loss: 3.4731 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▍                                                           | 3525/145905 [16:47<10:08:00,  3.90it/s]

📊 Step 3,525 | Epoch 1/5 | Loss: 3.3633 | Avg Loss: 3.4720 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▌                                                            | 3550/145905 [16:53<9:51:21,  4.01it/s]

📊 Step 3,550 | Epoch 1/5 | Loss: 3.0371 | Avg Loss: 3.4712 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   2%|█▌                                                            | 3575/145905 [16:59<9:57:01,  3.97it/s]

📊 Step 3,575 | Epoch 1/5 | Loss: 3.1289 | Avg Loss: 3.4698 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▌                                                           | 3600/145905 [17:06<10:14:11,  3.86it/s]

📊 Step 3,600 | Epoch 1/5 | Loss: 3.2734 | Avg Loss: 3.4687 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   2%|█▌                                                            | 3625/145905 [17:12<9:54:50,  3.99it/s]

📊 Step 3,625 | Epoch 1/5 | Loss: 3.1992 | Avg Loss: 3.4675 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▌                                                           | 3650/145905 [17:19<10:09:57,  3.89it/s]

📊 Step 3,650 | Epoch 1/5 | Loss: 3.1484 | Avg Loss: 3.4667 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▌                                                           | 3675/145905 [17:25<10:17:31,  3.84it/s]

📊 Step 3,675 | Epoch 1/5 | Loss: 3.1406 | Avg Loss: 3.4655 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▌                                                           | 3700/145905 [17:32<10:12:08,  3.87it/s]

📊 Step 3,700 | Epoch 1/5 | Loss: 3.5098 | Avg Loss: 3.4642 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▌                                                            | 3725/145905 [17:38<9:53:01,  4.00it/s]

📊 Step 3,725 | Epoch 1/5 | Loss: 3.1523 | Avg Loss: 3.4632 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▌                                                           | 3750/145905 [17:44<10:13:47,  3.86it/s]

📊 Step 3,750 | Epoch 1/5 | Loss: 3.2480 | Avg Loss: 3.4619 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▌                                                            | 3775/145905 [17:51<9:56:02,  3.97it/s]

📊 Step 3,775 | Epoch 1/5 | Loss: 3.3398 | Avg Loss: 3.4605 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▌                                                           | 3800/145905 [17:57<10:11:20,  3.87it/s]

📊 Step 3,800 | Epoch 1/5 | Loss: 3.5488 | Avg Loss: 3.4593 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▌                                                           | 3825/145905 [18:03<10:15:47,  3.85it/s]

📊 Step 3,825 | Epoch 1/5 | Loss: 3.3047 | Avg Loss: 3.4585 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▌                                                           | 3850/145905 [18:10<10:17:31,  3.83it/s]

📊 Step 3,850 | Epoch 1/5 | Loss: 3.1621 | Avg Loss: 3.4577 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▌                                                           | 3875/145905 [18:16<10:14:39,  3.85it/s]

📊 Step 3,875 | Epoch 1/5 | Loss: 3.2422 | Avg Loss: 3.4566 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                            | 3900/145905 [18:23<9:51:22,  4.00it/s]

📊 Step 3,900 | Epoch 1/5 | Loss: 3.0410 | Avg Loss: 3.4555 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▋                                                           | 3925/145905 [18:29<10:10:58,  3.87it/s]

📊 Step 3,925 | Epoch 1/5 | Loss: 3.0020 | Avg Loss: 3.4543 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                           | 3950/145905 [18:35<10:07:59,  3.89it/s]

📊 Step 3,950 | Epoch 1/5 | Loss: 3.2461 | Avg Loss: 3.4533 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                           | 3975/145905 [18:42<10:14:53,  3.85it/s]

📊 Step 3,975 | Epoch 1/5 | Loss: 3.2129 | Avg Loss: 3.4523 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                            | 4000/145905 [18:48<9:52:12,  3.99it/s]

📊 Step 4,000 | Epoch 1/5 | Loss: 3.1367 | Avg Loss: 3.4512 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▋                                                           | 4025/145905 [18:54<10:14:59,  3.85it/s]

📊 Step 4,025 | Epoch 1/5 | Loss: 3.7578 | Avg Loss: 3.4504 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                           | 4050/145905 [19:01<10:11:13,  3.87it/s]

📊 Step 4,050 | Epoch 1/5 | Loss: 3.0352 | Avg Loss: 3.4491 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▋                                                            | 4075/145905 [19:07<9:57:52,  3.95it/s]

📊 Step 4,075 | Epoch 1/5 | Loss: 3.3633 | Avg Loss: 3.4480 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▋                                                           | 4100/145905 [19:14<10:13:43,  3.85it/s]

📊 Step 4,100 | Epoch 1/5 | Loss: 3.3320 | Avg Loss: 3.4464 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▋                                                           | 4125/145905 [19:20<10:12:05,  3.86it/s]

📊 Step 4,125 | Epoch 1/5 | Loss: 3.9473 | Avg Loss: 3.4456 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▊                                                            | 4150/145905 [19:26<9:56:01,  3.96it/s]

📊 Step 4,150 | Epoch 1/5 | Loss: 3.2949 | Avg Loss: 3.4448 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▊                                                            | 4175/145905 [19:33<9:50:11,  4.00it/s]

📊 Step 4,175 | Epoch 1/5 | Loss: 3.0762 | Avg Loss: 3.4436 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                           | 4200/145905 [19:39<10:15:18,  3.84it/s]

📊 Step 4,200 | Epoch 1/5 | Loss: 3.2539 | Avg Loss: 3.4431 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▊                                                            | 4225/145905 [19:46<9:52:55,  3.98it/s]

📊 Step 4,225 | Epoch 1/5 | Loss: 3.4062 | Avg Loss: 3.4418 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4250/145905 [19:52<9:53:39,  3.98it/s]

📊 Step 4,250 | Epoch 1/5 | Loss: 3.2969 | Avg Loss: 3.4408 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4275/145905 [19:58<9:50:49,  4.00it/s]

📊 Step 4,275 | Epoch 1/5 | Loss: 3.2598 | Avg Loss: 3.4398 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                           | 4300/145905 [20:05<10:30:54,  3.74it/s]

📊 Step 4,300 | Epoch 1/5 | Loss: 3.3848 | Avg Loss: 3.4386 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4325/145905 [20:11<9:52:07,  3.99it/s]

📊 Step 4,325 | Epoch 1/5 | Loss: 3.3242 | Avg Loss: 3.4373 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4350/145905 [20:17<9:56:05,  3.96it/s]

📊 Step 4,350 | Epoch 1/5 | Loss: 2.9961 | Avg Loss: 3.4362 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4375/145905 [20:24<9:50:39,  3.99it/s]

📊 Step 4,375 | Epoch 1/5 | Loss: 3.2480 | Avg Loss: 3.4357 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                            | 4400/145905 [20:30<9:48:09,  4.01it/s]

📊 Step 4,400 | Epoch 1/5 | Loss: 3.3711 | Avg Loss: 3.4351 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▊                                                           | 4425/145905 [20:36<10:09:59,  3.87it/s]

📊 Step 4,425 | Epoch 1/5 | Loss: 3.1387 | Avg Loss: 3.4337 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▊                                                           | 4450/145905 [20:43<10:01:19,  3.92it/s]

📊 Step 4,450 | Epoch 1/5 | Loss: 2.9629 | Avg Loss: 3.4329 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▉                                                            | 4475/145905 [20:49<9:47:30,  4.01it/s]

📊 Step 4,475 | Epoch 1/5 | Loss: 3.2559 | Avg Loss: 3.4320 | Speed: 4.1 steps/s | Step Time: 0.24s


🚀 Training:   3%|█▉                                                            | 4500/145905 [20:55<9:53:21,  3.97it/s]

📊 Step 4,500 | Epoch 1/5 | Loss: 3.3438 | Avg Loss: 3.4307 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                           | 4525/145905 [21:02<10:14:29,  3.83it/s]

📊 Step 4,525 | Epoch 1/5 | Loss: 3.5508 | Avg Loss: 3.4297 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                           | 4550/145905 [21:08<10:12:10,  3.85it/s]

📊 Step 4,550 | Epoch 1/5 | Loss: 3.2969 | Avg Loss: 3.4290 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▉                                                            | 4575/145905 [21:15<9:54:33,  3.96it/s]

📊 Step 4,575 | Epoch 1/5 | Loss: 3.2207 | Avg Loss: 3.4283 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                            | 4600/145905 [21:21<9:55:42,  3.95it/s]

📊 Step 4,600 | Epoch 1/5 | Loss: 3.0664 | Avg Loss: 3.4274 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                           | 4625/145905 [21:28<10:07:41,  3.87it/s]

📊 Step 4,625 | Epoch 1/5 | Loss: 3.4434 | Avg Loss: 3.4263 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                            | 4650/145905 [21:34<9:51:27,  3.98it/s]

📊 Step 4,650 | Epoch 1/5 | Loss: 3.5352 | Avg Loss: 3.4255 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                            | 4675/145905 [21:40<9:47:54,  4.00it/s]

📊 Step 4,675 | Epoch 1/5 | Loss: 3.0977 | Avg Loss: 3.4245 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                            | 4700/145905 [21:47<9:50:01,  3.99it/s]

📊 Step 4,700 | Epoch 1/5 | Loss: 3.3223 | Avg Loss: 3.4235 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|█▉                                                           | 4725/145905 [21:53<10:08:57,  3.86it/s]

📊 Step 4,725 | Epoch 1/5 | Loss: 3.1309 | Avg Loss: 3.4222 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|█▉                                                           | 4750/145905 [22:00<10:08:03,  3.87it/s]

📊 Step 4,750 | Epoch 1/5 | Loss: 3.2832 | Avg Loss: 3.4215 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|██                                                            | 4775/145905 [22:06<9:56:15,  3.94it/s]

📊 Step 4,775 | Epoch 1/5 | Loss: 3.6328 | Avg Loss: 3.4210 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                           | 4800/145905 [22:12<10:00:58,  3.91it/s]

📊 Step 4,800 | Epoch 1/5 | Loss: 3.1172 | Avg Loss: 3.4199 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                           | 4825/145905 [22:19<10:00:33,  3.92it/s]

📊 Step 4,825 | Epoch 1/5 | Loss: 3.3965 | Avg Loss: 3.4191 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                            | 4850/145905 [22:25<9:50:24,  3.98it/s]

📊 Step 4,850 | Epoch 1/5 | Loss: 3.2148 | Avg Loss: 3.4188 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                            | 4875/145905 [22:31<9:48:24,  3.99it/s]

📊 Step 4,875 | Epoch 1/5 | Loss: 3.2324 | Avg Loss: 3.4178 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                           | 4900/145905 [22:38<10:16:35,  3.81it/s]

📊 Step 4,900 | Epoch 1/5 | Loss: 3.4863 | Avg Loss: 3.4168 | Speed: 4.0 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                           | 4925/145905 [22:44<10:10:41,  3.85it/s]

📊 Step 4,925 | Epoch 1/5 | Loss: 3.2617 | Avg Loss: 3.4162 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|██                                                            | 4950/145905 [22:51<9:57:13,  3.93it/s]

📊 Step 4,950 | Epoch 1/5 | Loss: 3.4121 | Avg Loss: 3.4156 | Speed: 3.9 steps/s | Step Time: 0.25s


🚀 Training:   3%|██                                                           | 4975/145905 [22:57<10:11:18,  3.84it/s]

📊 Step 4,975 | Epoch 1/5 | Loss: 3.3086 | Avg Loss: 3.4146 | Speed: 3.9 steps/s | Step Time: 0.26s


🚀 Training:   3%|██                                                           | 5000/145905 [23:03<10:05:37,  3.88it/s]

📊 Step 5,000 | Epoch 1/5 | Loss: 3.1875 | Avg Loss: 3.4141 | Speed: 4.1 steps/s | Step Time: 0.25s


🚀 Training:   3%|██▏                                                           | 5013/145905 [23:07<9:48:42,  3.99it/s]

### 🔄 Checkpoint Usage Examples

Here's how to use the checkpoint system for different scenarios:

In [None]:
def load_checkpoint(checkpoint_dir=CHECKPOINT_DIR, step=None):
    """
    Load model checkpoint using Flax's built-in checkpointing.
    
    Args:
        checkpoint_dir: Directory containing checkpoints (can be relative)
        step: Specific step to load (None = latest)
    
    Returns:
        restored_state: TrainState with loaded params/optimizer/step
        epoch_info: Dict with loaded step information
    """
    # Convert relative path to absolute (Flax requires absolute paths)
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    if not checkpoint_dir.exists():
        raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
    
    # Create a dummy state for restoration template
    # (Flax needs the structure to know what to restore)
    dummy_state = train_state.TrainState.create(
        apply_fn=model.__call__,
        params=params,  # Use current params as template
        tx=optax.adam(1e-5)  # Use current optimizer as template
    )
    
    # Restore the checkpoint
    restored_state = checkpoints.restore_checkpoint(
        ckpt_dir=str(checkpoint_dir),  # Convert Path to string for Flax
        target=dummy_state,  # Template for structure
        step=step  # None = latest, or specify step number
    )
    
    epoch_info = {
        "step": restored_state.step,
        "approximate_epoch": restored_state.step // ((len(train_dataset) // BATCH_SIZE))
    }
    
    print(f"📥 Checkpoint loaded: step {restored_state.step}")
    print(f"   🔄 Approximate epoch: {epoch_info['approximate_epoch']}")
    
    return restored_state, epoch_info

# 🔍 Utility function to list available checkpoints
def list_checkpoints(checkpoint_dir=CHECKPOINT_DIR):
    """List all available checkpoints in the directory."""
    # Convert relative path to absolute
    checkpoint_dir = Path(checkpoint_dir).resolve()
    
    if not checkpoint_dir.exists():
        print(f"No checkpoint directory found: {checkpoint_dir}")
        return []
    
    # Get all checkpoint steps by scanning for checkpoint_* directories
    checkpoint_pattern = checkpoint_dir / "checkpoint_*"
    checkpoint_dirs = list(checkpoint_dir.glob("checkpoint_*"))
    
    if not checkpoint_dirs:
        print(f"No checkpoints found in {checkpoint_dir}")
        return []
    
    # Extract step numbers from directory names
    available_steps = []
    for ckpt_dir in checkpoint_dirs:
        try:
            step = int(ckpt_dir.name.replace("checkpoint_", ""))
            available_steps.append(step)
        except ValueError:
            continue  # Skip invalid checkpoint directory names
    
    # Sort steps numerically
    available_steps.sort()
    
    print(f"📋 Available checkpoints in {checkpoint_dir}:")
    for step in available_steps:
        approximate_epoch = step // ((len(train_dataset) // BATCH_SIZE))
        print(f"   Step {step} (≈ epoch {approximate_epoch})")
    
    return available_steps


In [None]:
# 📋 Example 1: List all available checkpoints
print("=" * 50)
print("🔍 LISTING AVAILABLE CHECKPOINTS")
print("=" * 50)
available_steps = list_checkpoints()

# 📥 Example 2: Load the latest checkpoint (for resuming training)
print("\n" + "=" * 50)
print("📥 LOADING LATEST CHECKPOINT")
print("=" * 50)
try:
    # Uncomment these lines when you have checkpoints saved:
    # restored_state, info = load_checkpoint()
    # print(f"Resumed from step {info['step']}, approximate epoch {info['approximate_epoch']}")
    # state = restored_state  # Use this state to continue training
    print("💡 Uncomment the lines above once you have saved checkpoints!")
except Exception as e:
    print(f"No checkpoints found yet: {e}")

# 📥 Example 3: Load a specific checkpoint by step number
print("\n" + "=" * 50)
print("📥 LOADING SPECIFIC CHECKPOINT")
print("=" * 50)
try:
    # Example: load checkpoint from step 500
    # restored_state, info = load_checkpoint(step=500)
    print("💡 Use load_checkpoint(step=N) to load a specific checkpoint")
except Exception as e:
    print(f"Specific checkpoint example: {e}")

# 🚀 Example 4: Resume training from checkpoint
print("\n" + "=" * 50)
print("🚀 RESUMING TRAINING")
print("=" * 50)
print("""
# To resume training from the latest checkpoint:
restored_state, info = load_checkpoint()
state = restored_state

# Then continue training normally:
state = train_loop(state, train_loader, num_epochs=additional_epochs)

# The training will pick up from where it left off!
""")

print("✅ Checkpoint examples ready to use!")

🔍 LISTING AVAILABLE CHECKPOINTS
📋 Available checkpoints in /home/bwilliams/mlx/week6/bb-finetune/ben_dev/models:
   Step 24 (≈ epoch 0)
   Step 14589 (≈ epoch 0)

📥 LOADING LATEST CHECKPOINT
💡 Uncomment the lines above once you have saved checkpoints!

📥 LOADING SPECIFIC CHECKPOINT
💡 Use load_checkpoint(step=N) to load a specific checkpoint

🚀 RESUMING TRAINING

# To resume training from the latest checkpoint:
restored_state, info = load_checkpoint()
state = restored_state

# Then continue training normally:
state = train_loop(state, train_loader, num_epochs=additional_epochs)

# The training will pick up from where it left off!

✅ Checkpoint examples ready to use!
