# Flow Matching for Weather Prediction: Fundamentals

This notebook introduces the foundational concepts of flow matching and how they apply to weather prediction. We will:

1. Understand the mathematical foundations of flow matching
2. Implement a simple flow matching model
3. Visualize flow fields and trajectory evolution
4. Connect these concepts to weather prediction
5. Explore physical constraints in flow matching

Flow matching is a powerful approach for modeling complex dynamical systems like weather, as it allows us to learn continuous transformations between states.

## 1. Setup and Dependencies

In [None]:
# Install WeatherFlow if needed
try:
    import weatherflow
    print(f"WeatherFlow version: {weatherflow.__version__}")
except ImportError:
    !pip install -e ..
    import weatherflow
    print(f"WeatherFlow installed, version: {weatherflow.__version__}")

In [None]:
# Import standard libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.notebook import tqdm
import os
import warnings
warnings.filterwarnings('ignore')  # Suppress some warnings for cleaner output

# Import specific WeatherFlow components
from weatherflow.models.flow_matching import WeatherFlowMatch
from weatherflow.utils import WeatherVisualizer

# Set up matplotlib
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['figure.dpi'] = 100

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Flow Matching Theory

Flow matching is a technique for learning continuous transformations between probability distributions. It's closely related to continuous normalizing flows and can be used to model complex dynamical systems like weather evolution.

### Key Concepts

1. **Flow Fields**: Continuous vector fields that describe how a system evolves over time
2. **Path Interpolation**: Creating smooth paths between source and target states
3. **Vector Field Learning**: Learning velocity fields that can generate these paths
4. **ODE Integration**: Using learned vector fields to generate new trajectories

### Mathematical Foundation

In flow matching, we learn a continuous-time flow that transforms a source distribution $p_0(\mathbf{x})$ into a target distribution $p_1(\mathbf{x})$. 

The key equation is:

$$\mathbf{v}(\mathbf{x}_t, t) = \frac{d\mathbf{x}_t}{dt}$$

Where $\mathbf{v}(\mathbf{x}_t, t)$ is the velocity field at point $\mathbf{x}_t$ and time $t$.

For straight-line paths between $\mathbf{x}_0$ and $\mathbf{x}_1$, the target velocity is simply:

$$\mathbf{v}_\text{target}(\mathbf{x}_t, t) = \frac{\mathbf{x}_1 - \mathbf{x}_0}{1}$$

The goal is to learn a neural network that can approximate this velocity field.

### How Flow Matching Differs from Other Approaches

Flow matching has several advantages for weather prediction:

1. **Continuous Time**: Models weather as a continuous process, unlike discrete steps in many ML approaches
2. **Physical Constraints**: Can incorporate physical laws directly into the flow field
3. **Uncertainty Quantification**: Naturally models distributions over possible weather states
4. **Flexible Integration**: Can use different numerical methods for different accuracy/speed trade-offs

## 3. Simple Example: 2D Toy Problem

To build intuition, let's start with a simple 2D problem: learning a flow that transforms a Gaussian distribution into a mixture of Gaussians.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Function to generate data from a simple Gaussian
def sample_gaussian(n_samples, mean=[0, 0], std=1.0):
    return np.random.normal(mean, std, size=(n_samples, 2))

# Function to generate data from a mixture of Gaussians
def sample_mixture(n_samples, means=[[2, 2], [-2, 2], [0, -2]], std=0.5):
    k = len(means)
    # Randomly choose which Gaussian to sample from
    indices = np.random.choice(k, size=n_samples)
    samples = np.zeros((n_samples, 2))
    
    for i in range(n_samples):
        gaussian_idx = indices[i]
        samples[i] = np.random.normal(means[gaussian_idx], std)
        
    return samples

# Generate samples
n_samples = 1000
source_samples = sample_gaussian(n_samples)  # Simple Gaussian
target_samples = sample_mixture(n_samples)   # Mixture of Gaussians

# Plot the source and target distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Source distribution
axes[0].scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.5, s=10)
axes[0].set_title('Source Distribution: Single Gaussian')
axes[0].set_xlim(-4, 4)
axes[0].set_ylim(-4, 4)
axes[0].grid(True)
axes[0].set_aspect('equal')

# Target distribution
axes[1].scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.5, s=10)
axes[1].set_title('Target Distribution: Mixture of Gaussians')
axes[1].set_xlim(-4, 4)
axes[1].set_ylim(-4, 4)
axes[1].grid(True)
axes[1].set_aspect('equal')

plt.tight_layout()
plt.show()

### 3.1 Creating Training Pairs for Flow Matching

For flow matching, we need matching pairs of points from the source and target distributions. We'll use simple random pairing for this toy example.

In [None]:
# Convert to PyTorch tensors
source_tensor = torch.tensor(source_samples, dtype=torch.float32)
target_tensor = torch.tensor(target_samples, dtype=torch.float32)

# Create dataset and dataloader
dataset = TensorDataset(source_tensor, target_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Let's visualize a few pairs of points
n_vis = 10
fig, ax = plt.subplots(figsize=(8, 8))

# Plot all samples as background
ax.scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.2, s=10, color='blue', label='Source')
ax.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.2, s=10, color='red', label='Target')

# Plot a few pairs with connecting lines
for i in range(n_vis):
    ax.plot([source_samples[i, 0], target_samples[i, 0]],
             [source_samples[i, 1], target_samples[i, 1]],
             'k-', alpha=0.3)
    
ax.set_title('Matching Pairs for Flow Learning')
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.grid(True)
ax.legend()
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

### 3.2 Simple Flow Matching Model

Now, let's implement a simple flow matching model for this 2D problem.

In [None]:
class SimpleFlowModel(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        
        # Network architecture
        self.net = nn.Sequential(
            # Input: x and t (2+1=3 dimensions)
            nn.Linear(2 + 1, hidden_dim),
            nn.SiLU(),  # Smooth activation function
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            # Output: velocity vector (2 dimensions)
            nn.Linear(hidden_dim, 2)
        )
    
    def forward(self, x, t):
        """Compute velocity at point x and time t."""
        # Concatenate x and t
        if t.dim() == 1:
            # Add channel dimension to t
            t = t.unsqueeze(1)
        
        xt = torch.cat([x, t], dim=1)
        
        # Compute velocity
        velocity = self.net(xt)
        return velocity
    
    def compute_flow_loss(self, x0, x1, t):
        """Compute flow matching loss."""
        # Compute straight-line velocity target
        v_target = x1 - x0  # For t in [0, 1], velocity = displacement
        
        # Interpolate between x0 and x1 at time t
        x_t = x0 + t.unsqueeze(1) * (x1 - x0)
        
        # Predict velocity
        v_pred = self(x_t, t)
        
        # Compute MSE loss
        loss = F.mse_loss(v_pred, v_target)
        return loss

In [None]:
# Instantiate model
model = SimpleFlowModel(hidden_dim=64).to(device)

# Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
n_epochs = 50
losses = []

for epoch in tqdm(range(n_epochs)):
    epoch_loss = 0
    
    for x0, x1 in dataloader:
        x0, x1 = x0.to(device), x1.to(device)
        
        # Generate random times between 0 and 1
        t = torch.rand(x0.size(0), device=device)
        
        # Compute loss
        loss = model.compute_flow_loss(x0, x1, t)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * x0.size(0)
    
    # Average loss for the epoch
    epoch_loss /= len(dataset)
    losses.append(epoch_loss)
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.6f}")

# Plot training loss
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.yscale('log')
plt.show()

### 3.3 Visualize Learned Flow Field

Now let's visualize the flow field that our model has learned.

In [None]:
# Create a grid of points
grid_size = 20
x = np.linspace(-4, 4, grid_size)
y = np.linspace(-4, 4, grid_size)
X, Y = np.meshgrid(x, y)

# Put model in evaluation mode
model.eval()

# Visualize flow field at different time steps
time_steps = [0.0, 0.25, 0.5, 0.75, 1.0]
fig, axes = plt.subplots(1, len(time_steps), figsize=(20, 4))

with torch.no_grad():
    for i, t_val in enumerate(time_steps):
        # Prepare grid points
        grid_points = np.stack([X.flatten(), Y.flatten()], axis=1)
        grid_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)
        t = torch.ones(grid_points.shape[0], device=device) * t_val
        
        # Compute velocities
        velocities = model(grid_tensor, t).cpu().numpy()
        
        # Reshape for plotting
        U = velocities[:, 0].reshape(grid_size, grid_size)
        V = velocities[:, 1].reshape(grid_size, grid_size)
        
        # Calculate velocity magnitude for coloring
        speed = np.sqrt(U**2 + V**2)
        
        # Plot
        axes[i].streamplot(X, Y, U, V, density=1.5, color=speed, cmap='viridis',
                          linewidth=1, arrowsize=1.5)
        axes[i].set_title(f"t = {t_val}")
        axes[i].set_xlim(-4, 4)
        axes[i].set_ylim(-4, 4)
        axes[i].grid(True)
        axes[i].set_aspect('equal')

plt.tight_layout()
plt.show()

### 3.4 Generate Trajectories using ODE Solver

Now that we have a learned flow field, we can use an ODE solver to generate trajectories from the source to the target distribution.

In [None]:
# Import ODE solver
from torchdiffeq import odeint

# Create a function that returns the velocity at a given point and time
def vector_field(t, x):
    """Vector field function for ODE solver."""
    model.eval()
    with torch.no_grad():
        return model(x, t)

# Choose initial points from the source distribution
n_trajectories = 10
initial_points = source_tensor[:n_trajectories].to(device)

# Define time span
time_span = torch.linspace(0, 1, 100).to(device)

# Solve ODE
trajectories = odeint(vector_field, initial_points, time_span, method='dopri5')

# Move trajectories to CPU and convert to numpy
trajectories = trajectories.cpu().numpy()

# Plot the trajectories
fig, ax = plt.subplots(figsize=(8, 8))

# Plot all source and target samples
ax.scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.2, s=10, color='blue', label='Source')
ax.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.2, s=10, color='red', label='Target')

# Plot the trajectories
for i in range(n_trajectories):
    ax.plot(trajectories[:, i, 0], trajectories[:, i, 1], alpha=0.7, linewidth=1)

ax.set_title('Generated Trajectories')
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.grid(True)
ax.legend()
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

## 4. Extending to Weather-Like Data

Now let's move beyond toy problems and extend our approach to more weather-like data. To make the connection to real weather data, we will:

1. Generate synthetic weather-like data
2. Train a flow matching model on this data
3. Visualize the learned flow field

Let's create the synthetic data:

In [None]:
# Generate synthetic weather-like data
def sample_weather(n_samples, base_flow=2.0, perturbation_std=0.5):
    """Generate synthetic weather-like data with a base flow and perturbations."""
    # Base flow (e.g., jet stream)
    base_flow_x = np.ones(n_samples) * base_flow
    base_flow_y = np.zeros(n_samples)
    base_flow = np.stack([base_flow_x, base_flow_y], axis=1)
    
    # Perturbations (e.g., storms)
    perturbations = np.random.normal(0, perturbation_std, size=(n_samples, 2))
    
    # Combine base flow and perturbations
    weather_data = base_flow + perturbations
    return weather_data

# Generate training data
n_samples = 1000
weather_source = sample_weather(n_samples)
weather_target = sample_weather(n_samples, base_flow=1.0, perturbation_std=0.75)

# Convert to tensors
weather_source_tensor = torch.tensor(weather_source, dtype=torch.float32)
weather_target_tensor = torch.tensor(weather_target, dtype=torch.float32)

# Create dataset and dataloader
weather_dataset = TensorDataset(weather_source_tensor, weather_target_tensor)
weather_dataloader = DataLoader(weather_dataset, batch_size=64, shuffle=True)

# Plot the synthetic weather data
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot source distribution
axes[0].scatter(weather_source[:, 0], weather_source[:, 1], alpha=0.5, s=10)
axes[0].set_title('Source Weather Data')
axes[0].set_xlim(-4, 4)
axes[0].set_ylim(-4, 4)
axes[0].grid(True)
axes[0].set_aspect('equal')

# Plot target distribution
axes[1].scatter(weather_target[:, 0], weather_target[:, 1], alpha=0.5, s=10)
axes[1].set_title('Target Weather Data')
axes[1].set_xlim(-4, 4)
axes[1].set_ylim(-4, 4)
axes[1].grid(True)
axes[1].set_aspect('equal')

plt.tight_layout()
plt.show()

Now that we have the data, lets train a flow matching model:

In [None]:
# Train the flow matching model
weather_model = SimpleFlowModel(hidden_dim=64).to(device)
weather_optimizer = torch.optim.Adam(weather_model.parameters(), lr=1e-3)

# Training loop
n_epochs = 50
weather_losses = []

for epoch in tqdm(range(n_epochs)):
    epoch_loss = 0
    
    for x0, x1 in weather_dataloader:
        x0, x1 = x0.to(device), x1.to(device)
        
        # Generate random times between 0 and 1
        t = torch.rand(x0.size(0), device=device)
        
        # Compute loss
        loss = weather_model.compute_flow_loss(x0, x1, t)
        
        # Backward pass and optimize
        weather_optimizer.zero_grad()
        loss.backward()
        weather_optimizer.step()
        
        epoch_loss += loss.item() * x0.size(0)
    
    # Average loss for the epoch
    epoch_loss /= len(weather_dataset)
    weather_losses.append(epoch_loss)
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.6f}")

## 5. Visualizing the learned flow field

In [None]:
# Set up the visualization
# Visualize learned flow field
weather_model.eval()

# Define a function to create the plot flow field

In [None]:
# Visualize flow field for the weather-like data
def plot_weather_flow_field(model, time_val, grid_size=32):
    """Plot the weather flow field at a specific time."""
    # Create grid for visualization
    x = np.linspace(-4, 4, grid_size)
    y = np.linspace(-4, 4, grid_size)
    X, Y = np.meshgrid(x, y)
    grid_points = np.stack([X.flatten(), Y.flatten()], axis=1)
    
    # Convert to tensor
    grid_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)
    
    # Create time tensor
    t = torch.ones(grid_points.shape[0], device=device) * time_val
    
    # Get flow field
    with torch.no_grad():
        velocities = model(grid_tensor, t).cpu().numpy()
    
    # Reshape for plotting
    U = velocities[:, 0].reshape(grid_size, grid_size)
    V = velocities[:, 1].reshape(grid_size, grid_size)
    
    # Plot flow field with streamline
    plt.streamplot(X, Y, U, V, color='k')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Learned Flow Field at t=' + str(time_val))
    plt.show()

# Test the visualization plot
# Plot flow fields at different time steps
time_steps = [0.0, 0.25, 0.5, 0.75, 1.0]
fig, axes = plt.subplots(1, len(time_steps), figsize=(20, 4))

for i, t_val in enumerate(time_steps):
    
    # Prepare grid points
    grid_size = 20
    x = np.linspace(-4, 4, grid_size)
    y = np.linspace(-4, 4, grid_size)
    X, Y = np.meshgrid(x, y)
    
    with torch.no_grad():
        # Prepare grid points
        grid_points = np.stack([X.flatten(), Y.flatten()], axis=1)
        grid_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)
        t = torch.ones(grid_points.shape[0], device=device) * t_val

        # Compute velocities
        velocities = weather_model(grid_tensor, t).cpu().numpy()

        # Reshape for plotting
        U = velocities[:, 0].reshape(grid_size, grid_size)
        V = velocities[:, 1].reshape(grid_size, grid_size)

        # Calculate velocity magnitude for coloring
        speed = np.sqrt(U**2 + V**2)

        # Plot
        axes[i].streamplot(X, Y, U, V, density=1.5, color=speed, cmap='viridis',
                          linewidth=1, arrowsize=1.5)
        axes[i].set_title(f"t = {t_val}")
        axes[i].set_xlim(-4, 4)
        axes[i].set_ylim(-4, 4)
        axes[i].grid(True)
        axes[i].set_aspect('equal')

plt.tight_layout()
plt.show()

In [None]:
# 5. Physics-Constrained Flow Matching

# Now let's incorporate physics constraints into our flow model
class PhysicsConstrainedFlow(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()

        # Same network architecture as before
        self.net = nn.Sequential(
            nn.Linear(2 + 1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x, t):
        """Compute velocity at point x and time t with physics constraints."""
        # Concatenate x and t
        if t.dim() == 1:
            t = t.unsqueeze(1)

        xt = torch.cat([x, t], dim=1)

        # Compute raw velocity
        v_raw = self.net(xt)

        # Apply physics constraints
        v_constrained = self._apply_divergence_free_constraint(v_raw, x)

        return v_constrained

    def _apply_divergence_free_constraint(self, v, x):
        """Apply divergence-free constraint to make the flow incompressible.

        For weather, this relates to conservation of mass.
        """
        # Simplified implementation for 2D case
        # In a real implementation, we would compute the curl of a potential function

        # For now, just normalize the vectors to demonstrate the concept
        v_norm = torch.norm(v, dim=1, keepdim=True)
        v_normalized = v / (v_norm + 1e-8)

        # Return normalized vectors (simplified physics constraint)
        return v_normalized * v_norm

    def compute_flow_loss(self, x0, x1, t):
        """Compute flow matching loss with physics regularization."""
        # Compute straight-line velocity target
        v_target = x1 - x0

        # Interpolate between x0 and x1 at time t
        x_t = x0 + t.unsqueeze(1) * (x1 - x0)

        # Predict velocity
        v_pred = self(x_t, t)

        # Compute MSE loss
        flow_loss = F.mse_loss(v_pred, v_target)

        # Add physics-based regularization
        physics_loss = self._compute_physics_loss(v_pred, x_t)

        # Total loss
        total_loss = flow_loss + 0.1 * physics_loss

        return total_loss

    def _compute_physics_loss(self, v, x):
        """Compute physics-based regularization loss.

        For weather, this would include terms for:
        - Divergence-free (continuity equation)
        - Energy conservation
        - Geostrophic balance
        etc.
        """
        # Simple physics loss: encourage smoothness of the vector field
        # In a real implementation, we would have more sophisticated terms

        # Calculate magnitude (for demonstration)
        v_norm = torch.norm(v, dim=1)

        # Penalize very large velocities (simplified energy constraint)
        energy_penalty = torch.mean((v_norm - 1.0)**2)

        return energy_penalty

# Train the physics-constrained model
physics_model = PhysicsConstrainedFlow(hidden_dim=64).to(device)
physics_optimizer = torch.optim.Adam(physics_model.parameters(), lr=1e-3)

# Training loop
n_epochs = 50
physics_losses = []

for epoch in tqdm(range(n_epochs)):
    epoch_loss = 0

    for x0, x1 in weather_dataloader:
        x0, x1 = x0.to(device), x1.to(device)

        # Generate random times between 0 and 1
        t = torch.rand(x0.size(0), device=device)

        # Compute loss
        loss = physics_model.compute_flow_loss(x0, x1, t)

        # Backward pass and optimize
        physics_optimizer.zero_grad()
        loss.backward()
        physics_optimizer.step()

        epoch_loss += loss.item() * x0.size(0)

    # Average loss for the epoch
    epoch_loss /= len(weather_dataset)
    physics_losses.append(epoch_loss)

    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.6f}")


In [None]:
# Compare flow fields from standard and physics-constrained models
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Time points to visualize
vis_times = [0.0, 0.5, 1.0]

for i, t_val in enumerate(vis_times):
    # Standard model
    X, Y, U, V = plot_weather_flow_field(weather_model, t_val)
    speed = np.sqrt(U**2 + V**2)

    axes[0, i].streamplot(X, Y, U, V, density=1.5, color=speed, cmap='viridis')
    axes[0, i].set_title(f"Standard Model, t = {t_val}")
    axes[0, i].set_aspect('equal')
    axes[0, i].grid(True)

    # Physics-constrained model
    X, Y, U, V = plot_weather_flow_field(physics_model, t_val)
    speed = np.sqrt(U**2 + V**2)

    axes[1, i].streamplot(X, Y, U, V, density=1.5, color=speed, cmap='viridis')
    axes[1, i].set_title('title')
    axes[1, i].set_aspect('equal')
    axes[1, i].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# 6. Connection to the WeatherFlow Library

print("""
## WeatherFlow Library Implementation

The WeatherFlow library implements these concepts at scale for real weather data:

1. **WeatherFlowMatch Model**: Neural network for learning weather flow fields
   - Convolutional architecture for spatial structure
   - Time embedding for temporal dynamics
   - Physics-informed constraints for physical consistency

2. **ODE Integration**: Uses torchdiffeq for generating predictions
   - WeatherFlowODE class wraps the flow model with an ODE solver
   - Flexible solver methods (Runge-Kutta, Dopri5, etc.)
   - Adjustable tolerances for accuracy vs. speed trade-offs

3. **Spherical Geometry**: Accounts for Earth's spherical surface
   - Proper handling of coordinate systems
   - Accounting for convergence of meridians
   - Managing periodic boundary conditions

4. **Physics Constraints**: Incorporates atmospheric physics
   - Conservation of mass (divergence-free)
   - Energy conservation
   - Geostrophic balance
   - Coriolis effects

In the next notebooks, we'll apply these concepts to real ERA5 weather data.
""")

# Show an example of using the WeatherFlow library for a simple case
from weatherflow.models import WeatherFlowMatch

# Create a toy example input (batch_size=1, channels=2, height=16, width=32)
toy_input = torch.randn(1, 2, 16, 32).to(device)
time_points = torch.tensor([0.5]).to(device)

# Create model
model = WeatherFlowMatch(
    input_channels=2,
    hidden_dim=64,
    n_layers=3,
    physics_informed=True
).to(device)

# Forward pass
with torch.no_grad():
    velocity = model(toy_input, time_points)
    print(f'Input shape: {toy_input.shape}')
    print(f'Output velocity shape: {velocity.shape}')
    print(f'Velocity statistics: min={velocity.min().item():.4f}, max={velocity.max().item():.4f}, mean={velocity.mean().item():.4f}')

print("""
## Conclusion

In this notebook, we've explored the fundamentals of flow matching and how it applies to weather prediction:

1. We implemented a simple flow matching model for 2D distributions
2. We visualized flow fields and generated trajectories
3. We extended the approach to weather-like data
4. We incorporated physics constraints for more realistic flows
5. We connected these concepts to the WeatherFlow library

In the next notebook, we'll train a full WeatherFlowMatch model on real ERA5 data and evaluate its predictive performance.
""")
