# ERA5 Data Exploration for Weather Flow Matching

This notebook demonstrates how to load, explore, and visualize ERA5 data for weather prediction using the WeatherFlow library. We'll cover:

1. Loading data from WeatherBench2
2. Exploring the data structure
3. Visualizing different variables
4. Preparing data for model training
5. Computing statistics and climatology

Let's get started!

## 1. Setup and Installation

First, let's make sure we have WeatherFlow and all dependencies installed.

In [None]:
# Install WeatherFlow if needed
try:
    import weatherflow
    print(f"WeatherFlow version: {weatherflow.__version__}")
except ImportError:
    !pip install -e ..
    import weatherflow
    print(f"WeatherFlow installed, version: {weatherflow.__version__}")

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import xarray as xr
import cartopy.crs as ccrs
from tqdm.notebook import tqdm
import os
import warnings
warnings.filterwarnings('ignore')  # Suppress some warnings for cleaner output

# Import WeatherFlow modules
from weatherflow.data import ERA5Dataset, create_data_loaders
from weatherflow.utils import WeatherVisualizer

# Set up matplotlib larger figures
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['figure.dpi'] = 100

## 2. Loading ERA5 Data

WeatherFlow supports loading ERA5 data from multiple sources:

1. WeatherBench2 on Google Cloud Storage
2. Local NetCDF files
3. Custom Zarr datasets

Let's use the WeatherBench2 dataset which contains preprocessed global ERA5 reanalysis data.

In [None]:
# Define variables and pressure levels we're interested in
variables = ['z', 't', 'u', 'v']  # Geopotential, temperature, u-wind, v-wind
pressure_levels = [500]  # 500 hPa level
years = ('2016', '2016')  # Load just one year for faster exploration

# Detailed explanation of variables:
variable_details = {
    'z': 'Geopotential (m²/s²) - Represents atmospheric pressure levels',
    't': 'Temperature (K) - Air temperature',
    'u': 'U-component of wind (m/s) - Eastward wind',
    'v': 'V-component of wind (m/s) - Northward wind',
    'q': 'Specific humidity (kg/kg) - Mass of water vapor per unit mass of air',
    'r': 'Relative humidity (%) - Amount of water vapor relative to maximum possible'
}

# Print selected variables and their descriptions
print("Selected variables:")
for var in variables:
    print(f"  - {var}: {variable_details.get(var, 'Unknown variable')}")
print(f"\nPressure level: {pressure_levels[0]} hPa")
print(f"Time period: {years[0]} to {years[1]}")

In [None]:
# Load ERA5 data with progress information
print("Loading ERA5 data from WeatherBench2...")
try:
    # Try loading with default settings
    era5_data = ERA5Dataset(
        variables=variables,
        pressure_levels=pressure_levels,
        time_slice=years,
        normalize=False,  # Keep original values for exploration
        verbose=True
    )
    print(f"Successfully loaded data with {len(era5_data)} time steps")
except Exception as e:
    print(f"Error loading data: {str(e)}")
    print("\nTrying alternative loading method...")
    
    # If default method fails, try with explicit storage options
    era5_data = ERA5Dataset(
        variables=variables,
        pressure_levels=pressure_levels,
        time_slice=years,
        normalize=False,
        verbose=True,
        add_physics_features=False
    )

## 3. Exploring Data Structure

Let's examine the structure of the loaded data to better understand what we're working with.

In [None]:
# Get basic dataset information
print(f"Dataset shape information:")
print(f"  - Number of time steps: {len(era5_data)}")
print(f"  - Variables: {era5_data.variables}")
print(f"  - Pressure levels: {era5_data.pressure_levels}")
print(f"  - Spatial grid size: {era5_data.ds.latitude.size} × {era5_data.ds.longitude.size}")

# Look at the first sample to understand its structure
sample = era5_data[0]
print("\nSample data structure:")
for key, value in sample.items():
    if isinstance(value, dict):
        print(f"  - {key}: {type(value)}")
        for subkey, subvalue in value.items():
            print(f"      {subkey}: {type(subvalue)}")
    else:
        print(f"  - {key}: {type(value)}, shape: {value.shape}")

In [None]:
# Extract coordinate information
coords = era5_data.get_coords()
lats = coords['latitude']
lons = coords['longitude']

print(f"Latitude range: {lats.min():.2f}° to {lats.max():.2f}°, {len(lats)} points")
print(f"Longitude range: {lons.min():.2f}° to {lons.max():.2f}°, {len(lons)} points")

# Show coordinate spacing (important for certain physical calculations)
lat_spacing = np.mean(np.diff(lats))
lon_spacing = np.mean(np.diff(lons))
print(f"Grid resolution: {lat_spacing:.2f}° latitude × {lon_spacing:.2f}° longitude")

## 4. Visualizing Weather Variables

Now let's visualize each of our variables to get a feel for the data. We'll use the WeatherVisualizer class from WeatherFlow for this.

In [None]:
# Initialize visualizer
visualizer = WeatherVisualizer(figsize=(14, 8))

# Extract first sample (current state)
sample_data = era5_data[0]['input']

# Create a dictionary for visualization
data_dict = {}
for i, var in enumerate(variables):
    # Each variable has shape [levels, lat, lon], select first level
    data_dict[var] = sample_data[i, 0].numpy()  # Convert tensor to numpy

# Plot each variable
for i, var_name in enumerate(variables):
    plt.figure(figsize=(14, 8))
    fig, ax = visualizer.plot_field(
        data_dict[var_name],
        title=f"{var_name} at {pressure_levels[0]} hPa",
        var_name=var_name,
        coastlines=True,
        grid=True
    )
    plt.tight_layout()
    plt.show()

### 4.1 Wind Vector Visualization

Since we have both U and V wind components, we can visualize the vector field to see wind patterns.

In [None]:
# Extract U and V wind components
u_index = variables.index('u')
v_index = variables.index('v')
u_wind = sample_data[u_index, 0].numpy()
v_wind = sample_data[v_index, 0].numpy()

# For background, use geopotential height
z_index = variables.index('z')
geopotential = sample_data[z_index, 0].numpy()

# Calculate wind speed (magnitude)
wind_speed = np.sqrt(u_wind**2 + v_wind**2)

# Plot wind field with geopotential height as background
fig, ax = visualizer.plot_flow_vectors(
    u_wind, v_wind, 
    background=geopotential, 
    var_name='z',
    title=f"Wind Field at {pressure_levels[0]} hPa",
    scale=1.0, 
    density=1.0
)
plt.tight_layout()
plt.show()

# Plot wind speed
plt.figure(figsize=(14, 8))
fig, ax = visualizer.plot_field(
    wind_speed,
    title=f"Wind Speed at {pressure_levels[0]} hPa",
    cmap='YlOrRd',
    coastlines=True,
    grid=True
)
plt.tight_layout()
plt.show()

## 5. Temporal Evolution

Let's look at how variables change over time by extracting and visualizing a sequence of states.

In [None]:
# Number of time steps to visualize
n_steps = 5

# Extract a sequence of states for one variable (geopotential)
var_index = 0  # Index of variable to visualize (geopotential)
level_index = 0  # First pressure level

# Collect time sequence
time_sequence = []
time_stamps = []

for i in range(n_steps):
    if i < len(era5_data):
        sample = era5_data[i]
        # Extract the variable
        time_sequence.append(sample['input'][var_index, level_index].numpy())
        # Extract timestamp from metadata
        time_stamps.append(sample['metadata']['t0'])

# Create animation
print(f"Creating animation for {variables[var_index]} at {pressure_levels[0]} hPa...")
anim = visualizer.create_prediction_animation(
    time_sequence,
    var_name=variables[var_index],
    title=f"{variables[var_index]} Evolution",
    interval=800  # Slower animation for better viewing
)

# Display animation
from IPython.display import HTML
HTML(anim.to_jshtml())

## 6. Data Statistics and Climatology

Understanding the statistical properties of each variable is important for normalization and model training.

In [None]:
# Calculate statistics for each variable
stats = {}

# Number of samples to use for statistics (limit for memory efficiency)
n_samples = min(50, len(era5_data))
print(f"Computing statistics from {n_samples} samples...")

# Initialize arrays to collect data
var_data = {var: [] for var in variables}

# Collect data
for i in tqdm(range(n_samples)):
    sample = era5_data[i]
    for j, var in enumerate(variables):
        var_data[var].append(sample['input'][j].numpy().flatten())

# Compute statistics
for var in variables:
    # Concatenate all samples for this variable
    all_data = np.concatenate(var_data[var])
    
    # Calculate statistics
    stats[var] = {
        'mean': np.mean(all_data),
        'std': np.std(all_data),
        'min': np.min(all_data),
        'max': np.max(all_data),
        '5th_percentile': np.percentile(all_data, 5),
        '95th_percentile': np.percentile(all_data, 95)
    }

# Display statistics
print("\nVariable Statistics:")
for var in variables:
    print(f"\n{var} ({variable_details.get(var, '')})")
    for stat_name, stat_value in stats[var].items():
        print(f"  - {stat_name}: {stat_value:.2f}")

In [None]:
# Visualize distributions
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for i, var in enumerate(variables):
    # Get data for histograms
    all_data = np.concatenate(var_data[var])
    
    # Plot histogram
    axes[i].hist(all_data, bins=50, alpha=0.7, density=True)
    axes[i].set_title(f"{var} Distribution")
    axes[i].set_xlabel(variable_details.get(var, var))
    axes[i].set_ylabel("Density")
    
    # Add vertical lines for mean and std range
    mean = stats[var]['mean']
    std = stats[var]['std']
    axes[i].axvline(mean, color='r', linestyle='--', label=f"Mean: {mean:.2f}")
    axes[i].axvline(mean + std, color='g', linestyle=':', label=f"±1 Std: {std:.2f}")
    axes[i].axvline(mean - std, color='g', linestyle=':')
    axes[i].legend()

plt.tight_layout()
plt.show()

## 7. Data Normalization and Preparation for Training

Based on the statistics we calculated, let's create properly normalized data for model training.

In [None]:
# Create normalized data loaders for training
print("Creating data loaders with normalization...")

# Split data into training and validation
train_years = ('2016', '2016-06')  # First half of 2016
val_years = ('2016-07', '2016-12')  # Second half of 2016

# Create data loaders
train_loader, val_loader = create_data_loaders(
    variables=variables,
    pressure_levels=pressure_levels,
    train_slice=train_years,
    val_slice=val_years,
    batch_size=16,
    num_workers=4,
    normalize=True  # Apply normalization
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

In [None]:
# Visualize normalized data
# Get a batch from the training loader
sample_batch = next(iter(train_loader))

# Plot normalized fields for each variable
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for i, var in enumerate(variables):
    # Extract normalized field
    normalized_field = sample_batch['input'][0, i, 0].numpy()
    
    # Plot
    im = axes[i].imshow(normalized_field, cmap=visualizer.VAR_CMAPS.get(var, 'viridis'))
    axes[i].set_title(f"Normalized {var}")
    plt.colorbar(im, ax=axes[i])

plt.tight_layout()
plt.show()

## 8. Temporal Patterns and Lag Correlation

Understanding the temporal correlation in weather data is crucial for flow matching. Let's examine how variables evolve over short time periods.

In [None]:
# Select a specific location (grid point) to examine
lat_idx = len(lats) // 2  # Middle latitude (roughly equator)
lon_idx = len(lons) // 2  # Middle longitude

print(f"Selected location: Latitude {lats[lat_idx]:.2f}°, Longitude {lons[lon_idx]:.2f}°")

# Extract time series for each variable at this location
n_samples = min(100, len(era5_data))
time_series = {var: [] for var in variables}
timestamps = []

for i in tqdm(range(n_samples)):
    sample = era5_data[i]
    timestamps.append(sample['metadata']['t0'])
    
    for j, var in enumerate(variables):
        # Extract value at the selected location
        value = sample['input'][j, 0, lat_idx, lon_idx].item()
        time_series[var].append(value)

# Convert timestamps to datetime objects for better plotting
import pandas as pd
datetimes = pd.to_datetime(timestamps)

# Plot time series
fig, axes = plt.subplots(len(variables), 1, figsize=(14, 12), sharex=True)

for i, var in enumerate(variables):
    axes[i].plot(datetimes, time_series[var], '-o', markersize=4)
    axes[i].set_title(f"{var} - {variable_details.get(var, '')}")
    axes[i].set_ylabel(var)
    axes[i].grid(True)

axes[-1].set_xlabel("Time")
plt.tight_layout()
plt.show()

In [None]:
# Calculate lag correlations to understand predictability
max_lag = 10  # Maximum lag in time steps
lag_corrs = {var: [] for var in variables}

for var in variables:
    # Get the time series data
    ts = np.array(time_series[var])
    
    # Calculate autocorrelation for different lags
    for lag in range(max_lag + 1):
        if lag == 0:
            # Correlation with itself is always 1
            lag_corrs[var].append(1.0)
        else:
            # Compute correlation between original series and lagged series
            corr = np.corrcoef(ts[lag:], ts[:-lag])[0, 1]
            lag_corrs[var].append(corr)

# Plot lag correlations
plt.figure(figsize=(12, 6))
lags = range(max_lag + 1)

for var in variables:
    plt.plot(lags, lag_corrs[var], 'o-', label=var)
    
plt.xlabel('Lag (time steps)')
plt.ylabel('Autocorrelation')
plt.title('Temporal Autocorrelation by Variable')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

# Calculate cross-correlations between variables
plt.figure(figsize=(14, 10))
var_data = {}
for var in variables:
    var_data[var] = np.array(time_series[var])

# Create a correlation matrix
corr_matrix = np.zeros((len(variables), len(variables)))
for i, var1 in enumerate(variables):
    for j, var2 in enumerate(variables):
        corr_matrix[i, j] = np.corrcoef(var_data[var1], var_data[var2])[0, 1]

# Plot correlation matrix as a heatmap
plt.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(label='Correlation Coefficient')
plt.xticks(range(len(variables)), variables)
plt.yticks(range(len(variables)), variables)
plt.title('Cross-Correlation Between Variables')

# Add correlation values as text
for i in range(len(variables)):
    for j in range(len(variables)):
        plt.text(j, i, f'{corr_matrix[i, j]:.2f}', 
                 ha='center', va='center', 
                 color='white' if abs(corr_matrix[i, j]) > 0.5 else 'black')

plt.tight_layout()
plt.show()

## 9. Compute derived quantities for physics constraints

In [None]:
print("Computing some important derived physical quantities for flow matching...")

# Extract geopotential, temperature, and wind components
z_index = variables.index('z')
t_index = variables.index('t')
u_index = variables.index('u')
v_index = variables.index('v')

sample = era5_data[0]
z = sample['input'][z_index, 0].numpy()
t = sample['input'][t_index, 0].numpy()
u = sample['input'][u_index, 0].numpy()
v = sample['input'][v_index, 0].numpy()

# Calculate wind speed
wind_speed = np.sqrt(u**2 + v**2)

# Calculate vorticity (curl of wind field)
# Simplified calculation using finite differences
dy = 111000 * np.mean(np.diff(lats))  # Convert degrees to meters
dx = 111000 * np.mean(np.diff(lons)) * np.cos(np.radians(np.mean(lats)))

# Compute derivatives
dudy = np.zeros_like(u)
dvdx = np.zeros_like(v)

dudy[1:-1, :] = (u[2:, :] - u[:-2, :]) / (2 * dy)
dvdx[:, 1:-1] = (v[:, 2:] - v[:, :-2]) / (2 * dx)

# Relative vorticity (curl of velocity)
vorticity = dvdx - dudy

# Calculate divergence (measure of mass continuity)
dudx = np.zeros_like(u)
dvdy = np.zeros_like(v)

dudx[:, 1:-1] = (u[:, 2:] - u[:, :-2]) / (2 * dx)
dvdy[1:-1, :] = (v[2:, :] - v[:-2, :]) / (2 * dy)

divergence = dudx + dvdy

# Visualize these derived quantities
fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# Plot original wind field
visualizer.plot_flow_vectors(u, v, background=z, var_name='z', 
                           title="Wind Field and Geopotential Height", 
                           ax=axes[0, 0])

# Plot vorticity
cmap = 'RdBu_r'
im = axes[0, 1].imshow(vorticity, cmap=cmap, origin='lower')
axes[0, 1].set_title("Vorticity (1/s)")
plt.colorbar(im, ax=axes[0, 1])

# Plot divergence
im = axes[1, 0].imshow(divergence, cmap=cmap, origin='lower')
axes[1, 0].set_title("Divergence (1/s)")
plt.colorbar(im, ax=axes[1, 0])

# Plot wind speed
im = axes[1, 1].imshow(wind_speed, cmap='viridis', origin='lower')
axes[1, 1].set_title("Wind Speed (m/s)")
plt.colorbar(im, ax=axes[1, 1])

plt.tight_layout()
plt.show()

## 10. Visualize data suitable for flow matching

In [None]:
print("Visualizing Sequential Data for Flow Matching")

# Extract consecutive states to visualize flow matching targets
n_steps = 4  # Number of consecutive steps
first_sample_idx = 0

states = []
for i in range(first_sample_idx, first_sample_idx + n_steps):
    if i < len(era5_data):
        sample = era5_data[i]
        states.append(sample['input'][z_index, 0].numpy())  # Using geopotential

# Compute the "flow" between consecutive states
flows = []
for i in range(len(states) - 1):
    flow = states[i+1] - states[i]
    flows.append(flow)

# Visualize states and flows
fig, axes = plt.subplots(2, n_steps-1, figsize=(18, 10))

# Plot consecutive states
for i in range(n_steps-1):
    # Plot current state
    im = axes[0, i].imshow(states[i], cmap='viridis', origin='lower')
    axes[0, i].set_title(f"State at t={i}")
    plt.colorbar(im, ax=axes[0, i])
    
    # Plot flow to next state
    im = axes[1, i].imshow(flows[i], cmap='RdBu_r', origin='lower')
    axes[1, i].set_title(f"Flow t={i} → t={i+1}")
    plt.colorbar(im, ax=axes[1, i])

plt.tight_layout()
plt.show()

## 11. Save statistics for model training

In [None]:
# Create a clean dictionary of normalization statistics
norm_stats = {}
for var in variables:
    norm_stats[var] = {
        'mean': stats[var]['mean'],
        'std': stats[var]['std']
    }
    
print("Normalization statistics for model training:")
for var, var_stats in norm_stats.items():
    print(f"  {var}: mean = {var_stats['mean']:.4f}, std = {var_stats['std']:.4f}")

# Save statistics to file
import json
import os

# Create directory if it doesn't exist
os.makedirs('../data', exist_ok=True)

# Save as JSON
with open('../data/normalization_stats.json', 'w') as f:
    json.dump(norm_stats, f, indent=2)
    
print(f"\nStatistics saved to '../data/normalization_stats.json'")

## 12. Conclusion

In [None]:
print("""
In this notebook, we've explored ERA5 data for weather prediction using flow matching. We:

1. Loaded and inspected ERA5 data from WeatherBench2
2. Visualized different weather variables and their relationships
3. Analyzed temporal patterns and correlations
4. Computed physics-relevant derived quantities
5. Prepared data for flow matching model training

Next steps would be to train a flow matching model using this data, which we'll cover in the next notebook.
""")