# 03_local_training.ipynb - ULTRA FAST VERSION

**Optimizations:**
- Mixed Precision (AMP) - 2x faster
- cudnn.benchmark - 10-15% faster
- Fast batch loading
- GPU verification

**Expected: ~2-5 min per epoch on T4**

In [1]:
# Cell 1: Setup with optimizations
import os, gc, glob, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# CRITICAL: Verify GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type != 'cuda':
    print('⚠️ WARNING: Running on CPU! This will be VERY SLOW!')
    print('Enable GPU: Runtime → Change runtime type → GPU')
else:
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
    torch.backends.cudnn.benchmark = True
    print('✓ cudnn.benchmark enabled')

# Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
print('✓ Mixed Precision (AMP) enabled')

# Paths
NOTEBOOK_DIR = os.getcwd()
PROJECT_ROOT = os.path.dirname(NOTEBOOK_DIR) if NOTEBOOK_DIR.endswith('notebooks') else NOTEBOOK_DIR
DATA_DIR = os.path.join(PROJECT_ROOT, 'data', 'batched')
CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
FIGURES_DIR = os.path.join(PROJECT_ROOT, 'figures')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)
print(f'Data dir: {DATA_DIR}')

Device: cpu
Enable GPU: Runtime → Change runtime type → GPU
✓ Mixed Precision (AMP) enabled
Data dir: /content/data/batched


  scaler = GradScaler()
  super().__init__(


In [2]:
# Cell 2: FAST Data Loading
BATCH_SIZE = 32

def get_batch_files(split):
    split_dir = os.path.join(DATA_DIR, split)
    x_files = sorted(glob.glob(os.path.join(split_dir, 'X_batch_*.npy')))
    y_files = sorted(glob.glob(os.path.join(split_dir, 'Y_batch_*.npy')))
    return x_files, y_files

def batch_generator(split, batch_size=BATCH_SIZE, shuffle=False):
    x_files, y_files = get_batch_files(split)
    file_indices = list(range(len(x_files)))
    if shuffle:
        np.random.shuffle(file_indices)
    
    for idx in file_indices:
        X = np.load(x_files[idx])
        Y = np.load(y_files[idx])
        n = len(X)
        indices = np.random.permutation(n) if shuffle else np.arange(n)
        
        for start in range(0, n, batch_size):
            batch_idx = indices[start:start+batch_size]
            x = torch.from_numpy(X[batch_idx]).float().permute(0, 1, 4, 2, 3)
            y = torch.from_numpy(Y[batch_idx]).float().permute(0, 1, 4, 2, 3)
            yield x, y

# Verify data and count batches
train_files, _ = get_batch_files('train')
val_files, _ = get_batch_files('val')
SAMPLES_PER_FILE = 500
BATCHES_PER_FILE = (SAMPLES_PER_FILE + BATCH_SIZE - 1) // BATCH_SIZE  # ceil division
n_train_batches = len(train_files) * BATCHES_PER_FILE
n_val_batches = len(val_files) * BATCHES_PER_FILE
print(f'Train: {len(train_files)} files × {BATCHES_PER_FILE} = {n_train_batches} batches')
print(f'Val: {len(val_files)} files × {BATCHES_PER_FILE} = {n_val_batches} batches')

# Quick benchmark
print('\nBenchmarking data loading...')
t0 = time.time()
for i, (x, y) in enumerate(batch_generator('train')):
    if i >= 50:
        break
print(f'50 batches in {time.time()-t0:.2f}s ({(time.time()-t0)/50*1000:.1f}ms/batch)')

Train: 0 files × 16 = 0 batches
Val: 0 files × 16 = 0 batches

Benchmarking data loading...
50 batches in 0.00s (0.0ms/batch)


In [3]:
# Cell 3: Load Stats
stats = np.load(os.path.join(DATA_DIR, 'stats.npz'), allow_pickle=True)
mean, std = stats['mean'], stats['std']
variables = list(stats['variables'])
print(f'Variables: {variables}')

FileNotFoundError: [Errno 2] No such file or directory: '/content/data/batched/stats.npz'

In [None]:
# Cell 4: Model
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim,
                              kernel_size, padding=kernel_size//2)
    def forward(self, x, hidden):
        h, c = hidden
        gates = self.conv(torch.cat([x, h], dim=1))
        i, f, o, g = torch.chunk(gates, 4, dim=1)
        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        h_next = torch.sigmoid(o) * torch.tanh(c_next)
        return h_next, c_next
    def init_hidden(self, B, H, W, dev):
        return (torch.zeros(B, self.hidden_dim, H, W, device=dev),
                torch.zeros(B, self.hidden_dim, H, W, device=dev))

class WeatherNowcaster(nn.Module):
    def __init__(self, in_ch, hidden_dim, out_ch, n_layers=2):
        super().__init__()
        self.encoder = nn.ModuleList([ConvLSTMCell(in_ch if i==0 else hidden_dim, hidden_dim, 3) for i in range(n_layers)])
        self.decoder = nn.ModuleList([ConvLSTMCell(out_ch if i==0 else hidden_dim, hidden_dim, 3) for i in range(n_layers)])
        self.out_conv = nn.Conv2d(hidden_dim, out_ch, 1)
    
    def forward(self, x, future_steps):
        B, T, C, H, W = x.shape
        hidden = [cell.init_hidden(B, H, W, x.device) for cell in self.encoder]
        for t in range(T):
            inp = x[:, t]
            for i, cell in enumerate(self.encoder):
                h, c = cell(inp, hidden[i])
                hidden[i] = (h, c)
                inp = h
        
        dec_hidden = [(h.clone(), c.clone()) for h, c in hidden]
        outputs = []
        dec_in = self.out_conv(dec_hidden[-1][0])
        for _ in range(future_steps):
            for i, cell in enumerate(self.decoder):
                h, c = cell(dec_in if i==0 else h, dec_hidden[i])
                dec_hidden[i] = (h, c)
            dec_in = self.out_conv(h)
            outputs.append(dec_in)
        return torch.stack(outputs, dim=1)

T_IN, T_OUT = 24, 6
HIDDEN_DIM = 128

model = WeatherNowcaster(2, HIDDEN_DIM, 2, n_layers=2).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

# Benchmark model
print('\nBenchmarking model...')
x_test = torch.randn(BATCH_SIZE, T_IN, 2, 31, 41).to(device)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(10):
    with autocast():
        _ = model(x_test, T_OUT)
    torch.cuda.synchronize()
print(f'10 forward passes in {time.time()-t0:.2f}s ({(time.time()-t0)/10*1000:.1f}ms/batch)')
del x_test
torch.cuda.empty_cache()

In [None]:
# Cell 5: FAST Training with AMP
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

NUM_EPOCHS = 50
best_val_loss = float('inf')
train_losses, val_losses = [], []

print(f'Training: {NUM_EPOCHS} epochs, batch={BATCH_SIZE}, hidden={HIDDEN_DIM}')
print('=' * 60)

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # Train
    model.train()
    train_loss, n_batches = 0.0, 0
    pbar = tqdm(batch_generator('train', BATCH_SIZE, shuffle=True), 
                total=n_train_batches, desc=f'Epoch {epoch+1} [Train]', leave=False)
    
    for x, y in pbar:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        
        with autocast():
            out = model(x, T_OUT)
            loss = criterion(out, y)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss /= n_batches
    train_losses.append(train_loss)
    
    # Validate
    model.eval()
    val_loss, n_val = 0.0, 0
    with torch.no_grad():
        for x, y in tqdm(batch_generator('val', BATCH_SIZE), total=n_val_batches, 
                         desc=f'Epoch {epoch+1} [Val]', leave=False):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            with autocast():
                val_loss += criterion(model(x, T_OUT), y).item()
            n_val += 1
    
    val_loss /= n_val
    val_losses.append(val_loss)
    scheduler.step(val_loss)
    
    epoch_time = time.time() - epoch_start
    marker = '★ BEST' if val_loss < best_val_loss else ''
    print(f'Epoch {epoch+1:2d} | Train: {train_loss:.6f} | Val: {val_loss:.6f} | Time: {epoch_time:.1f}s {marker}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({'model': model.state_dict(), 'val_loss': val_loss}, 
                   os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
    
    gc.collect()
    torch.cuda.empty_cache()

print('=' * 60)
print(f'✓ Done! Best val loss: {best_val_loss:.6f}')

In [None]:
# Cell 6: Plot & Save
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Val')
plt.xlabel('Epoch'); plt.ylabel('MSE Loss')
plt.legend(); plt.grid(alpha=0.3)
plt.savefig(os.path.join(FIGURES_DIR, 'training_curve.png'), dpi=150)
plt.show()

torch.save({
    'model_state_dict': model.state_dict(),
    'config': {'hidden_dim': HIDDEN_DIM, 'n_layers': 2, 'T_IN': T_IN, 'T_OUT': T_OUT},
    'mean': mean, 'std': std, 'variables': variables,
    'best_val_loss': best_val_loss
}, os.path.join(CHECKPOINT_DIR, 'final_model.pth'))
print('Saved final_model.pth')