
# 🤖 Tutorial 04 — Basic ML Model (PyTorch)

We train a simple **MLP** on the arrays produced in Tutorial 03 to predict `t2m`.


## 0. Imports & device

In [None]:

import os
from pathlib import Path
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

# Paths
DATA_DIR = Path("data/processed")
FIG_DIR = Path("figures"); FIG_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR = Path("models"); MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.__version__


## 1. Load preprocessed arrays

In [None]:

X_train = np.load(DATA_DIR / "X_train.npy")
y_train = np.load(DATA_DIR / "y_train.npy")
X_val   = np.load(DATA_DIR / "X_val.npy")
y_val   = np.load(DATA_DIR / "y_val.npy")
X_test  = np.load(DATA_DIR / "X_test.npy")
y_test  = np.load(DATA_DIR / "y_test.npy")

stats = np.load(DATA_DIR / "stats.npz")
print("Loaded shapes:")
print("  X_train", X_train.shape, "y_train", y_train.shape)
print("  X_val  ", X_val.shape,   "y_val",   y_val.shape)
print("  X_test ", X_test.shape,  "y_test",  y_test.shape)
print("Feature means:", stats['mu'].ravel())
print("Feature stds :", stats['sigma'].ravel())
print("Lags used    :", stats['lags'])


## 2. Build PyTorch datasets & loaders

In [None]:

# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_val_t   = torch.tensor(X_val,   dtype=torch.float32)
y_val_t   = torch.tensor(y_val,   dtype=torch.float32).unsqueeze(1)
X_test_t  = torch.tensor(X_test,  dtype=torch.float32)
y_test_t  = torch.tensor(y_test,  dtype=torch.float32).unsqueeze(1)

train_ds = TensorDataset(X_train_t, y_train_t)
val_ds   = TensorDataset(X_val_t,   y_val_t)
test_ds  = TensorDataset(X_test_t,  y_test_t)

BATCH_SIZE = 2048
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

X_train_t.shape, y_train_t.shape


## 3. Define a simple MLP model

In [None]:

input_dim = X_train.shape[1]
hidden_dim = 128
dropout_p = 0.1

class MLP(nn.Module):
    def __init__(self, in_dim, hidden=128, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x)

model = MLP(input_dim, hidden=hidden_dim, dropout=dropout_p).to(device)
model


## 4. Train with early stopping

In [None]:

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

EPOCHS = 50
PATIENCE = 5

best_val = float('inf')
best_state = None
train_losses, val_losses = [], []
no_improve = 0

for epoch in range(1, EPOCHS+1):
    # Train
    model.train()
    running = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        running += loss.item() * xb.size(0)
    train_loss = running / len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validate
    model.eval()
    running = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb)
            loss = criterion(preds, yb)
            running += loss.item() * xb.size(0)
    val_loss = running / len(val_loader.dataset)
    val_losses.append(val_loss)

    print(f"Epoch {epoch:03d} | train {train_loss:.6f} | val {val_loss:.6f}")

    # Early stopping
    if val_loss < best_val - 1e-6:
        best_val = val_loss
        best_state = model.state_dict()
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("Early stopping triggered.")
            break

# Load best model
if best_state is not None:
    model.load_state_dict(best_state)


## 5. Plot loss curves

In [None]:

plt.figure(figsize=(8,4))
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Loss curves (MSE)")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.legend()
plt.tight_layout()
plt.savefig("figures/loss_curve.png", dpi=150)
plt.show()


## 6. Evaluate on the test set

In [None]:

model.eval()
preds_all = []
targets_all = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        out = model(xb).cpu().numpy().ravel()
        preds_all.append(out)
        targets_all.append(yb.numpy().ravel())

y_pred = np.concatenate(preds_all)
y_true = np.concatenate(targets_all)

mse = float(np.mean((y_pred - y_true)**2))
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(y_pred - y_true)))

print(f"Test MSE : {mse:.6f}")
print(f"Test RMSE: {rmse:.6f}")
print(f"Test MAE : {mae:.6f}")


## 7. Predicted vs True scatter

In [None]:

# Sample up to 5000 points for plotting
n = y_true.shape[0]
idx = np.random.choice(n, size=min(5000, n), replace=False) if n > 5000 else np.arange(n)

plt.figure(figsize=(5,5))
plt.scatter(y_true[idx], y_pred[idx], s=5, alpha=0.5)
plt.plot([y_true[idx].min(), y_true[idx].max()], [y_true[idx].min(), y_true[idx].max()])
plt.xlabel("True")
plt.ylabel("Predicted")
plt.title("Predicted vs True")
plt.tight_layout()
plt.savefig("figures/pred_vs_true.png", dpi=150)
plt.show()


## 8. Save model and predictions

In [None]:

model_path = Path("models/mlp_t2m.pt")
torch.save(model.state_dict(), model_path)
np.save(DATA_DIR / "y_pred_test.npy", y_pred)
print("Saved model to", model_path)
print("Saved predictions to", DATA_DIR / "y_pred_test.npy")


## 9. Inference example

In [None]:

def predict(model, X_np, batch_size=4096):
    model.eval()
    preds = []
    with torch.no_grad():
        for i in range(0, X_np.shape[0], batch_size):
            xb = torch.tensor(X_np[i:i+batch_size], dtype=torch.float32, device=device)
            out = model(xb).cpu().numpy().ravel()
            preds.append(out)
    return np.concatenate(preds)

# Example: predict on the first 100 standardized samples of the test set
demo_preds = predict(model, X_test[:100])
demo_preds[:10]



---

## ✅ You trained your first EO ML model!
- Built and trained an **MLP** with PyTorch
- Used **early stopping** and evaluated on a held-out test set
- Visualized training dynamics and prediction quality
- Saved the model and predictions for later use

**Next idea:** Add more variables (winds, humidity), engineer features, or try CNNs on gridded data.
