# WeatherFlow: Colab A100 Quickstart

This notebook gives a GPU-ready demo of WeatherFlow's flow-matching models on synthetic data. Run on a GPU runtime (e.g., Colab A100).

In [None]:
# Install WeatherFlow in a clean venv to avoid Colab preinstalled conflicts
VENV_DIR = "/content/wfenv"
!python -m venv $VENV_DIR
!$VENV_DIR/bin/pip -q install -U pip
!$VENV_DIR/bin/pip -q install -U weatherflow torchdiffeq numpy==1.26.4
import sys
sys.path.insert(0, f"{VENV_DIR}/lib/python3.10/site-packages")

In [None]:
import math
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

from weatherflow.models.flow_matching import WeatherFlowMatch, WeatherFlowODE
from weatherflow.training.flow_trainer import FlowTrainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

torch.manual_seed(42)
torch.backends.cudnn.benchmark = True

In [None]:
class SyntheticWeather(Dataset):
    """Tiny synthetic dataset with smooth sinusoidal patterns."""

    def __init__(self, n_samples=64, grid=(32, 64)):
        self.n_samples = n_samples
        self.grid = grid

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        h, w = self.grid
        y = torch.linspace(0, 2 * math.pi, steps=h)
        x = torch.linspace(0, 2 * math.pi, steps=w)
        yy, xx = torch.meshgrid(y, x, indexing="ij")
        phase = torch.rand(1) * 2 * math.pi
        base = torch.sin(xx + phase) + torch.cos(yy * 2 + phase)
        u = base + 0.1 * torch.randn_like(base)
        v = base.roll(shifts=1, dims=0) + 0.1 * torch.randn_like(base)
        temp = base * 0.5 + 0.05 * torch.randn_like(base)
        z = base * 2 + 0.1 * torch.randn_like(base)
        current = torch.stack([z, temp, u, v], dim=0)
        target = current + 0.02 * torch.randn_like(current)  # simple drift
        return current.float(), target.float()

train_ds = SyntheticWeather(n_samples=128, grid=(32, 64))
val_ds = SyntheticWeather(n_samples=32, grid=(32, 64))

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, pin_memory=True)
next(iter(train_loader))[0].shape

In [None]:
model = WeatherFlowMatch(
    input_channels=4,
    hidden_dim=64,
    n_layers=4,
    use_attention=True,
    window_size=8,
    physics_informed=True,
    use_spectral_mixer=True,
    spectral_modes=8,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
trainer = FlowTrainer(
    model=model,
    optimizer=optimizer,
    device=device,
    use_amp=True,
    grad_clip=1.0,
    ema_decay=0.995,
    noise_std=(0.0, 0.05),
    seed=42,
)

# One quick epoch for demo speed
train_metrics = trainer.train_epoch(train_loader)
val_metrics = trainer.validate(val_loader)
train_metrics, val_metrics

In [None]:
# Fast inference with the Heun-based path and a tiny ensemble
ode = WeatherFlowODE(trainer.model, fast_mode=True)
batch = next(iter(val_loader))
x0 = batch[0].to(device)
times = torch.linspace(0, 1, steps=4, device=device)
with torch.no_grad():
    preds = ode.ensemble_forecast(x0, times, num_members=3, noise_std=0.05)

# Visualize one variable of the first member and time step
sample = preds[0, 0, 0, 0].cpu()  # member, time, batch, channel
plt.figure(figsize=(6, 4))
plt.title("Predicted geopotential (sample)")
plt.imshow(sample, cmap="turbo")
plt.colorbar()
plt.tight_layout()
plt.show()