In [1]:
# TEXT GENERATION - Run this AFTER training completes
import torch
from model import DeepSeekV3Config, DeepSeekV3ForCausalLM
import tiktoken
import os

# Check if checkpoint exists
checkpoint_path = "checkpoint_5500.pt"

if not os.path.exists(checkpoint_path):
    print("=" * 70)
    print("‚ö†Ô∏è  ERROR: No trained model found!")
    print("=" * 70)
    print(f"\nThe checkpoint file '{checkpoint_path}' does not exist.")
    print("\nPlease run the training cell FIRST to create the model.")
    print("\nThe training cell will:")
    print("  1. Create and train DeepSeek-V3 from scratch")
    print("  2. Save the trained model to 'checkpoint_5500.pt'")
    print("  3. Then you can run this generation cell")
    print("\n" + "=" * 70)
else:
    # Load the trained model
    print("Loading trained model...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    config = checkpoint['config']
    model = DeepSeekV3ForCausalLM(config)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Device detection
    if torch.cuda.is_available():
        device = 'cuda'
        print(f"Using device: {device} ({torch.cuda.get_device_name(0)})")
    elif torch.backends.mps.is_available():
        device = 'mps'
        print(f"Using device: {device} (Apple Silicon GPU)")
    else:
        device = 'cpu'
        print(f"Using device: {device}")

    model = model.to(device)
    model.eval()

    print(f"\n‚úÖ Model loaded successfully!")
    print(f"Parameters: {checkpoint['total_params']:,}")

    # Tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Generate text
    def generate_text(prompt_text, max_new_tokens=100, temperature=0.8, top_k=50):
        """Generate text from a prompt"""
        # Encode the prompt
        input_ids = tokenizer.encode(prompt_text)
        input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
        
        print(f"\nPrompt: '{prompt_text}'")
        print(f"Generating {max_new_tokens} tokens...")
        print("-" * 70)
        
        # Generate
        with torch.no_grad():
            generated = model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k
            )
        
        # Decode
        generated_text = tokenizer.decode(generated[0].cpu().tolist())
        print(generated_text)
        print("-" * 70)
        return generated_text

    # Example generations
    print("\n" + "=" * 70)
    print("TEXT GENERATION EXAMPLES")
    print("=" * 70)

    # Example 1
    generate_text("Once upon a time", max_new_tokens=50, temperature=0.8)

    # Example 2
    generate_text("The meaning of life is", max_new_tokens=50, temperature=0.7)

    # Example 3 - Custom prompt (uncomment to use)
    # generate_text("Your custom prompt here", max_new_tokens=100, temperature=0.8)

‚ö†Ô∏è  ERROR: No trained model found!

The checkpoint file 'checkpoint_5500.pt' does not exist.

Please run the training cell FIRST to create the model.

The training cell will:
  1. Create and train DeepSeek-V3 from scratch
  2. Save the trained model to 'checkpoint_5500.pt'
  3. Then you can run this generation cell



# PHASE 1: Train for 5000 steps and save checkpoint

Run this cell first. It will:
- Train your model for exactly 5000 steps
- Automatically stop after 5000 steps
- Save a checkpoint to `checkpoint_5000.pt`

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
from tqdm import tqdm
import math

# Import DeepSeekV3 from model.py
from model import DeepSeekV3Config, DeepSeekV3ForCausalLM


# Dataset
class TextDataset(Dataset):
    def __init__(self, filepath, tokenizer, block_size=1024):
        with open(filepath, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.tokens = tokenizer.encode(text)

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        input_ids = self.tokens[idx:idx + self.block_size]
        target_ids = self.tokens[idx + 1:idx + self.block_size + 1]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


# ==========================================
# PHASE 1: TRAIN FOR 5000 STEPS (Plain PyTorch)
# ==========================================

print("=" * 70)
print("PHASE 1: TRAINING FOR 5000 STEPS (Plain PyTorch)")
print("=" * 70)

# Configuration
USE_SMALL_MODEL = False  # Set to False to use full 135M model

if USE_SMALL_MODEL:
    print("\n[Using SMALL model]")
    config = DeepSeek-V3Config(
        vocab_size=50257,
        hidden_size=384,
        intermediate_size=1024,
        num_hidden_layers=12,
        num_attention_heads=6,
        num_key_value_heads=2,
        max_position_embeddings=1024,
        rms_norm_eps=1e-5,
        rope_theta=10000,
        tie_word_embeddings=True,
        attention_dropout=0.0,
    )
    block_size = 512
    batch_size = 4
    max_lr = 3e-4
    accumulate_grad_batches = 4
else:
    print("\n[Using FULL DeepSeek-V3-135M architecture]")
    config = DeepSeekV3Config(
        vocab_size=50257,
        hidden_size=576,
        num_hidden_layers=30,
        num_attention_heads=9,
        kv_lora_rank=512,
        moe_intermediate_size=256,
        n_shared_experts=1,
        n_routed_experts=8,
        num_experts_per_tok=2,
        max_position_embeddings=2048,
        rms_norm_eps=1e-5,
        rope_theta=10000.0,
        tie_word_embeddings=True,
        attention_dropout=0.0,
    )
    block_size = 1024
    batch_size = 2
    max_lr = 1e-3
    accumulate_grad_batches = 8

# Training parameters
warmup_steps = 100
max_steps = 5000  # Will stop at EXACTLY 5000 steps
log_interval = 50  # Print every 50 steps to avoid output overflow

# Device setup
if torch.cuda.is_available():
    device = 'cuda'
    print(f"\nUsing device: {device} ({torch.cuda.get_device_name(0)})")
elif torch.backends.mps.is_available():
    device = 'mps'
    print(f"\nUsing device: {device} (Apple Silicon GPU)")
else:
    device = 'cpu'
    print(f"\nUsing device: {device}")

# Setup
tokenizer = tiktoken.get_encoding("gpt2")

# Dataset and DataLoader
dataset = TextDataset("input-1.txt", tokenizer, block_size)
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0,
    pin_memory=False
)

# Model
model = DeepSeekV3ForCausalLM(config)
model = model.to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Configuration:")
print(f"  - Architecture: DeepSeek-V3 (MLA + MoE)")
print(f"  - Vocabulary size: {config.vocab_size:,}")
print(f"  - Hidden size: {config.hidden_size}")
print(f"  - Layers: {config.num_hidden_layers}")
print(f"  - Attention heads: {config.num_attention_heads}")
print(f"  - KV LoRA Rank: {config.kv_lora_rank}")
print(f"  - MoE Experts: 1 Shared + {config.n_routed_experts} Routed")
print(f"  - Total parameters: {total_params:,}")

print(f"\nTraining Configuration (Phase 1):")
print(f"  - Max steps: {max_steps}")
print(f"  - Batch size: {batch_size}")
print(f"  - Gradient accumulation: {accumulate_grad_batches}")
print(f"  - Effective batch size: {batch_size * accumulate_grad_batches}")
print(f"  - Warmup steps: {warmup_steps}")
print(f"  - Max learning rate: {max_lr}")
print(f"  - Logging interval: Every {log_interval} steps")

# Optimizer - separate decay parameters
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if 'bias' in name or 'norm' in name or 'embed' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.01},
    {'params': no_decay_params, 'weight_decay': 0.0}
], lr=max_lr, betas=(0.9, 0.95), eps=1e-8)

# Learning rate scheduler
def get_lr(step, warmup_steps, max_steps, max_lr):
    if step < warmup_steps:
        return max_lr * (step / warmup_steps)
    else:
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return max_lr * max(0.1, 1.0 - progress)

# Clear cache
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
elif torch.cuda.is_available():
    torch.cuda.empty_cache()

torch.set_float32_matmul_precision('high')

# ==========================================
# INITIAL LOSS CHECK
# ==========================================
print("\n" + "=" * 70)
print("INITIAL LOSS CHECK (before training)")
print("=" * 70)

model.eval()
with torch.no_grad():
    # Get a sample batch
    sample_batch = next(iter(dataloader))
    input_ids, target_ids = sample_batch
    input_ids = input_ids.to(device)
    target_ids = target_ids.to(device)
    
    # Forward pass
    outputs = model(input_ids=input_ids, labels=target_ids)
    initial_loss = outputs['loss'].item()
    
    print(f"\nInitial loss (random weights): {initial_loss:.4f}")
    print(f"Expected loss for random model: ~{math.log(config.vocab_size):.2f}")
    
    if initial_loss > 50:
        print(f"\n‚ö†Ô∏è  WARNING: Loss is unusually high ({initial_loss:.4f})!")
        print("This suggests a potential issue with loss calculation.")
    elif 8 < initial_loss < 15:
        print(f"\n‚úÖ Loss is in expected range for random initialization!")
    else:
        print(f"\n‚ö†Ô∏è  Loss is {initial_loss:.4f}, which is outside typical range (8-15)")

print("=" * 70 + "\n")
model.train()

# ==========================================
# TRAINING LOOP
# ==========================================
print("=" * 70)
print("STARTING PHASE 1 TRAINING")
print(f"Will stop at EXACTLY {max_steps} steps")
print(f"Logging every {log_interval} steps")
print("=" * 70 + "\n")

# Training loop
update_step = 0  # Actual optimizer update steps
batch_idx = 0    # Batch counter for gradient accumulation
accumulated_loss = 0.0
optimizer.zero_grad()

# Create infinite dataloader
def cycle(dataloader):
    while True:
        for batch in dataloader:
            yield batch

data_iter = cycle(dataloader)

# Progress bar
pbar = tqdm(total=max_steps, desc="Training", unit="step")

while update_step < max_steps:
    # Get batch
    input_ids, target_ids = next(data_iter)
    input_ids = input_ids.to(device)
    target_ids = target_ids.to(device)

    # Forward pass
    outputs = model(input_ids=input_ids, labels=target_ids)
    loss = outputs['loss']

    # Scale loss for gradient accumulation
    loss = loss / accumulate_grad_batches
    loss.backward()

    # Accumulate the UNSCALED loss for logging (multiply back)
    accumulated_loss += loss.item() * accumulate_grad_batches
    batch_idx += 1

    # Update weights after accumulation
    if batch_idx % accumulate_grad_batches == 0:
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update learning rate
        lr = get_lr(update_step, warmup_steps, max_steps, max_lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Optimizer step
        optimizer.step()
        optimizer.zero_grad()

        # Calculate average loss over accumulated batches
        avg_loss = accumulated_loss / accumulate_grad_batches
        
        # Update progress bar every step
        pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr:.6f}'})
        
        # Print only at log_interval to avoid output overflow
        if (update_step + 1) % log_interval == 0 or update_step == 0:
            print(f"Step {update_step + 1}/{max_steps} | Loss: {avg_loss:.4f} | LR: {lr:.6f}")
        
        accumulated_loss = 0.0

        pbar.update(1)
        update_step += 1
        batch_idx = 0  # Reset batch counter

        # EXACT STOP at max_steps
        if update_step >= max_steps:
            print(f"\n‚úÖ Reached {max_steps} steps - stopping training!")
            break

pbar.close()

print("\n" + "=" * 70)
print("PHASE 1 COMPLETED!")
print("=" * 70)

# Save checkpoint
checkpoint_path = "checkpoint_5000.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'total_params': total_params,
    'global_step': update_step,
}, checkpoint_path)

print(f"\n‚úÖ Checkpoint saved to: {checkpoint_path}")
print(f"‚úÖ Trained for exactly {update_step} steps")
print(f"‚úÖ Final loss: {avg_loss:.4f}")
print("\nYou can now run PHASE 2 to continue training for 500 more steps!")

# Cleanup
del model, optimizer
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
elif torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Memory cleared.")

PHASE 1: TRAINING FOR 5000 STEPS (Plain PyTorch)

[Using FULL DeepSeek-V3-135M architecture]

Using device: mps (Apple Silicon GPU)

Model Configuration:
  - Architecture: DeepSeek-V3 (MLA + MoE)
  - Vocabulary size: 50,257
  - Hidden size: 576
  - Layers: 30
  - Attention heads: 9
  - KV LoRA Rank: 512
  - MoE Experts: 1 Shared + 8 Routed
  - Total parameters: 174,012,288

Training Configuration (Phase 1):
  - Max steps: 5000
  - Batch size: 2
  - Gradient accumulation: 8
  - Effective batch size: 16
  - Warmup steps: 100
  - Max learning rate: 0.001
  - Logging interval: Every 50 steps

INITIAL LOSS CHECK (before training)

Initial loss (random weights): 10.9710
Expected loss for random model: ~10.82

‚úÖ Loss is in expected range for random initialization!

STARTING PHASE 1 TRAINING
Will stop at EXACTLY 5000 steps
Logging every 50 steps



Training:   0%|          | 1/5000 [00:35<48:47:38, 35.14s/step, loss=11.1523, lr=0.000000]

Step 1/5000 | Loss: 11.1523 | LR: 0.000000


Training:   1%|          | 50/5000 [12:14<17:04:33, 12.42s/step, loss=6.3426, lr=0.000490]

Step 50/5000 | Loss: 6.3426 | LR: 0.000490


Training:   2%|‚ñè         | 100/5000 [23:09<20:32:24, 15.09s/step, loss=6.1240, lr=0.000990]

Step 100/5000 | Loss: 6.1240 | LR: 0.000990


Training:   3%|‚ñé         | 150/5000 [36:05<21:20:46, 15.84s/step, loss=5.4385, lr=0.000990]

Step 150/5000 | Loss: 5.4385 | LR: 0.000990


Training:   4%|‚ñç         | 200/5000 [49:15<20:21:59, 15.27s/step, loss=5.0636, lr=0.000980]

Step 200/5000 | Loss: 5.0636 | LR: 0.000980


Training:   5%|‚ñå         | 250/5000 [1:01:50<19:38:42, 14.89s/step, loss=4.7123, lr=0.000970]

Step 250/5000 | Loss: 4.7123 | LR: 0.000970


Training:   6%|‚ñå         | 300/5000 [1:14:20<19:53:27, 15.24s/step, loss=4.5043, lr=0.000959]

Step 300/5000 | Loss: 4.5043 | LR: 0.000959


Training:   7%|‚ñã         | 350/5000 [1:26:22<18:52:55, 14.62s/step, loss=4.1652, lr=0.000949]

Step 350/5000 | Loss: 4.1652 | LR: 0.000949


Training:   8%|‚ñä         | 400/5000 [1:38:36<18:40:00, 14.61s/step, loss=3.6805, lr=0.000939]

Step 400/5000 | Loss: 3.6805 | LR: 0.000939


Training:   9%|‚ñâ         | 450/5000 [1:50:50<18:22:51, 14.54s/step, loss=3.0098, lr=0.000929]

Step 450/5000 | Loss: 3.0098 | LR: 0.000929


Training:  10%|‚ñà         | 500/5000 [2:02:45<17:49:06, 14.25s/step, loss=2.6873, lr=0.000919]

Step 500/5000 | Loss: 2.6873 | LR: 0.000919


Training:  11%|‚ñà         | 550/5000 [2:14:36<17:45:03, 14.36s/step, loss=1.8318, lr=0.000908]

Step 550/5000 | Loss: 1.8318 | LR: 0.000908


Training:  12%|‚ñà‚ñè        | 600/5000 [2:26:30<17:30:40, 14.33s/step, loss=1.0757, lr=0.000898]

Step 600/5000 | Loss: 1.0757 | LR: 0.000898


Training:  13%|‚ñà‚ñé        | 650/5000 [2:38:12<17:02:48, 14.11s/step, loss=0.5748, lr=0.000888]

Step 650/5000 | Loss: 0.5748 | LR: 0.000888


Training:  14%|‚ñà‚ñç        | 700/5000 [2:49:48<16:26:11, 13.76s/step, loss=0.4340, lr=0.000878]

Step 700/5000 | Loss: 0.4340 | LR: 0.000878


Training:  15%|‚ñà‚ñå        | 750/5000 [3:01:28<16:40:00, 14.12s/step, loss=0.3486, lr=0.000868]

Step 750/5000 | Loss: 0.3486 | LR: 0.000868


Training:  16%|‚ñà‚ñå        | 800/5000 [3:13:06<16:16:44, 13.95s/step, loss=0.3401, lr=0.000857]

Step 800/5000 | Loss: 0.3401 | LR: 0.000857


Training:  17%|‚ñà‚ñã        | 850/5000 [3:24:42<16:03:46, 13.93s/step, loss=0.2799, lr=0.000847]

Step 850/5000 | Loss: 0.2799 | LR: 0.000847


Training:  18%|‚ñà‚ñä        | 900/5000 [3:36:19<16:00:17, 14.05s/step, loss=0.2515, lr=0.000837]

Step 900/5000 | Loss: 0.2515 | LR: 0.000837


Training:  19%|‚ñà‚ñâ        | 950/5000 [3:47:52<15:30:52, 13.79s/step, loss=0.2300, lr=0.000827]

Step 950/5000 | Loss: 0.2300 | LR: 0.000827


Training:  20%|‚ñà‚ñà        | 1000/5000 [3:59:28<15:36:16, 14.04s/step, loss=0.2030, lr=0.000817]

Step 1000/5000 | Loss: 0.2030 | LR: 0.000817


Training:  21%|‚ñà‚ñà        | 1050/5000 [4:11:02<15:09:21, 13.81s/step, loss=0.2127, lr=0.000806]

Step 1050/5000 | Loss: 0.2127 | LR: 0.000806


Training:  22%|‚ñà‚ñà‚ñè       | 1100/5000 [4:22:37<15:04:40, 13.92s/step, loss=0.2036, lr=0.000796]

Step 1100/5000 | Loss: 0.2036 | LR: 0.000796


Training:  23%|‚ñà‚ñà‚ñé       | 1150/5000 [4:34:14<14:55:26, 13.95s/step, loss=0.1756, lr=0.000786]

Step 1150/5000 | Loss: 0.1756 | LR: 0.000786


Training:  24%|‚ñà‚ñà‚ñç       | 1200/5000 [4:45:42<14:37:52, 13.86s/step, loss=0.1667, lr=0.000776]

Step 1200/5000 | Loss: 0.1667 | LR: 0.000776


Training:  25%|‚ñà‚ñà‚ñå       | 1250/5000 [4:57:17<14:35:38, 14.01s/step, loss=0.1696, lr=0.000766]

Step 1250/5000 | Loss: 0.1696 | LR: 0.000766


Training:  26%|‚ñà‚ñà‚ñå       | 1300/5000 [5:08:49<13:44:03, 13.36s/step, loss=0.1456, lr=0.000755]

Step 1300/5000 | Loss: 0.1456 | LR: 0.000755


Training:  27%|‚ñà‚ñà‚ñã       | 1350/5000 [5:20:23<14:01:15, 13.83s/step, loss=0.1908, lr=0.000745]

Step 1350/5000 | Loss: 0.1908 | LR: 0.000745


Training:  28%|‚ñà‚ñà‚ñä       | 1400/5000 [5:31:57<13:53:54, 13.90s/step, loss=0.1628, lr=0.000735]

Step 1400/5000 | Loss: 0.1628 | LR: 0.000735


Training:  29%|‚ñà‚ñà‚ñâ       | 1450/5000 [5:43:31<13:39:01, 13.84s/step, loss=0.1273, lr=0.000725]

Step 1450/5000 | Loss: 0.1273 | LR: 0.000725


Training:  30%|‚ñà‚ñà‚ñà       | 1500/5000 [5:55:05<13:28:58, 13.87s/step, loss=0.1493, lr=0.000714]

Step 1500/5000 | Loss: 0.1493 | LR: 0.000714


Training:  31%|‚ñà‚ñà‚ñà       | 1550/5000 [6:06:42<13:17:28, 13.87s/step, loss=0.1456, lr=0.000704]

Step 1550/5000 | Loss: 0.1456 | LR: 0.000704


Training:  32%|‚ñà‚ñà‚ñà‚ñè      | 1600/5000 [6:18:15<13:05:27, 13.86s/step, loss=0.1315, lr=0.000694]

Step 1600/5000 | Loss: 0.1315 | LR: 0.000694


Training:  33%|‚ñà‚ñà‚ñà‚ñé      | 1650/5000 [6:29:47<12:59:58, 13.97s/step, loss=0.1299, lr=0.000684]

Step 1650/5000 | Loss: 0.1299 | LR: 0.000684


Training:  34%|‚ñà‚ñà‚ñà‚ñç      | 1700/5000 [6:41:21<12:43:35, 13.88s/step, loss=0.1255, lr=0.000674]

Step 1700/5000 | Loss: 0.1255 | LR: 0.000674


Training:  35%|‚ñà‚ñà‚ñà‚ñå      | 1750/5000 [6:52:53<12:31:02, 13.87s/step, loss=0.1234, lr=0.000663]

Step 1750/5000 | Loss: 0.1234 | LR: 0.000663


Training:  36%|‚ñà‚ñà‚ñà‚ñå      | 1800/5000 [7:04:28<12:24:19, 13.96s/step, loss=0.1352, lr=0.000653]

Step 1800/5000 | Loss: 0.1352 | LR: 0.000653


Training:  37%|‚ñà‚ñà‚ñà‚ñã      | 1850/5000 [7:16:00<12:00:06, 13.72s/step, loss=0.1188, lr=0.000643]

Step 1850/5000 | Loss: 0.1188 | LR: 0.000643


Training:  38%|‚ñà‚ñà‚ñà‚ñä      | 1900/5000 [7:27:33<12:02:38, 13.99s/step, loss=0.1174, lr=0.000633]

Step 1900/5000 | Loss: 0.1174 | LR: 0.000633


Training:  39%|‚ñà‚ñà‚ñà‚ñâ      | 1950/5000 [7:39:05<11:46:19, 13.90s/step, loss=0.1084, lr=0.000623]

Step 1950/5000 | Loss: 0.1084 | LR: 0.000623


Training:  40%|‚ñà‚ñà‚ñà‚ñà      | 2000/5000 [7:50:39<11:35:52, 13.92s/step, loss=0.1069, lr=0.000612]

Step 2000/5000 | Loss: 0.1069 | LR: 0.000612


Training:  41%|‚ñà‚ñà‚ñà‚ñà      | 2050/5000 [8:02:13<11:30:04, 14.04s/step, loss=0.1160, lr=0.000602]

Step 2050/5000 | Loss: 0.1160 | LR: 0.000602


Training:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 2100/5000 [8:13:43<11:08:16, 13.83s/step, loss=0.1054, lr=0.000592]

Step 2100/5000 | Loss: 0.1054 | LR: 0.000592


Training:  43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 2150/5000 [8:25:16<10:58:34, 13.86s/step, loss=0.1106, lr=0.000582]

Step 2150/5000 | Loss: 0.1106 | LR: 0.000582


Training:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 2200/5000 [8:36:44<10:40:18, 13.72s/step, loss=0.1146, lr=0.000572]

Step 2200/5000 | Loss: 0.1146 | LR: 0.000572


Training:  45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 2250/5000 [8:48:18<10:39:33, 13.95s/step, loss=0.1178, lr=0.000561]

Step 2250/5000 | Loss: 0.1178 | LR: 0.000561


Training:  46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 2300/5000 [8:59:45<10:18:09, 13.74s/step, loss=0.0973, lr=0.000551]

Step 2300/5000 | Loss: 0.0973 | LR: 0.000551


Training:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 2350/5000 [9:11:14<10:18:26, 14.00s/step, loss=0.0938, lr=0.000541]

Step 2350/5000 | Loss: 0.0938 | LR: 0.000541


Training:  48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 2400/5000 [9:22:42<9:58:37, 13.81s/step, loss=0.0923, lr=0.000531] 

Step 2400/5000 | Loss: 0.0923 | LR: 0.000531


Training:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 2450/5000 [9:34:14<10:02:02, 14.17s/step, loss=0.0881, lr=0.000521]

Step 2450/5000 | Loss: 0.0881 | LR: 0.000521


Training:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 2500/5000 [9:45:40<9:26:49, 13.60s/step, loss=0.1014, lr=0.000510] 

Step 2500/5000 | Loss: 0.1014 | LR: 0.000510


Training:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 2550/5000 [9:57:13<9:28:30, 13.92s/step, loss=0.0930, lr=0.000500]

Step 2550/5000 | Loss: 0.0930 | LR: 0.000500


Training:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 2600/5000 [10:08:42<8:59:58, 13.50s/step, loss=0.1005, lr=0.000490]

Step 2600/5000 | Loss: 0.1005 | LR: 0.000490


Training:  53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 2650/5000 [10:20:11<8:57:37, 13.73s/step, loss=0.0884, lr=0.000480]

Step 2650/5000 | Loss: 0.0884 | LR: 0.000480


Training:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 2700/5000 [10:31:38<8:48:10, 13.78s/step, loss=0.0974, lr=0.000470]

Step 2700/5000 | Loss: 0.0974 | LR: 0.000470


Training:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 2750/5000 [10:43:12<8:42:44, 13.94s/step, loss=0.0888, lr=0.000459]

Step 2750/5000 | Loss: 0.0888 | LR: 0.000459


Training:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 2800/5000 [10:54:42<8:24:37, 13.76s/step, loss=0.0917, lr=0.000449]

Step 2800/5000 | Loss: 0.0917 | LR: 0.000449


Training:  57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 2850/5000 [11:06:10<8:16:12, 13.85s/step, loss=0.0857, lr=0.000439]

Step 2850/5000 | Loss: 0.0857 | LR: 0.000439


Training:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 2900/5000 [11:17:42<7:59:37, 13.70s/step, loss=0.0834, lr=0.000429]

Step 2900/5000 | Loss: 0.0834 | LR: 0.000429


Training:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 2950/5000 [11:29:07<7:43:04, 13.55s/step, loss=0.0826, lr=0.000419]

Step 2950/5000 | Loss: 0.0826 | LR: 0.000419


Training:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 3000/5000 [11:40:35<7:37:53, 13.74s/step, loss=0.0852, lr=0.000408]

Step 3000/5000 | Loss: 0.0852 | LR: 0.000408


Training:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 3050/5000 [11:52:00<7:28:20, 13.80s/step, loss=0.0812, lr=0.000398]

Step 3050/5000 | Loss: 0.0812 | LR: 0.000398


Training:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 3100/5000 [12:03:37<7:29:55, 14.21s/step, loss=0.0813, lr=0.000388]

Step 3100/5000 | Loss: 0.0813 | LR: 0.000388


Training:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 3150/5000 [12:15:22<7:17:52, 14.20s/step, loss=0.0840, lr=0.000378]

Step 3150/5000 | Loss: 0.0840 | LR: 0.000378


Training:  64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 3200/5000 [12:26:55<6:50:28, 13.68s/step, loss=0.0869, lr=0.000368]

Step 3200/5000 | Loss: 0.0869 | LR: 0.000368


Training:  65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 3250/5000 [12:38:23<6:39:59, 13.71s/step, loss=0.0814, lr=0.000357]

Step 3250/5000 | Loss: 0.0814 | LR: 0.000357


Training:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 3300/5000 [12:49:53<6:29:29, 13.75s/step, loss=0.0800, lr=0.000347]

Step 3300/5000 | Loss: 0.0800 | LR: 0.000347


Training:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 3350/5000 [13:01:26<6:20:49, 13.85s/step, loss=0.0815, lr=0.000337]

Step 3350/5000 | Loss: 0.0815 | LR: 0.000337


Training:  68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 3400/5000 [13:13:01<6:14:39, 14.05s/step, loss=0.0845, lr=0.000327]

Step 3400/5000 | Loss: 0.0845 | LR: 0.000327


Training:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 3450/5000 [13:24:32<5:54:01, 13.70s/step, loss=0.0756, lr=0.000317]

Step 3450/5000 | Loss: 0.0756 | LR: 0.000317


Training:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 3500/5000 [13:36:00<5:40:43, 13.63s/step, loss=0.0782, lr=0.000306]

Step 3500/5000 | Loss: 0.0782 | LR: 0.000306


Training:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 3550/5000 [13:47:29<5:34:33, 13.84s/step, loss=0.0769, lr=0.000296]

Step 3550/5000 | Loss: 0.0769 | LR: 0.000296


Training:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 3600/5000 [13:58:59<5:17:52, 13.62s/step, loss=0.0705, lr=0.000286]

Step 3600/5000 | Loss: 0.0705 | LR: 0.000286


Training:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 3650/5000 [14:10:26<5:07:52, 13.68s/step, loss=0.0808, lr=0.000276]

Step 3650/5000 | Loss: 0.0808 | LR: 0.000276


Training:  74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 3700/5000 [14:21:58<5:01:48, 13.93s/step, loss=0.0751, lr=0.000266]

Step 3700/5000 | Loss: 0.0751 | LR: 0.000266


Training:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 3750/5000 [14:33:29<4:49:25, 13.89s/step, loss=0.0714, lr=0.000255]

Step 3750/5000 | Loss: 0.0714 | LR: 0.000255


Training:  76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 3800/5000 [14:44:55<4:32:46, 13.64s/step, loss=0.0739, lr=0.000245]

Step 3800/5000 | Loss: 0.0739 | LR: 0.000245


Training:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 3850/5000 [14:56:21<4:22:54, 13.72s/step, loss=0.0735, lr=0.000235]

Step 3850/5000 | Loss: 0.0735 | LR: 0.000235


Training:  78%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 3900/5000 [15:07:46<4:09:44, 13.62s/step, loss=0.0738, lr=0.000225]

Step 3900/5000 | Loss: 0.0738 | LR: 0.000225


Training:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 3950/5000 [15:19:06<3:59:58, 13.71s/step, loss=0.0792, lr=0.000214]

Step 3950/5000 | Loss: 0.0792 | LR: 0.000214


Training:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 4000/5000 [15:30:31<3:44:54, 13.49s/step, loss=0.0765, lr=0.000204]

Step 4000/5000 | Loss: 0.0765 | LR: 0.000204


Training:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 4050/5000 [15:41:52<3:37:48, 13.76s/step, loss=0.0708, lr=0.000194]

Step 4050/5000 | Loss: 0.0708 | LR: 0.000194


Training:  82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 4100/5000 [15:53:16<3:26:40, 13.78s/step, loss=0.0710, lr=0.000184]

Step 4100/5000 | Loss: 0.0710 | LR: 0.000184


Training:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 4150/5000 [16:33:15<44:57:48, 190.43s/step, loss=0.0681, lr=0.000174]

Step 4150/5000 | Loss: 0.0681 | LR: 0.000174


Training:  84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 4200/5000 [17:01:07<3:02:45, 13.71s/step, loss=0.0731, lr=0.000163]  

Step 4200/5000 | Loss: 0.0731 | LR: 0.000163


Training:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 4250/5000 [17:48:38<6:02:59, 29.04s/step, loss=0.0634, lr=0.000153]  

Step 4250/5000 | Loss: 0.0634 | LR: 0.000153


Training:  86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 4300/5000 [18:56:43<44:32:49, 229.10s/step, loss=0.0666, lr=0.000143]

Step 4300/5000 | Loss: 0.0666 | LR: 0.000143


Training:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 4350/5000 [19:50:04<2:25:43, 13.45s/step, loss=0.0662, lr=0.000133]  

Step 4350/5000 | Loss: 0.0662 | LR: 0.000133


Training:  88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 4400/5000 [20:52:28<6:19:52, 37.99s/step, loss=0.0661, lr=0.000123]  

Step 4400/5000 | Loss: 0.0661 | LR: 0.000123


Training:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 4450/5000 [21:46:06<2:10:31, 14.24s/step, loss=0.0660, lr=0.000112]  

Step 4450/5000 | Loss: 0.0660 | LR: 0.000112


Training:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 4500/5000 [21:57:44<1:55:41, 13.88s/step, loss=0.0597, lr=0.000102]

Step 4500/5000 | Loss: 0.0597 | LR: 0.000102


Training:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 4550/5000 [22:09:16<1:43:19, 13.78s/step, loss=0.0649, lr=0.000100]

Step 4550/5000 | Loss: 0.0649 | LR: 0.000100


Training:  92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 4600/5000 [22:20:45<1:33:02, 13.96s/step, loss=0.0657, lr=0.000100]

Step 4600/5000 | Loss: 0.0657 | LR: 0.000100


Training:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 4650/5000 [22:32:11<1:19:56, 13.71s/step, loss=0.0687, lr=0.000100]

Step 4650/5000 | Loss: 0.0687 | LR: 0.000100


Training:  94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 4700/5000 [22:43:39<1:08:18, 13.66s/step, loss=0.0636, lr=0.000100]

Step 4700/5000 | Loss: 0.0636 | LR: 0.000100


Training:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 4750/5000 [22:55:07<57:05, 13.70s/step, loss=0.0625, lr=0.000100]  

Step 4750/5000 | Loss: 0.0625 | LR: 0.000100


Training:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 4800/5000 [23:06:36<46:24, 13.92s/step, loss=0.0717, lr=0.000100]

Step 4800/5000 | Loss: 0.0717 | LR: 0.000100


Training:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 4850/5000 [23:18:03<34:13, 13.69s/step, loss=0.0584, lr=0.000100]

Step 4850/5000 | Loss: 0.0584 | LR: 0.000100


Training:  98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 4900/5000 [23:29:32<22:44, 13.64s/step, loss=0.0648, lr=0.000100]

Step 4900/5000 | Loss: 0.0648 | LR: 0.000100


Training:  99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 4950/5000 [23:41:06<11:32, 13.85s/step, loss=0.0621, lr=0.000100]

Step 4950/5000 | Loss: 0.0621 | LR: 0.000100


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5000/5000 [23:52:37<00:00, 17.19s/step, loss=0.0652, lr=0.000100]


Step 5000/5000 | Loss: 0.0652 | LR: 0.000100

‚úÖ Reached 5000 steps - stopping training!

PHASE 1 COMPLETED!

‚úÖ Checkpoint saved to: checkpoint_5000.pt
‚úÖ Trained for exactly 5000 steps
‚úÖ Final loss: 0.0652

You can now run PHASE 2 to continue training for 500 more steps!
Memory cleared.


# PHASE 2: Load checkpoint and train for 500 more steps

Run this cell after Phase 1 completes. It will:
- Load the checkpoint from `checkpoint_5000.pt`
- Resume training from step 5000
- Train for 500 more steps (total: 5500 steps)
- Save final checkpoint to `checkpoint_5500.pt`

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
from tqdm import tqdm
import math
import os

# Import DeepSeekV3 from model.py
from model import DeepSeekV3Config, DeepSeekV3ForCausalLM


# Dataset
class TextDataset(Dataset):
    def __init__(self, filepath, tokenizer, block_size=1024):
        with open(filepath, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.tokens = tokenizer.encode(text)

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        input_ids = self.tokens[idx:idx + self.block_size]
        target_ids = self.tokens[idx + 1:idx + self.block_size + 1]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


# ==========================================
# PHASE 2: LOAD CHECKPOINT AND TRAIN 500 MORE STEPS (Plain PyTorch)
# ==========================================

print("=" * 70)
print("PHASE 2: LOADING CHECKPOINT AND TRAINING 500 MORE STEPS")
print("=" * 70)

# Check if checkpoint exists
checkpoint_path = "checkpoint_5000.pt"
if not os.path.exists(checkpoint_path):
    print(f"\n‚ùå ERROR: Checkpoint not found at {checkpoint_path}")
    print("Please run PHASE 1 first to create the checkpoint!")
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Load checkpoint
print(f"\nüìÇ Loading checkpoint from: {checkpoint_path}")
# Note: weights_only=False is safe here because we created this checkpoint ourselves
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

config = checkpoint['config']
previous_step = checkpoint['global_step']

print(f"‚úÖ Checkpoint loaded successfully!")
print(f"   Previous training: {previous_step} steps")

# Configuration
USE_SMALL_MODEL = False  # Should match Phase 1

if USE_SMALL_MODEL:
    block_size = 512
    batch_size = 4
    max_lr = 3e-4
    accumulate_grad_batches = 4
else:
    block_size = 1024
    batch_size = 2
    max_lr = 1e-3
    accumulate_grad_batches = 8

# Training parameters for PHASE 2
warmup_steps = 100  # Already completed in Phase 1
additional_steps = 500  # Train for 500 MORE steps
new_max_steps = previous_step + additional_steps  # Total: 5500 steps
log_interval = 50  # Print every 50 steps to avoid output overflow

print(f"\nPhase 2 Training Plan:")
print(f"  - Starting from step: {previous_step}")
print(f"  - Training for: {additional_steps} more steps")
print(f"  - Final step will be: {new_max_steps}")
print(f"  - Logging interval: Every {log_interval} steps")

# Device setup
if torch.cuda.is_available():
    device = 'cuda'
    print(f"\nUsing device: {device} ({torch.cuda.get_device_name(0)})")
elif torch.backends.mps.is_available():
    device = 'mps'
    print(f"\nUsing device: {device} (Apple Silicon GPU)")
else:
    device = 'cpu'
    print(f"\nUsing device: {device}")

# Setup
tokenizer = tiktoken.get_encoding("gpt2")

# Dataset and DataLoader
dataset = TextDataset("input-1.txt", tokenizer, block_size)
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0,
    pin_memory=False
)

# Model
model = DeepSeekV3ForCausalLM(config)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

print(f"\n‚úÖ Model weights loaded from checkpoint")

total_params = sum(p.numel() for p in model.parameters())
print(f"   Total parameters: {total_params:,}")

# Optimizer - separate decay parameters
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if 'bias' in name or 'norm' in name or 'embed' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.01},
    {'params': no_decay_params, 'weight_decay': 0.0}
], lr=max_lr, betas=(0.9, 0.95), eps=1e-8)

# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

print("‚úÖ Optimizer state restored!")

# Learning rate scheduler
def get_lr(step, warmup_steps, max_steps, max_lr):
    if step < warmup_steps:
        return max_lr * (step / warmup_steps)
    else:
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return max_lr * max(0.1, 1.0 - progress)

# Clear cache
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
elif torch.cuda.is_available():
    torch.cuda.empty_cache()

torch.set_float32_matmul_precision('high')

print("\n" + "=" * 70)
print("STARTING PHASE 2 TRAINING")
print(f"Resuming from step {previous_step}, training to step {new_max_steps}")
print(f"Will stop at EXACTLY {new_max_steps} steps")
print("=" * 70 + "\n")

# Training loop
model.train()
update_step = previous_step  # Start from where we left off
batch_idx = 0  # Batch counter for gradient accumulation
accumulated_loss = 0.0
optimizer.zero_grad()

# Create infinite dataloader
def cycle(dataloader):
    while True:
        for batch in dataloader:
            yield batch

data_iter = cycle(dataloader)

# Progress bar
pbar = tqdm(total=new_max_steps, initial=previous_step, desc="Training", unit="step")

while update_step < new_max_steps:
    # Get batch
    input_ids, target_ids = next(data_iter)
    input_ids = input_ids.to(device)
    target_ids = target_ids.to(device)

    # Forward pass
    outputs = model(input_ids=input_ids, labels=target_ids)
    loss = outputs['loss']

    # Scale loss for gradient accumulation
    loss = loss / accumulate_grad_batches
    loss.backward()

    # Accumulate the UNSCALED loss for logging (multiply back)
    accumulated_loss += loss.item() * accumulate_grad_batches
    batch_idx += 1

    # Update weights after accumulation
    if batch_idx % accumulate_grad_batches == 0:
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update learning rate
        lr = get_lr(update_step, warmup_steps, new_max_steps, max_lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Optimizer step
        optimizer.step()
        optimizer.zero_grad()

        # Calculate average loss over accumulated batches
        avg_loss = accumulated_loss / accumulate_grad_batches
        
        # Update progress bar every step
        pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr:.6f}'})
        
        # Print only at log_interval to avoid output overflow
        if (update_step + 1) % log_interval == 0 or update_step == previous_step:
            print(f"Step {update_step + 1}/{new_max_steps} | Loss: {avg_loss:.4f} | LR: {lr:.6f}")
        
        accumulated_loss = 0.0

        pbar.update(1)
        update_step += 1
        batch_idx = 0  # Reset batch counter

        # EXACT STOP at new_max_steps
        if update_step >= new_max_steps:
            print(f"\n‚úÖ Reached {new_max_steps} steps - stopping training!")
            break

pbar.close()

print("\n" + "=" * 70)
print("PHASE 2 COMPLETED!")
print("=" * 70)

# Save final checkpoint
final_checkpoint_path = "checkpoint_5500.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'total_params': total_params,
    'global_step': update_step,
}, final_checkpoint_path)

print(f"\n‚úÖ Final checkpoint saved to: {final_checkpoint_path}")
print(f"‚úÖ Total training steps: {update_step}")
print(f"‚úÖ Final loss: {avg_loss:.4f}")
print(f"\nüéâ Training complete! Model trained for {previous_step} + {additional_steps} = {update_step} steps")

# Cleanup
del model, optimizer
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
elif torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Memory cleared.")

PHASE 2: LOADING CHECKPOINT AND TRAINING 500 MORE STEPS

üìÇ Loading checkpoint from: checkpoint_5000.pt
‚úÖ Checkpoint loaded successfully!
   Previous training: 5000 steps

Phase 2 Training Plan:
  - Starting from step: 5000
  - Training for: 500 more steps
  - Final step will be: 5500
  - Logging interval: Every 50 steps

Using device: mps (Apple Silicon GPU)

‚úÖ Model weights loaded from checkpoint
   Total parameters: 174,012,288
‚úÖ Optimizer state restored!

STARTING PHASE 2 TRAINING
Resuming from step 5000, training to step 5500
Will stop at EXACTLY 5500 steps



Training:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 5001/5500 [00:31<4:25:49, 31.96s/step, loss=0.0647, lr=0.000100]

Step 5001/5500 | Loss: 0.0647 | LR: 0.000100


Training:  92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 5050/5500 [12:28<1:50:14, 14.70s/step, loss=0.0632, lr=0.000100]

Step 5050/5500 | Loss: 0.0632 | LR: 0.000100


Training:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 5100/5500 [24:36<1:36:33, 14.48s/step, loss=0.0633, lr=0.000100]

Step 5100/5500 | Loss: 0.0633 | LR: 0.000100


Training:  94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 5150/5500 [36:38<1:24:48, 14.54s/step, loss=0.0597, lr=0.000100]

Step 5150/5500 | Loss: 0.0597 | LR: 0.000100


Training:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 5200/5500 [48:41<1:12:21, 14.47s/step, loss=0.0621, lr=0.000100]

Step 5200/5500 | Loss: 0.0621 | LR: 0.000100


Training:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 5250/5500 [1:00:43<1:00:06, 14.43s/step, loss=0.0579, lr=0.000100]

Step 5250/5500 | Loss: 0.0579 | LR: 0.000100


Training:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 5300/5500 [1:12:45<48:30, 14.55s/step, loss=0.0632, lr=0.000100]  

Step 5300/5500 | Loss: 0.0632 | LR: 0.000100


Training:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 5350/5500 [1:24:46<35:48, 14.33s/step, loss=0.0616, lr=0.000100]

Step 5350/5500 | Loss: 0.0616 | LR: 0.000100


Training:  98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 5400/5500 [1:36:52<24:11, 14.52s/step, loss=0.0641, lr=0.000100]

Step 5400/5500 | Loss: 0.0641 | LR: 0.000100


Training:  99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 5450/5500 [1:48:56<12:05, 14.52s/step, loss=0.0574, lr=0.000100]

Step 5450/5500 | Loss: 0.0574 | LR: 0.000100


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5500/5500 [2:00:59<00:00, 14.52s/step, loss=0.0609, lr=0.000100]


Step 5500/5500 | Loss: 0.0609 | LR: 0.000100

‚úÖ Reached 5500 steps - stopping training!

PHASE 2 COMPLETED!

‚úÖ Final checkpoint saved to: checkpoint_5500.pt
‚úÖ Total training steps: 5500
‚úÖ Final loss: 0.0609

üéâ Training complete! Model trained for 5000 + 500 = 5500 steps
Memory cleared.
