# WeatherFlow: Complete Guide

This notebook provides a comprehensive guide to using the WeatherFlow package with real weather data. It demonstrates:

1. Loading and preprocessing ERA5 reanalysis data
2. Training flow matching models on weather data
3. Making predictions with trained models
4. Visualizing and evaluating results
5. Advanced techniques and customizations

## Setup

First, let's set up our environment and import the required packages.

In [None]:
# Add repository root to Python path if running locally
import sys
import os

# Get absolute path to repo root
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
repo_root = os.path.abspath(os.path.join(notebook_dir, '..'))

# Add to path if not already there
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
    print(f"Added {repo_root} to Python path")

In [None]:
# Standard imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
from IPython.display import display
from tqdm.auto import tqdm
import json
import warnings
warnings.filterwarnings("ignore")

# For geographic visualization
try:
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    CARTOPY_AVAILABLE = True
except ImportError:
    print("Cartopy not available. Some visualizations will be limited.")
    CARTOPY_AVAILABLE = False

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100

In [None]:
# Import WeatherFlow package
import weatherflow
from weatherflow.data import ERA5Dataset, create_data_loaders
from weatherflow.models import WeatherFlowMatch, WeatherFlowODE
from weatherflow.training import FlowTrainer
from weatherflow.utils import WeatherVisualizer, FlowVisualizer

print(f"WeatherFlow version: {weatherflow.__version__}")

## 1. Loading and Exploring ERA5 Data

ERA5 is a comprehensive reanalysis dataset from the European Centre for Medium-Range Weather Forecasts (ECMWF). It provides hourly estimates of atmospheric, land, and oceanic climate variables. We'll use the WeatherBench 2 version of ERA5 data for this example.

In [None]:
# Define parameters for data loading
variables = ['z', 't']  # Geopotential height and temperature
pressure_levels = [500]  # 500 hPa pressure level (mid-troposphere)
time_slice = ('2016', '2017')  # 1-year period for the example

try:
    # Attempt to load the ERA5 dataset
    dataset = ERA5Dataset(
        variables=variables,
        pressure_levels=pressure_levels,
        time_slice=time_slice
    )
    
    print(f"Successfully loaded ERA5 data")
    print(f"Dataset length: {len(dataset)} time steps")
    print(f"Dataset shape: {dataset.shape}")
    
    # Get a sample from the dataset
    sample = dataset[0]
    print(f"\nSample structure:")
    for key, value in sample.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.shape}")
        else:
            print(f"{key}: {type(value)}")
    
    data_available = True
    
except Exception as e:
    print(f"Error loading ERA5 data: {str(e)}")
    print("\nUsing synthetic data instead...")
    
    # Create synthetic dataset
    class SyntheticERA5Dataset:
        def __init__(self, variables, pressure_levels, time_slice):
            self.variables = variables
            self.pressure_levels = pressure_levels
            self.n_lat, self.n_lon = 32, 64
            self.time_steps = 100
            print(f"Created synthetic dataset with {len(variables)} variables, {len(pressure_levels)} levels")
            
        def __len__(self):
            return self.time_steps - 1
        
        def __getitem__(self, idx):
            # Create random tensors for input and target
            input_data = torch.randn(len(self.variables), len(self.pressure_levels), self.n_lat, self.n_lon)
            target_data = input_data + torch.randn_like(input_data) * 0.1  # Slightly perturbed for target
            
            return {
                'input': input_data,
                'target': target_data,
                'metadata': {
                    't0': '2016-01-01',
                    't1': '2016-01-02',
                    'variables': self.variables,
                    'pressure_levels': self.pressure_levels
                }
            }
        
        @property
        def shape(self):
            return (len(self.variables), len(self.pressure_levels), self.n_lat, self.n_lon)
    
    dataset = SyntheticERA5Dataset(variables, pressure_levels, time_slice)
    data_available = False

### Visualizing the data

Let's examine our weather data by visualizing some samples.

In [None]:
# Create a visualizer
visualizer = WeatherVisualizer()

# Get a sample
sample = dataset[0]
input_data = sample['input']
target_data = sample['target']

# Create a figure to visualize the data
fig, axes = plt.subplots(2, len(variables), figsize=(len(variables)*6, 10))

# Plot each variable at the input time
for i, var in enumerate(variables):
    # Input data - first time step
    if CARTOPY_AVAILABLE:
        ax = plt.subplot(2, len(variables), i+1, projection=ccrs.Robinson())
        visualizer.plot_weather_field(
            data=input_data[i, 0].numpy(),  # First pressure level
            ax=ax,
            title=f"{var} (Input)",
            colorbar=True,
            projection=ccrs.Robinson()
        )
    else:
        ax = plt.subplot(2, len(variables), i+1)
        im = ax.imshow(input_data[i, 0].numpy(), cmap='viridis')
        plt.colorbar(im, ax=ax)
        ax.set_title(f"{var} (Input)")
    
    # Target data - next time step
    if CARTOPY_AVAILABLE:
        ax = plt.subplot(2, len(variables), i+1+len(variables), projection=ccrs.Robinson())
        visualizer.plot_weather_field(
            data=target_data[i, 0].numpy(),  # First pressure level
            ax=ax,
            title=f"{var} (Target)",
            colorbar=True,
            projection=ccrs.Robinson()
        )
    else:
        ax = plt.subplot(2, len(variables), i+1+len(variables))
        im = ax.imshow(target_data[i, 0].numpy(), cmap='viridis')
        plt.colorbar(im, ax=ax)
        ax.set_title(f"{var} (Target)")

plt.tight_layout()
plt.show()

## 2. Creating Data Loaders

Now we'll create data loaders for training and validation sets, which handle the batching and other preprocessing steps.

In [None]:
# Create data loaders
batch_size = 8

if hasattr(dataset, 'ds') and hasattr(dataset, 'times'):
    # Using real ERA5 data
    try:
        # Split into train and validation sets (80/20)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        
        from torch.utils.data import random_split, DataLoader
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
        
        print(f"Created data loaders with {train_size} training and {val_size} validation samples")
    except Exception as e:
        print(f"Error creating data loaders: {str(e)}")
        # Use weatherflow's built-in function
        train_loader, val_loader = create_data_loaders(
            variables=variables,
            pressure_levels=pressure_levels,
            batch_size=batch_size
        )
else:
    # Using synthetic data
    from torch.utils.data import random_split, DataLoader
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    print(f"Created data loaders with {train_size} training and {val_size} validation samples")

Let's check a batch from the data loader to see if everything is working correctly:

In [None]:
# Get a batch from the train loader
batch = next(iter(train_loader))

# Print batch information
print("Batch structure:")
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: {value.shape}")
    else:
        print(f"{key}: {type(value)}")

## 3. Creating and Training the Model

Now we'll create a WeatherFlowMatch model, which implements flow matching for weather prediction, and train it on our data.

In [None]:
# Set up device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create the model
model = WeatherFlowMatch(
    input_channels=len(variables),
    hidden_dim=128,  # Smaller for this example
    n_layers=3,      # Fewer layers for faster training
    use_attention=True,
    physics_informed=True
)

model = model.to(device)

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

# Create trainer
trainer = FlowTrainer(
    model=model,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    physics_regularization=True,
    physics_lambda=0.1
)

# Summary of the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters")

In [None]:
# Define number of epochs for training
num_epochs = 5  # Small for this example

# Store metrics for plotting
train_losses = []
val_losses = []

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train for one epoch
    train_metrics = trainer.train_epoch(train_loader)
    train_loss = train_metrics['loss']
    train_losses.append(train_loss)
    
    # Validate
    val_metrics = trainer.validate(val_loader)
    val_loss = val_metrics['val_loss']
    val_losses.append(val_loss)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Print metrics
    print(f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
    
    # Save best model (in a real scenario, you'd save to disk)
    if epoch == 0 or val_loss < min(val_losses[:-1]):
        best_model_state = model.state_dict()
        print("New best model!")

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Load best model (in a real scenario, you'd load from disk)
model.load_state_dict(best_model_state)
print("Loaded best model!")

## 4. Making Predictions with the Model

Now that we have a trained model, we can use it to make predictions. We'll wrap our flow matching model with the WeatherFlowODE class, which integrates the learned velocity field to generate predictions at specific time points.

In [None]:
# Create ODE model for prediction
ode_model = WeatherFlowODE(
    flow_model=model,
    solver_method='rk4',  # Runge-Kutta 4 (faster than dopri5 for this example)
    rtol=1e-3,
    atol=1e-3
)

# Set model to evaluation mode
model.eval()

# Get a batch from validation set
val_batch = next(iter(val_loader))
x0 = val_batch['input'].to(device)
x1_true = val_batch['target'].to(device)

# Define prediction times (0 = start, 1 = end)
times = torch.linspace(0, 1, 5).to(device)  # Generate 5 time steps

# Generate predictions
with torch.no_grad():
    try:
        predictions = ode_model(x0, times)
        print(f"Generated predictions with shape: {predictions.shape}")
        prediction_ok = True
    except Exception as e:
        print(f"Error generating predictions: {str(e)}")
        print("Using simple linear interpolation instead...")
        
        # Simple linear interpolation as fallback
        predictions = []
        for t in times:
            t_val = t.item()
            pred_t = x0 * (1 - t_val) + x1_true * t_val
            predictions.append(pred_t)
        predictions = torch.stack(predictions)
        print(f"Generated predictions with shape: {predictions.shape}")
        prediction_ok = False

## 5. Visualizing Predictions

Now let's visualize our predictions to see how the weather evolves over time according to our model.

In [None]:
# Select a sample from the batch to visualize (first one)
sample_idx = 0

# Setup visualization
num_timesteps = len(times)
var_idx = 0  # First variable (e.g., geopotential)
level_idx = 0  # First pressure level

# Create a figure
fig, axes = plt.subplots(1, num_timesteps, figsize=(4*num_timesteps, 4))

# Plot each time step
for i, t in enumerate(times):
    ax = axes[i]
    
    # Get prediction at this time step
    pred = predictions[i, sample_idx, var_idx, level_idx].cpu().numpy()
    
    # Plot
    im = ax.imshow(pred, cmap='viridis')
    ax.set_title(f"t={t.item():.2f}")
    ax.axis('off')
    
    # Add colorbar to the last plot
    if i == num_timesteps - 1:
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle(f"Evolution of {variables[var_idx]} at {pressure_levels[level_idx]} hPa", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Compare the final prediction with the ground truth
for var_idx, var_name in enumerate(variables):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Initial state
    x0_np = x0[sample_idx, var_idx, level_idx].cpu().numpy()
    im0 = axes[0].imshow(x0_np, cmap='viridis')
    axes[0].set_title(f"Initial {var_name}")
    plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    axes[0].axis('off')
    
    # Final prediction
    x1_pred_np = predictions[-1, sample_idx, var_idx, level_idx].cpu().numpy()
    im1 = axes[1].imshow(x1_pred_np, cmap='viridis')
    axes[1].set_title(f"Predicted {var_name}")
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    axes[1].axis('off')
    
    # Ground truth
    x1_true_np = x1_true[sample_idx, var_idx, level_idx].cpu().numpy()
    im2 = axes[2].imshow(x1_true_np, cmap='viridis')
    axes[2].set_title(f"Ground Truth {var_name}")
    plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    axes[2].axis('off')
    
    plt.suptitle(f"{var_name} at {pressure_levels[level_idx]} hPa", fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Calculate error metrics
    mse = ((x1_pred_np - x1_true_np) ** 2).mean()
    mae = np.abs(x1_pred_np - x1_true_np).mean()
    
    print(f"{var_name} Metrics:")
    print(f"MSE: {mse:.6f}")
    print(f"MAE: {mae:.6f}")

## 6. Analyzing Model Behavior

Let's analyze how our model behaves as we vary certain parameters, like the time parameter in the flow model.

In [None]:
# Select a sample input
x_sample = x0[sample_idx:sample_idx+1]

# Try different time values
time_values = [0.0, 0.25, 0.5, 0.75, 1.0]
velocities = []

# Compute the velocity field for each time value
with torch.no_grad():
    for t in time_values:
        t_tensor = torch.tensor([t], device=device)
        v = model(x_sample, t_tensor)
        velocities.append(v.cpu())

# Visualize the velocity field for the first variable
var_idx = 0
level_idx = 0

fig, axes = plt.subplots(1, len(time_values), figsize=(4*len(time_values), 4))

for i, (t, v) in enumerate(zip(time_values, velocities)):
    ax = axes[i]
    
    # Get velocity at this time step
    v_field = v[0, var_idx, level_idx].numpy()
    
    # Plot
    im = ax.imshow(v_field, cmap='coolwarm')
    ax.set_title(f"t={t:.2f}")
    ax.axis('off')
    
    # Add colorbar to the last plot
    if i == len(time_values) - 1:
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle(f"Velocity Field for {variables[var_idx]} at Different Times", fontsize=16)
plt.tight_layout()
plt.show()

## 7. Advanced Visualization - Flow Field

Let's visualize the flow field of our model to better understand the predicted weather dynamics.

In [None]:
# Initialize flow visualizer
flow_vis = FlowVisualizer()

# Check if we have multiple variables for vector field visualization
if len(variables) >= 2:
    # Compute the flow at t=0.5
    t_mid = torch.tensor([0.5], device=device)
    with torch.no_grad():
        v_mid = model(x_sample, t_mid).cpu()[0]
    
    # Extract u, v components (first two variables)
    u = v_mid[0, level_idx].numpy()
    v = v_mid[1, level_idx].numpy()
    
    # Create the plot
    plt.figure(figsize=(10, 8))
    
    if CARTOPY_AVAILABLE:
        ax = plt.axes(projection=ccrs.PlateCarree())
        ax.coastlines()
        
        # Downsample for clearer visualization
        stride = 2
        lats = np.linspace(-90, 90, u.shape[0])[::stride]
        lons = np.linspace(-180, 180, u.shape[1])[::stride]
        
        lon_grid, lat_grid = np.meshgrid(lons, lats)
        
        # Plot flow vectors
        ax.quiver(lon_grid, lat_grid, 
                 u[::stride, ::stride], v[::stride, ::stride],
                 scale=50, width=0.002)
        
        # Add gridlines and background
        ax.gridlines(draw_labels=True)
        ax.add_feature(cfeature.LAND, facecolor='lightgray')
        ax.add_feature(cfeature.COASTLINE)
        
        plt.title(f"Flow Field at t=0.5")
    else:
        # Simple vector field plot without cartopy
        y, x = np.mgrid[0:u.shape[0]:5, 0:u.shape[1]:5]
        plt.quiver(x, y, u[::5, ::5], v[::5, ::5])
        plt.title(f"Flow Field at t=0.5 (u={variables[0]}, v={variables[1]})")
    
    plt.tight_layout()
    plt.show()
else:
    print("Need at least two variables for vector field visualization")

## 8. Making Multi-step Predictions

Let's use our model to make multi-step predictions (forecasting several days ahead).

In [None]:
# Function to make multi-step predictions
def multi_step_predict(model, x0, n_steps=3):
    """Make multi-step predictions using the flow model."""
    # Set model to evaluation mode
    model.eval()
    
    # Initial state
    x = x0.clone()
    predictions = [x]
    
    # Make predictions step by step
    with torch.no_grad():
        for step in range(n_steps):
            try:
                # Use ODE model for one step
                times = torch.tensor([0.0, 1.0], device=x.device)
                x_next = ode_model(x, times)[-1]  # Take the final state
            except Exception as e:
                # Fallback to direct prediction
                t = torch.ones(x.size(0), device=x.device)
                v = model(x, t)
                x_next = x + v  # Simple Euler integration
            
            predictions.append(x_next)
            x = x_next
    
    return predictions

In [None]:
# Make multi-step predictions
n_steps = 3  # Predict 3 steps ahead
multi_step_preds = multi_step_predict(model, x_sample, n_steps)

# Visualize multi-step predictions
var_idx = 0  # First variable
level_idx = 0  # First pressure level

fig, axes = plt.subplots(1, n_steps+1, figsize=(4*(n_steps+1), 4))

for i, pred in enumerate(multi_step_preds):
    ax = axes[i]
    
    # Get prediction
    pred_np = pred[0, var_idx, level_idx].cpu().numpy()
    
    # Plot
    im = ax.imshow(pred_np, cmap='viridis')
    ax.set_title(f"Step {i}")
    ax.axis('off')
    
    # Add colorbar to the last plot
    if i == n_steps:
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle(f"Multi-step Prediction of {variables[var_idx]} at {pressure_levels[level_idx]} hPa", fontsize=16)
plt.tight_layout()
plt.show()

## 9. Creating an Animation

Let's create an animation to visualize the flow dynamics more clearly.

In [None]:
# Try to create an animation (if the environment supports it)
try:
    from matplotlib.animation import FuncAnimation
    from IPython.display import HTML
    
    # Generate more dense predictions for smoother animation
    times = torch.linspace(0, 1, 20).to(device)  # 20 time steps
    
    with torch.no_grad():
        try:
            dense_preds = ode_model(x_sample, times)
            animation_data = dense_preds[:, 0, var_idx, level_idx].cpu().numpy()
        except Exception as e:
            # Fallback to simple interpolation
            print(f"Using simple interpolation for animation")
            animation_data = []
            x_start = x_sample[0, var_idx, level_idx].cpu().numpy()
            x_end = x1_true[sample_idx, var_idx, level_idx].cpu().numpy()
            
            for t in times:
                t_val = t.item()
                frame = x_start * (1 - t_val) + x_end * t_val
                animation_data.append(frame)
            animation_data = np.array(animation_data)
    
    # Create figure for animation
    fig, ax = plt.subplots(figsize=(8, 6))
    plt.close()  # Close the figure to prevent display
    
    # Find global min/max for consistent colorbar
    vmin = animation_data.min()
    vmax = animation_data.max()
    
    # Initial plot
    im = ax.imshow(animation_data[0], cmap='viridis', vmin=vmin, vmax=vmax)
    title = ax.set_title(f"{variables[var_idx]} at t=0.00")
    plt.colorbar(im, ax=ax)
    
    # Animation update function
    def update(frame):
        im.set_array(animation_data[frame])
        title.set_text(f"{variables[var_idx]} at t={times[frame].item():.2f}")
        return [im, title]
    
    # Create animation
    anim = FuncAnimation(fig, update, frames=len(animation_data), interval=200, blit=True)
    
    # Display in notebook
    HTML(anim.to_jshtml())
    
except Exception as e:
    print(f"Could not create animation: {str(e)}")
    print("Try running this notebook in an environment with animation support.")

## 10. Conclusion and Next Steps

In this notebook, we've demonstrated the complete workflow for using the WeatherFlow package:

1. Loading and exploring ERA5 reanalysis data
2. Creating data loaders for training and validation
3. Building and training a flow matching model
4. Making predictions using the trained model
5. Visualizing and analyzing the results
6. Creating advanced visualizations and animations

### Next Steps

To build on what we've learned:

1. **Try different variables**: Experiment with other atmospheric variables like humidity or winds.
2. **Modify model architecture**: Adjust the number of layers, hidden dimensions, or try without attention.
3. **Experiment with physics constraints**: Toggle physics-informed constraints and see how they affect results.
4. **Train for longer**: Increase the number of epochs for better convergence.
5. **Evaluate on test data**: Use a separate test set to evaluate model performance.
6. **Compare with other methods**: Try simple baselines or other prediction models for comparison.

For more information, refer to the WeatherFlow documentation and examples.