# 🌎 Welcome to the CSE151B Spring 2025 Climate Emulation Competition!

Thank you for participating in this exciting challenge focused on building machine learning models to emulate complex climate systems.  
This notebook is provided as a **starter template** to help yo`u:

- Understand how to load and preprocess the dataset  
- Construct a baseline model  
- Train and evaluate predictions using a PyTorch Lightning pipeline  
- Format your predictions for submission to the leaderboard  

You're encouraged to:
- Build on this structure or replace it entirely
- Try more advanced models and training strategies
- Incorporate your own ideas to push the boundaries of what's possible

If you're interested in developing within a repository structure and/or use helpful tools like configuration management (based on Hydra) and logging (with Weights & Biases), we recommend checking out the following Github repo. Such a structure can be useful when running multiple experiments and trying various research ideas.

👉 [https://github.com/salvaRC/cse151b-spring2025-competition](https://github.com/salvaRC/cse151b-spring2025-competition)

Good luck, have fun, and we hope you learn a lot through this process!


### 📦 Install Required Libraries
We install the necessary Python packages for data loading, deep learning, and visualization.


In [None]:
!pip install xarray zarr dask lightning matplotlib wandb cftime einops --quiet

import os
from datetime import datetime
import numpy as np
import xarray as xr
import dask.array as da
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

### ⚙️ Configuration Setup  
Define all model, data, and training hyperparameters in one place for easy control and reproducibility.

### 📊 Data Configuration

We define the dataset settings used for training and evaluation. This includes:

- **`path`**: Path to the `.zarr` dataset containing monthly climate variables from CMIP6 simulations.
- **`input_vars`**: Climate forcing variables (e.g., CO₂, CH₄) used as model inputs.
- **`output_vars`**: Target variables to predict — surface air temperature (`tas`) and precipitation (`pr`).
- **`target_member_id`**: Ensemble member to use from the simulations (each SSP has 3) for target variables.
- **`train_ssps`**: SSP scenarios used for training (low to high emissions).
- **`test_ssp`**: Scenario held out for evaluation (Must be set to SSP245).
- **`test_months`**: Number of months to include in the test split (Must be set to 120).
- **`batch_size`** and **`num_workers`**: Data loading parameters for PyTorch training.

These settings reflect how the challenge is structured: models must learn from some emission scenarios and generalize to unseen ones.

> ⚠️ **Important:** Do **not modify** the following test settings:
>
> - `test_ssp` must remain **`ssp245`**, which is the held-out evaluation scenario.
> - `test_months` must be **`120`**, corresponding to the last 10 years (monthly resolution) of the scenario.



In [None]:
#NOTE Change the data directory according to where you have your zarr files stored
config = {
    "data": {
        "path": "/kaggle/input/cse151b-spring2025-competition/processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "target_member_id": 0,
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "test_months": 360,
        "batch_size": 16,
        "num_workers": 4,
    },
    "model": {
        "type": "unetAttention",
        "kernel_size": 3,
        "init_dim": 64,
        "depth": 6,
        "dropout_rate": 0.3,
    },
    "training": {
        "lr": 5e-5,
    },
    "trainer": {
        "max_epochs": 30,
        "accelerator": "auto",
        "devices": "auto",
        "precision": 16,
        "deterministic": True,
        "num_sanity_val_steps": 0,
    },
    "seed": 42,
}
pl.seed_everything(config["seed"])  # Set seed for reproducibility

### 🔧 Spatial Weighting Utility Function

This cell sets up utility functions for reproducibility and spatial weighting:

- **`get_lat_weights(latitude_values)`**: Computes cosine-based area weights for each latitude, accounting for the Earth's curvature. This is critical for evaluating global climate metrics fairly — grid cells near the equator represent larger surface areas than those near the poles.


In [None]:
def get_lat_weights(latitude_values):
    lat_rad = np.deg2rad(latitude_values)
    weights = np.cos(lat_rad)
    return weights / np.mean(weights)

### 🧠 SimpleCNN: A Residual Convolutional Baseline

This is a lightweight baseline model designed to capture spatial patterns in global climate data using convolutional layers.

- The architecture starts with a **convolution + batch norm + ReLU** block to process the input channels.
- It then applies a series of **residual blocks** to extract increasingly abstract spatial features. These help preserve gradient flow during training.
- Finally, a few convolutional layers reduce the feature maps down to the desired number of output channels (`tas` and `pr`).

This model only serves as a **simple baseline for climate emulation**. 

We encourage you to build and experiment with your own models and ideas.


In [None]:
# Define a Residual Block which is a building block for ResNet-like architectures
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super().__init__()
        # First convolution layer with BatchNorm and ReLU activation
        # Padding is set to keep spatial dimensions same (kernel_size // 2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Second convolution layer followed by BatchNorm (no activation yet)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Skip connection (identity or projection) to match output dimensions
        self.skip = nn.Sequential()
        # If input and output dimensions or stride differ, adjust skip connection via 1x1 conv
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x  # Save input to add back later (skip connection)
        
        # Forward pass through first conv, batch norm, and ReLU
        out = self.relu(self.bn1(self.conv1(x)))
        
        # Forward pass through second conv and batch norm (no activation yet)
        out = self.bn2(self.conv2(out))
        
        # Add skip connection (either identity or adjusted input)
        out += self.skip(identity)
        
        # Final activation after addition
        return self.relu(out)

# Define a simple CNN model using ResidualBlocks
class SimpleCNN(nn.Module):
    def __init__(self, n_input_channels, n_output_channels, kernel_size=3, init_dim=64, depth=4, dropout_rate=0.2):
        super().__init__()
        
        # Initial convolution block: Conv + BatchNorm + ReLU
        # Converts input channels to init_dim feature maps
        self.initial = nn.Sequential(
            nn.Conv2d(n_input_channels, init_dim, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(init_dim),
            nn.ReLU(inplace=True),
        )
        
        # List to hold ResidualBlocks
        self.res_blocks = nn.ModuleList()
        current_dim = init_dim
        
        # Create 'depth' number of ResidualBlocks
        for i in range(depth):
            # Double the number of channels for each block except the last one
            out_dim = current_dim * 2 if i < depth - 1 else current_dim
            self.res_blocks.append(ResidualBlock(current_dim, out_dim))
            if i < depth - 1:
                current_dim *= 2  # Update current_dim after doubling
        
        # Dropout layer for regularization, applied after residual blocks
        self.dropout = nn.Dropout2d(dropout_rate)
        
        # Final convolutional layers to reduce channels to output channels
        # Includes Conv + BatchNorm + ReLU followed by a 1x1 Conv to output channels
        self.final = nn.Sequential(
            nn.Conv2d(current_dim, current_dim // 2, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(current_dim // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(current_dim // 2, n_output_channels, kernel_size=1),
        )

    def forward(self, x):
        # Pass input through initial block
        x = self.initial(x)
        
        # Pass through each residual block sequentially
        for res_block in self.res_blocks:
            x = res_block(x)
        
        # Apply dropout and then final layers to get output
        return self.final(self.dropout(x))

In [None]:
# A convolutional block used in the UNet encoder and decoder paths
class UNetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout=0.1):
        super().__init__()
        # First convolution layer + ReLU activation
        # Padding is set to maintain spatial size
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.relu1 = nn.ReLU(inplace=True)
        
        # Dropout layer for regularization after first conv+relu
        self.drop = nn.Dropout2d(dropout)
        
        # Second convolution layer + ReLU activation
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.relu2 = nn.ReLU(inplace=True)

        # Skip connection to match input to output channels if needed
        # Uses 1x1 convolution if channels differ, otherwise identity mapping
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)  # Prepare skip connection (either conv or identity)
        
        out = self.relu1(self.conv1(x))  # First conv + relu
        out = self.drop(out)              # Apply dropout
        out = self.conv2(out)             # Second conv (no activation yet)
        
        # Add skip connection and apply final activation
        return self.relu2(out + identity)

# UNet architecture implementing encoder-decoder with skip connections
class UNet(nn.Module):
    def __init__(
        self,
        n_input_channels,
        n_output_channels,
        kernel_size=3,
        init_dim=64,
        dropout_rate=0.1,
        depth=4,
    ):
        super().__init__()
        self.depth = depth
        
        # Compute number of channels for each level of the UNet (doubling every step)
        self.dims = [init_dim * (2 ** i) for i in range(depth + 1)]

        # Encoder blocks and downsampling layers
        self.encoders = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        in_ch = n_input_channels
        
        for i in range(depth):
            # Encoder convolution block for current depth level
            self.encoders.append(UNetConvBlock(in_ch, self.dims[i], kernel_size, dropout_rate))
            
            # Downsampling by stride 2 conv to reduce spatial dimensions by half
            self.downsamples.append(
                nn.Conv2d(self.dims[i], self.dims[i], kernel_size=3, stride=2, padding=1)
            )
            
            in_ch = self.dims[i]  # Update input channels for next block

        # Bottleneck block at the bottom of the UNet
        self.bottleneck = UNetConvBlock(self.dims[depth - 1], self.dims[depth], kernel_size, dropout_rate)

        # Decoder upsampling and convolution blocks
        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        # Build decoder layers in reverse order
        for i in reversed(range(depth)):
            # Transpose convolution to upsample (double spatial size)
            self.upconvs.append(
                nn.ConvTranspose2d(self.dims[i + 1], self.dims[i], kernel_size=2, stride=2)
            )
            
            # Decoder conv block takes concatenated features (upsampled + skip connection)
            self.decoders.append(
                UNetConvBlock(self.dims[i] * 2, self.dims[i], kernel_size, dropout_rate)
            )

        # Final conv layer to get desired number of output channels (e.g., segmentation classes)
        self.final = nn.Conv2d(self.dims[0], n_output_channels, kernel_size=1)

    def forward(self, x):
        enc_outputs = []  # To store outputs for skip connections

        # Encoder path: conv blocks + downsampling
        for i in range(self.depth):
            x = self.encoders[i](x)      # Apply encoder conv block
            enc_outputs.append(x)        # Save output for skip connection
            
            x = self.downsamples[i](x)   # Downsample (reduce spatial size)

        # Bottleneck conv block
        x = self.bottleneck(x)

        # Decoder path: upsampling + concatenation with skip features + conv blocks
        for i in range(self.depth):
            x = self.upconvs[i](x)  # Upsample feature map
            
            # Get corresponding encoder output for skip connection
            enc = enc_outputs[self.depth - 1 - i]

            # Handle any size mismatches due to odd dimensions by interpolation
            if x.shape[-2:] != enc.shape[-2:]:
                x = F.interpolate(x, size=enc.shape[-2:], mode='bilinear', align_corners=False)

            # Concatenate along channel dimension
            x = torch.cat([x, enc], dim=1)
            
            # Apply decoder conv block on concatenated features
            x = self.decoders[i](x)

        # Final conv layer to get output predictions
        return self.final(x)

In [None]:
# Attention block used in UNet to modulate skip connections
class UNetAttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # Transform gating signal (decoder features) to intermediate channels
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(F_int)
        )
        # Transform skip connection features (encoder features) to intermediate channels
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(F_int)
        )
        # Compute attention coefficients (single channel, sigmoid activation)
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, x):
        # Apply transformations
        g1 = self.W_g(g)  # gating signal from decoder
        x1 = self.W_x(x)  # features from encoder skip connection
        
        # Sum and apply non-linearity then sigmoid attention mask
        psi = self.psi(F.relu(g1 + x1))
        
        # Scale encoder features by attention coefficients
        return x * psi


# Convolutional block with residual skip connection, used in attention UNet
class UNetAttentionConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        # First conv + batch norm + ReLU
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)

        # Second conv + batch norm (no activation yet)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

        # Skip connection to match dimensions if needed
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)  # Prepare skip connection

        # Forward through conv layers and activations
        out = self.relu1(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        
        # Add skip connection and apply final activation
        return self.relu2(out + identity)


# UNet architecture with attention gates on skip connections
class UNetAttention(nn.Module):
    def __init__(self, n_input_channels, n_output_channels, kernel_size=3, init_dim=64, depth=4):
        super().__init__()
        self.depth = depth
        
        # Channel sizes at each depth level (doubling at each step)
        self.dims = [init_dim * (2 ** i) for i in range(depth + 1)]

        # Encoder blocks and downsampling layers
        self.encoders = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        in_ch = n_input_channels
        
        for i in range(depth):
            # Encoder conv block
            self.encoders.append(UNetAttentionConvBlock(in_ch, self.dims[i], kernel_size))
            # Downsample spatial resolution by factor of 2
            self.downsamples.append(
                nn.Conv2d(self.dims[i], self.dims[i], kernel_size=3, stride=2, padding=1)
            )
            in_ch = self.dims[i]

        # Bottleneck conv block at bottom of UNet
        self.bottleneck = UNetAttentionConvBlock(self.dims[depth - 1], self.dims[depth], kernel_size)

        # Decoder upconvs, attention blocks, and conv blocks
        self.upconvs = nn.ModuleList()
        self.attention_blocks = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        # Build decoder path in reverse order
        for i in reversed(range(depth)):
            # Upsample by transpose convolution (double spatial size)
            self.upconvs.append(
                nn.ConvTranspose2d(self.dims[i + 1], self.dims[i], kernel_size=2, stride=2)
            )
            # Attention block modulates encoder skip features based on decoder features
            self.attention_blocks.append(
                UNetAttentionBlock(F_g=self.dims[i], F_l=self.dims[i], F_int=self.dims[i] // 2)
            )
            # Decoder conv block processes concatenated upsampled + attended skip features
            self.decoders.append(UNetAttentionConvBlock(self.dims[i] * 2, self.dims[i], kernel_size))

        # Final conv to produce output channels (e.g., segmentation classes)
        self.final = nn.Conv2d(self.dims[0], n_output_channels, kernel_size=1)

    def forward(self, x):
        enc_outputs = []

        # Encoder path
        for i in range(self.depth):
            x = self.encoders[i](x)
            enc_outputs.append(x)    # Save encoder outputs for skip connections
            x = self.downsamples[i](x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path with attention gating on skip connections
        for i in range(self.depth):
            x = self.upconvs[i](x)  # Upsample

            # Get corresponding encoder output
            enc = enc_outputs[self.depth - 1 - i]

            # Fix any spatial size mismatches due to rounding in down/up sampling
            if x.shape[-2:] != enc.shape[-2:]:
                x = F.interpolate(x, size=enc.shape[-2:], mode='bilinear', align_corners=False)

            # Apply attention block to encoder skip features, conditioned on decoder features
            enc = self.attention_blocks[i](x, enc)

            # Concatenate upsampled decoder features with attended skip connection
            x = torch.cat([x, enc], dim=1)

            # Decode concatenated features
            x = self.decoders[i](x)

        # Final output conv layer
        return self.final(x)

In [None]:
# Basic convolutional block used in ResNet, with residual connections and dropout
class ResNetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dropout=0.1, stride=1, downsample=None):
        super().__init__()
        # First convolution layer with possible stride (for downsampling)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.drop = nn.Dropout2d(dropout)  # Spatial dropout for regularization
        
        # Second convolution layer (no stride here, same spatial size)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Optional downsample layer to match identity dimensions (for skip connection)
        self.downsample = downsample

    def forward(self, x):
        identity = x  # Save input for skip connection

        # Forward pass through first conv + batch norm + relu
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.drop(out)  # Apply dropout
        out = self.bn2(self.conv2(out))  # Second conv + batch norm (no activation yet)

        # If downsample layer exists, apply to identity to match shapes
        if self.downsample is not None:
            identity = self.downsample(x)

        # Add skip connection (residual)
        out += identity
        return self.relu(out)  # Final activation


# ResNet model with encoder and decoder parts
class ResNet(nn.Module):
    def __init__(
        self,
        n_input_channels,
        n_output_channels,
        kernel_size=3,
        init_dim=64,
        dropout_rate=0.1,
        depth=4,
    ):
        super().__init__()

        # Define number of residual blocks in each layer based on depth
        if depth == 4:
            layers = [2, 2, 2, 2]       # ResNet-18 style
        elif depth == 5:
            layers = [3, 4, 6, 3]       # ResNet-34 style
        elif depth == 6:
            layers = [3, 4, 23, 3]      # ResNet-101 style
        else:
            raise ValueError("Unsupported depth")

        self.inplanes = init_dim  # Number of channels input to each block

        # Initial convolution and batch norm + ReLU before ResNet layers
        self.initial = nn.Sequential(
            nn.Conv2d(n_input_channels, init_dim, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(init_dim),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # Downsample spatial size

        # Residual layers forming the encoder
        self.layer0 = self._make_layer(
            ResNetConvBlock, init_dim, layers[0], stride=1, kernel_size=kernel_size, dropout=dropout_rate
        )
        self.layer1 = self._make_layer(
            ResNetConvBlock, init_dim * 2, layers[1], stride=2, kernel_size=kernel_size, dropout=dropout_rate
        )
        self.layer2 = self._make_layer(
            ResNetConvBlock, init_dim * 4, layers[2], stride=2, kernel_size=kernel_size, dropout=dropout_rate
        )
        self.layer3 = self._make_layer(
            ResNetConvBlock, init_dim * 8, layers[3], stride=2, kernel_size=kernel_size, dropout=dropout_rate
        )

        # Decoder blocks: transposed convolutions to upsample feature maps back to input size
        self.upsample1 = self._upsample_block(init_dim * 8, init_dim * 4)
        self.upsample2 = self._upsample_block(init_dim * 4, init_dim * 2)
        self.upsample3 = self._upsample_block(init_dim * 2, init_dim)
        self.upsample4 = self._upsample_block(init_dim, init_dim // 2)

        # Final layers to reduce channels and produce output (e.g., segmentation classes)
        self.final = nn.Sequential(
            nn.Conv2d(init_dim // 2, init_dim // 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(init_dim // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(init_dim // 2, n_output_channels, kernel_size=1),
        )

    # Helper function to create a sequence of residual blocks
    def _make_layer(self, block, planes, blocks, stride=1, kernel_size=3, dropout=0.1):
        downsample = None
        # Create downsampling skip connection if dimensions or stride differ
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        # First block in the layer might downsample input
        layers = [block(self.inplanes, planes, stride=stride, downsample=downsample, kernel_size=kernel_size, dropout=dropout)]
        self.inplanes = planes  # Update inplanes for subsequent blocks

        # Remaining blocks, no downsampling
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, kernel_size=kernel_size, dropout=dropout))

        return nn.Sequential(*layers)

    # Helper function to create upsampling blocks using ConvTranspose2d
    def _upsample_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        input_size = x.shape[-2:]  # Save original spatial size for final interpolation

        # Encoder path
        x = self.initial(x)       # Initial conv layer
        x = self.maxpool(x)       # Downsample

        x = self.layer0(x)        # Residual layers with downsampling as configured
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # Decoder path: upsample step by step to increase spatial resolution
        x = self.upsample1(x)
        x = self.upsample2(x)
        x = self.upsample3(x)
        x = self.upsample4(x)

        # Resize to original input size to ensure consistent output dimensions
        x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)

        # Final conv layers to produce output (e.g. segmentation map)
        return self.final(x)

In [None]:
# Single dense block: BN -> ReLU -> Conv -> Dropout -> Concatenate input with output
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, dropout=0.1):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        # Conv layer outputs growth_rate feature maps
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        self.drop = nn.Dropout2d(dropout)

    def forward(self, x):
        # Apply BN -> ReLU -> Conv -> Dropout
        out = self.conv(self.relu(self.bn(x)))
        out = self.drop(out)
        # Concatenate input and output along channel dimension (dense connectivity)
        return torch.cat([x, out], dim=1)

# Multiple DenseBlocks stacked sequentially, input channels grow with each block
class DenseLayer(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate, dropout=0.1):
        super().__init__()
        layers = []
        for i in range(num_layers):
            # Each DenseBlock receives channels increased by growth_rate * number of previous blocks
            layers.append(DenseBlock(in_channels + i * growth_rate, growth_rate, dropout))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

# Transition layer reduces spatial dimensions and channels via BN, ReLU, 1x1 Conv, Dropout, and AvgPool
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.1):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        # 1x1 conv reduces channel count from in_channels to out_channels
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.drop = nn.Dropout2d(dropout)
        # Avg pooling reduces spatial dimensions by 2
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.relu(self.bn(x))
        x = self.conv(x)
        x = self.drop(x)
        return self.pool(x)

# DenseNet architecture combining dense blocks and transition layers, with an upsampling decoder
class DenseNet(nn.Module):
    def __init__(
        self,
        n_input_channels,
        n_output_channels,
        kernel_size=3,
        init_dim=64,
        dropout_rate=0.1,
        depth=4,
        growth_rate=32,
    ):
        super().__init__()

        # Define number of layers per dense block based on depth
        if depth == 3:
            layers = [6, 12, 24, 16]   # DenseNet-121 style
        elif depth == 4:
            layers = [6, 12, 32, 32]   # DenseNet-169 style
        elif depth == 5:
            layers = [6, 12, 48, 32]   # DenseNet-201 style
        elif depth == 6:
            layers = [6, 12, 64, 48]   # DenseNet-264 style
        else:
            raise ValueError("Unsupported depth")

        self.inplanes = init_dim       # Current number of channels, updated after each block
        self.growth_rate = growth_rate # Number of channels added per DenseBlock
        self.dropout_rate = dropout_rate

        # Initial convolution layer to process input image
        self.initial = nn.Sequential(
            nn.Conv2d(n_input_channels, init_dim, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(init_dim),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # Downsample spatially

        # Encoder: Dense blocks separated by transition layers (except after last block)
        self.layer0 = self._make_layer(layers[0])
        self.trans0 = self._make_transition()

        self.layer1 = self._make_layer(layers[1])
        self.trans1 = self._make_transition()

        self.layer2 = self._make_layer(layers[2])
        self.trans2 = self._make_transition()

        self.layer3 = self._make_layer(layers[3])  # No transition after final dense block

        # Decoder: upsample progressively to recover spatial resolution
        self.upsample1 = self._upsample_block(self.inplanes, self.inplanes // 2)
        self.upsample2 = self._upsample_block(self.inplanes // 2, self.inplanes // 4)
        self.upsample3 = self._upsample_block(self.inplanes // 4, self.inplanes // 8)
        self.upsample4 = self._upsample_block(self.inplanes // 8, self.inplanes // 16)

        # Final conv layers to produce output with desired number of channels
        self.final = nn.Sequential(
            nn.Conv2d(self.inplanes // 16, self.inplanes // 16, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(self.inplanes // 16),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.inplanes // 16, n_output_channels, kernel_size=1),
        )

    # Create a dense layer with multiple DenseBlocks; update current channels accordingly
    def _make_layer(self, num_layers):
        block = DenseLayer(num_layers, self.inplanes, self.growth_rate, self.dropout_rate)
        self.inplanes += num_layers * self.growth_rate  # Channels increase by growth_rate * num_layers
        return block

    # Create a transition layer to reduce channels and downsample spatially by 2
    def _make_transition(self):
        out_channels = self.inplanes // 2  # Reduce channels by half
        trans = TransitionLayer(self.inplanes, out_channels, self.dropout_rate)
        self.inplanes = out_channels
        return trans

    # Upsampling block using transposed convolution followed by batch norm and ReLU
    def _upsample_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        input_size = x.shape[-2:]  # Save input spatial dimensions

        x = self.initial(x)        # Initial conv + BN + ReLU
        x = self.maxpool(x)        # Downsample spatially

        # Encoder path: dense blocks separated by transitions
        x = self.trans0(self.layer0(x))
        x = self.trans1(self.layer1(x))
        x = self.trans2(self.layer2(x))
        x = self.layer3(x)  # Last block without transition

        # Decoder path: upsample step-by-step to recover spatial resolution
        x = self.upsample1(x)
        x = self.upsample2(x)
        x = self.upsample3(x)
        x = self.upsample4(x)

        # Interpolate to match original input size precisely
        x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)

        # Final conv layers to generate output (e.g., segmentation map)
        return self.final(x)

### 📐 Normalizer: Z-Score Scaling for Climate Inputs & Outputs

This class handles **Z-score normalization**, a crucial preprocessing step for stable and efficient neural network training:

- **`set_input_statistics(mean, std)` / `set_output_statistics(...)`**: Store the mean and standard deviation computed from the training data for later use.
- **`normalize(data, data_type)`**: Standardizes the data using `(x - mean) / std`. This is applied separately to inputs and outputs.
- **`inverse_transform_output(data)`**: Converts model predictions back to the original physical units (e.g., Kelvin for temperature, mm/day for precipitation).

Normalizing the data ensures the model sees inputs with similar dynamic ranges and avoids biases caused by different variable scales.


In [None]:
class Normalizer:
    def __init__(self):
        self.mean_in, self.std_in = None, None
        self.mean_out, self.std_out = None, None

    def set_input_statistics(self, mean, std):
        self.mean_in = mean
        self.std_in = std

    def set_output_statistics(self, mean, std):
        self.mean_out = mean
        self.std_out = std

    def normalize(self, data, data_type):
        if data_type == "input":
            return (data - self.mean_in) / self.std_in
        elif data_type == "output":
            return (data - self.mean_out) / self.std_out

    def inverse_transform_output(self, data):
        return data * self.std_out + self.mean_out


In [None]:
class ClimateDataset(Dataset):
    def __init__(self, inputs_dask, outputs_dask, output_is_normalized=True):
        self.size = inputs_dask.shape[0]
        print(f"Creating dataset with {self.size} samples...")

        inputs_np = inputs_dask.compute()
        outputs_np = outputs_dask.compute()

        self.inputs = torch.from_numpy(inputs_np).float()
        self.outputs = torch.from_numpy(outputs_np).float()

        if torch.isnan(self.inputs).any() or torch.isnan(self.outputs).any():
            raise ValueError("NaNs found in dataset")

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]


class ClimateDataModule(pl.LightningDataModule):
    def __init__(
        self,
        path,
        input_vars,
        output_vars,
        train_ssps,
        test_ssp,
        target_member_id,
        val_months=120,
        test_months=360,
        batch_size=32,
        num_workers=0,
        seed=42,
    ):
        super().__init__()
        self.path = path
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.train_ssps = train_ssps
        self.test_ssp = test_ssp
        self.target_member_id = target_member_id
        self.val_months = val_months
        self.test_months = test_months
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.normalizer = Normalizer()

    def prepare_data(self):
        assert os.path.exists(self.path), f"Data path not found: {self.path}"

    def setup(self, stage=None):
        ds = xr.open_zarr(self.path, consolidated=False, chunks={"time": 24})
        spatial_template = ds["rsdt"].isel(time=0, ssp=0, drop=True)

        def load_ssp(ssp):
            input_dask, output_dask = [], []
            for var in self.input_vars:
                da_var = ds[var].sel(ssp=ssp)
                if "latitude" in da_var.dims:
                    da_var = da_var.rename({"latitude": "y", "longitude": "x"})
                if "member_id" in da_var.dims:
                    da_var = da_var.sel(member_id=self.target_member_id)
                if set(da_var.dims) == {"time"}:
                    da_var = da_var.broadcast_like(spatial_template).transpose("time", "y", "x")
                input_dask.append(da_var.data)

            for var in self.output_vars:
                da_out = ds[var].sel(ssp=ssp, member_id=self.target_member_id)
                if "latitude" in da_out.dims:
                    da_out = da_out.rename({"latitude": "y", "longitude": "x"})
                output_dask.append(da_out.data)

            return da.stack(input_dask, axis=1), da.stack(output_dask, axis=1)

        train_input, train_output, val_input, val_output = [], [], None, None

        for ssp in self.train_ssps:
            x, y = load_ssp(ssp)
            if ssp == "ssp370":
                val_input = x[-self.val_months:]
                val_output = y[-self.val_months:]
                train_input.append(x[:-self.val_months])
                train_output.append(y[:-self.val_months])
            else:
                train_input.append(x)
                train_output.append(y)

        train_input = da.concatenate(train_input, axis=0)
        train_output = da.concatenate(train_output, axis=0)

        self.normalizer.set_input_statistics(
            mean=da.nanmean(train_input, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_input, axis=(0, 2, 3), keepdims=True).compute(),
        )
        self.normalizer.set_output_statistics(
            mean=da.nanmean(train_output, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_output, axis=(0, 2, 3), keepdims=True).compute(),
        )

        train_input_norm = self.normalizer.normalize(train_input, "input")
        train_output_norm = self.normalizer.normalize(train_output, "output")
        val_input_norm = self.normalizer.normalize(val_input, "input")
        val_output_norm = self.normalizer.normalize(val_output, "output")

        test_input, test_output = load_ssp(self.test_ssp)
        test_input = test_input[-self.test_months:]
        test_output = test_output[-self.test_months:]
        test_input_norm = self.normalizer.normalize(test_input, "input")

        self.train_dataset = ClimateDataset(train_input_norm, train_output_norm)
        self.val_dataset = ClimateDataset(val_input_norm, val_output_norm)
        self.test_dataset = ClimateDataset(test_input_norm, test_output, output_is_normalized=False)

        self.lat = spatial_template.y.values
        self.lon = spatial_template.x.values
        self.area_weights = xr.DataArray(get_lat_weights(self.lat), dims=["y"], coords={"y": self.lat})

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def get_lat_weights(self):
        return self.area_weights

    def get_coords(self):
        return self.lat, self.lon

### 🌍 Data Module: Loading, Normalization, and Splitting

This section handles the entire data pipeline, from loading the `.zarr` dataset to preparing PyTorch-ready DataLoaders.

#### `ClimateDataset`
- A simple PyTorch `Dataset` wrapper that preloads the entire (normalized) dataset into memory using Dask.
- Converts the data to PyTorch tensors and handles any `NaN` checks up front.

#### `ClimateDataModule`
A PyTorch Lightning `DataModule` that handles:
- ✅ **Loading data** from different SSP scenarios and ensemble members
- ✅ **Broadcasting non-spatial inputs** (like CO₂) to match spatial grid size
- ✅ **Normalization** using mean/std computed from training data only
- ✅ **Splitting** into training, validation, and test sets:
  - Training: All months from selected SSPs (except last 10 years of SSP370)
  - Validation: Last 10 years (120 months) of SSP370
  - Test: Last 10 years of SSP245 (unseen scenario)
- ✅ **Batching** and parallelized data loading via PyTorch `DataLoader`s
- ✅ **Latitude-based area weighting** for fair climate metric evaluation
- Shape of the inputs are Batch_Size X 5 (num_input_variables) X 48 X 72
- Shape of ouputputs are Batch_Size X 2 (num_output_variables) X 48 X 72

> ℹ️ **Note:** You likely won’t need to modify this class but feel free to make modifications if you want to inlcude different ensemble mebers to feed more data to your models


### ⚡ ClimateEmulationModule: Lightning Wrapper for Climate Model Emulation

This is the core model wrapper built with **PyTorch Lightning**, which organizes the training, validation, and testing logic for the climate emulation task. Lightning abstracts away much of the boilerplate code in PyTorch-based deep learning workflows, making it easier to scale models.

#### ✅ Key Features

- **`training_step` / `validation_step` / `test_step`**: Standard Lightning hooks for computing loss and predictions at each stage. The loss used is **Mean Squared Error (MSE)**.

- **Normalization-aware outputs**:
  - During validation and testing, predictions and targets are denormalized before evaluation using stored mean/std statistics.
  - This ensures evaluation is done in real-world units (Kelvin and mm/day).

- **Metric Evaluation** via `_evaluate()`:
  For each variable (`tas`, `pr`), it calculates:
  - **Monthly Area-Weighted RMSE**
  - **Time-Mean RMSE** (RMSE on 10-year average's)
  - **Time-Stddev MAE** (MAE on 10-year standard deviation; a measure of temporal variability)
    
  These metrics reflect the competition's evaluation criteria and are logged and printed.

- **Kaggle Submission Writer**:
  After testing, predictions are saved to a `.csv` file in the required Kaggle format via `_save_submission()`.

- **Saving Predictions for Visualization**:
  - Validation predictions are saved tao `val_preds.npy` and `val_trues.npy`
  - These can be loaded later for visual inspection of the model's performance.

 🔧 **Feel free to modify any part of this module** (loss functions, evaluation, training logic) to better suit your model or training pipeline / Use pure PyTorch etc.

⚠️ The **final submission `.csv` file must strictly follow the format and naming convention used in `_save_submission()`**, as these `ID`s are used to match predictions to the hidden test set during evaluation.



In [None]:
import pandas as pd

class ClimateEmulationModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=['model']) # Save all hyperparameters except the model to self.hparams.<param_name>
        self.criterion = nn.MSELoss()
        self.normalizer = None
        self.val_preds, self.val_targets = [], []
        self.test_preds, self.test_targets = [], []
        self.train_losses = []
        self.val_losses = []

    def forward(self, x):
        return self.model(x)

    def on_fit_start(self):
        self.normalizer = self.trainer.datamodule.normalizer  # Get the normalizer from the datamodule (see above)

    def training_step(self, batch, batch_idx):
        x, y = batch # Unpack inputs and targets (this is the output of the _getitem_ method in the Dataset above)
        y_hat = self(x)   # Forward pass
        loss = self.criterion(y_hat, y)  # Calculate loss
        self.log("train/loss", loss, prog_bar=True)  # Log loss for tracking
        return loss

    def on_train_epoch_end(self):
        loss = self.trainer.callback_metrics.get("train/loss")
        if loss is not None:
            self.train_losses.append(loss.item())
            print(f"[Epoch {self.current_epoch}] Train Loss: {loss.item():.4f}")
            print(f"[Epoch {self.current_epoch}] Current LR: {self.trainer.optimizers[0].param_groups[0]['lr']:.6f}")

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("val/loss", loss, prog_bar=True)

        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np = self.normalizer.inverse_transform_output(y.detach().cpu().numpy())
        self.val_preds.append(y_hat_np)
        self.val_targets.append(y_np)

        return loss

    def on_validation_epoch_end(self):
        val_loss = self.trainer.callback_metrics.get("val/loss")
        if val_loss is not None:
            self.val_losses.append(val_loss.item())
            print(f"[Epoch {self.current_epoch}] Val Loss: {val_loss.item():.4f}")
        
        # Concatenate all predictions and ground truths from each val step/batch into one array
        preds = np.concatenate(self.val_preds, axis=0)
        trues = np.concatenate(self.val_targets, axis=0)
        self._evaluate(preds, trues, phase="val")
        np.save("val_preds.npy", preds)
        np.save("val_trues.npy", trues)
        self.val_preds.clear()
        self.val_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np = y.detach().cpu().numpy()
        self.test_preds.append(y_hat_np)
        self.test_targets.append(y_np)

    def on_test_epoch_end(self):
        # Concatenate all predictions and ground truths from each test step/batch into one array
        preds = np.concatenate(self.test_preds, axis=0)
        trues = np.concatenate(self.test_targets, axis=0)
        self._evaluate(preds, trues, phase="test")
        self._save_submission(preds)
        self.test_preds.clear()
        self.test_targets.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-5)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, threshold=1e-3, min_lr=1e-7, verbose=True)

        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler,"monitor": "val/loss","interval": "epoch","frequency": 1}}
        # return optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def _evaluate(self, preds, trues, phase="val"):
        datamodule = self.trainer.datamodule
        area_weights = datamodule.get_lat_weights()
        lat, lon = datamodule.get_coords()
        time = np.arange(preds.shape[0])
        output_vars = datamodule.output_vars

        for i, var in enumerate(output_vars):
            p = preds[:, i]
            t = trues[:, i]
            p_xr = xr.DataArray(p, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
            t_xr = xr.DataArray(t, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

            # RMSE
            rmse = np.sqrt(((p_xr - t_xr) ** 2).weighted(area_weights).mean(("time", "y", "x")).item())
            # RMSE of time-mean
            mean_rmse = np.sqrt(((p_xr.mean("time") - t_xr.mean("time")) ** 2).weighted(area_weights).mean(("y", "x")).item())
            # MAE of time-stddev
            std_mae = np.abs(p_xr.std("time") - t_xr.std("time")).weighted(area_weights).mean(("y", "x")).item()

            print(f"[{phase.upper()}] {var}: RMSE={rmse:.4f}, Time-Mean RMSE={mean_rmse:.4f}, Time-Stddev MAE={std_mae:.4f}")
            self.log_dict({
                f"{phase}/{var}/rmse": rmse,
                f"{phase}/{var}/time_mean_rmse": mean_rmse,
                f"{phase}/{var}/time_std_mae": std_mae,
            })

    def _save_submission(self, predictions):
        datamodule = self.trainer.datamodule
        lat, lon = datamodule.get_coords()
        output_vars = datamodule.output_vars
        time = np.arange(predictions.shape[0])

        rows = []
        for t_idx, t in enumerate(time):
            for var_idx, var in enumerate(output_vars):
                for y_idx, y in enumerate(lat):
                    for x_idx, x in enumerate(lon):
                        row_id = f"t{t_idx:03d}_{var}_{y:.2f}_{x:.2f}"
                        pred = predictions[t_idx, var_idx, y_idx, x_idx]
                        rows.append({"ID": row_id, "Prediction": pred})

        df = pd.DataFrame(rows)
        filename = f"kaggle_submission_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        filepath = os.path.join("/kaggle/working", filename)
        # os.makedirs("submissions", exist_ok=True)
        # filepath = f"submissions/kaggle_submission_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        df.to_csv(filepath, index=False)
        print(f"✅ Submission saved to: {filepath}")

### ⚡ Training & Evaluation with PyTorch Lightning

This block sets up and runs the training and testing pipeline using **PyTorch Lightning’s `Trainer`**, which abstracts away much of the boilerplate in deep learning workflows.

- **Modular Setup**:
  - `datamodule`: Handles loading, normalization, and batching of climate data.
  - `model`: A convolutional neural network that maps climate forcings to predicted outputs.
  - `lightning_module`: Wraps the model with training/validation/test logic and metric evaluation.

- **Trainer Flexibility**:
  The `Trainer` accepts a wide range of configuration options from `config["trainer"]`, including:
  - Number of epochs
  - Precision (e.g., 16-bit or 32-bit)
  - Device configuration (CPU, GPU, or TPU)
  - Determinism, logging, callbacks, and more

In [None]:
def plot_loss_curves(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label="Training Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    # plt.savefig("loss_curves.png")
    plt.show()

In [None]:
datamodule = ClimateDataModule(**config["data"])

specificModel = config["model"]["type"]

n_input_channels=len(config["data"]["input_vars"])
n_output_channels=len(config["data"]["output_vars"])


# model = SimpleCNN(
#     n_input_channels=len(config["data"]["input_vars"]),
#     n_output_channels=len(config["data"]["output_vars"]),
#     **{k: v for k, v in config["model"].items() if k != "type"}
# )

if(specificModel == "simple_cnn"):
    model = SimpleCNN(n_input_channels, n_output_channels, **{k: v for k, v in config["model"].items() if k != "type"})
elif(specificModel == "unet"):
    model = UNet(n_input_channels, n_output_channels, **{k: v for k, v in config["model"].items() if k != "type"})
elif(specificModel == "unetAttention"):
    model = UNetAttention(n_input_channels, n_output_channels, **{k: v for k, v in config["model"].items() if k not in ["type", "dropout_rate"]})
elif(specificModel == "resnet"):
    model = ResNet(n_input_channels, n_output_channels, **{k: v for k, v in config["model"].items() if k != "type"})
elif(specificModel == "densenet"):
    model = DenseNet(n_input_channels, n_output_channels, **{k: v for k, v in config["model"].items() if k != "type"})

lightning_module = ClimateEmulationModule(model, learning_rate=config["training"]["lr"])

earlyStop = EarlyStopping(
    monitor="val/loss",
    patience=15,
    mode="min",
    verbose=True
)

goodModelCheckpoint = ModelCheckpoint(
    dirpath="/kaggle/working/checkpoints",
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    filename="best-checkpoint",
)

trainer = pl.Trainer(
    callbacks=[earlyStop, goodModelCheckpoint],
    **config["trainer"]
)

# trainer = pl.Trainer(**config["trainer"])
trainer.fit(lightning_module, datamodule=datamodule)   # Training

plot_loss_curves(lightning_module.train_losses, lightning_module.val_losses)

# Test model

**IMPORTANT:** Please note that the test metrics will be bad because the test targets have been corrupted on the public Kaggle dataset.
The purpose of testing below is to generate the Kaggle submission file based on your model's predictions, which you can submit to the competition.

In [None]:
best_model_path = goodModelCheckpoint.best_model_path
best_model = ClimateEmulationModule.load_from_checkpoint(
    best_model_path,
    model=model,  # Pass your model architecture again
    learning_rate=config["training"]["lr"],
)

best_model.normalizer = datamodule.normalizer

# Evaluate using the best model
trainer.test(best_model, datamodule=datamodule)
# trainer.test(lightning_module, datamodule=datamodule) 

### Plotting Utils


In [None]:
def plot_comparison(true_xr, pred_xr, title, cmap='viridis', diff_cmap='RdBu_r', metric=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    vmin = min(true_xr.min().item(), pred_xr.min().item())
    vmax = max(true_xr.max().item(), pred_xr.max().item())

    # Ground truth
    true_xr.plot(ax=axs[0], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[0].set_title(f"{title} (Ground Truth)")

    # Prediction
    pred_xr.plot(ax=axs[1], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[1].set_title(f"{title} (Prediction)")

    # Difference
    diff = pred_xr - true_xr
    abs_max = np.max(np.abs(diff))
    diff.plot(ax=axs[2], cmap=diff_cmap, vmin=-abs_max, vmax=abs_max, add_colorbar=True)
    axs[2].set_title(f"{title} (Difference) {f'- {metric:.4f}' if metric else ''}")

    plt.tight_layout()
    plt.show()


### 🖼️ Visualizing Validation Predictions

This cell loads saved validation predictions and compares them to the ground truth using spatial plots. These visualizations help you qualitatively assess your model's performance.

For each output variable (`tas`, `pr`), we visualize:

- **📈 Time-Mean Map**: The 10-year average spatial pattern for both prediction and ground truth. Helps identify long-term biases or spatial shifts.
- **📊 Time-Stddev Map**: Shows the standard deviation across time for each grid cell — useful for assessing how well the model captures **temporal variability** at each location.
- **🕓 Random Timestep Sample**: Visual comparison of prediction vs ground truth for a single month. Useful for spotting fine-grained anomalies or errors in specific months.

> These plots provide intuition beyond metrics and are useful for debugging spatial or temporal model failures.


In [None]:
# Load validation predictions
# make sure to have run the validation loop at least once
val_preds = np.load("val_preds.npy")
val_trues = np.load("val_trues.npy")

lat, lon = datamodule.get_coords()
output_vars = datamodule.output_vars
time = np.arange(val_preds.shape[0])

for i, var in enumerate(output_vars):
    pred_xr = xr.DataArray(val_preds[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    true_xr = xr.DataArray(val_trues[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

    # --- Time Mean ---
    plot_comparison(true_xr.mean("time"), pred_xr.mean("time"), f"{var} Val Time-Mean")

    # --- Time Stddev ---
    plot_comparison(true_xr.std("time"), pred_xr.std("time"), f"{var} Val Time-Stddev", cmap="plasma")

    # --- Random timestep ---
    t_idx = np.random.randint(0, len(time))
    plot_comparison(true_xr.isel(time=t_idx), pred_xr.isel(time=t_idx), f"{var} Val Sample Timestep {t_idx}")

In [None]:
# Load validation predictions
# make sure to have run the validation loop at least once
val_preds = np.load("val_preds.npy")
val_trues = np.load("val_trues.npy")

lat, lon = datamodule.get_coords()
output_vars = datamodule.output_vars
time = np.arange(val_preds.shape[0])

t_idx = np.random.randint(0, len(time))
for i, var in enumerate(output_vars):
    pred_xr = xr.DataArray(val_preds[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    true_xr = xr.DataArray(val_trues[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    
    plot_comparison(true_xr.isel(time=t_idx), pred_xr.isel(time=t_idx), f"{var} Val Sample Timestep {t_idx}")

t_idx = np.random.randint(0, len(time))
for i, var in enumerate(output_vars):
    pred_xr = xr.DataArray(val_preds[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    true_xr = xr.DataArray(val_trues[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    
    plot_comparison(true_xr.isel(time=t_idx), pred_xr.isel(time=t_idx), f"{var} Val Sample Timestep {t_idx}")

## 🧪 Final Notes

This notebook is meant to serve as a **baseline template** — a starting point to help you get up and running quickly with the climate emulation challenge.

You are **not** required to stick to this exact setup. In fact, we **encourage** you to:

- 🔁 Build on top of the provided `DataModule`. 
- 🧠 Use your own model architectures or training pipelines that you’re more comfortable with 
- ⚗️ Experiment with ideas  
- 🥇 Compete creatively to climb the Kaggle leaderboard  
- 🙌 Most importantly: **have fun** and **learn as much as you can** along the way

This challenge simulates a real-world scientific problem, and there’s no single "correct" approach — so be curious, experiment boldly, and make it your own!


In [None]:
data_path = config["data"]["path"]
data = xr.open_zarr(data_path)

# Select relevant SSPs and average over ensemble members
# Set up 2 rows (variables) x 3 columns (SSPs)

ssps = ["ssp126", "ssp370", "ssp585"]
subset = data.sel(ssp=ssps).mean(dim=["member_id"])

target_vars = ["tas", "pr"]

fig, ax = plt.subplots(len(target_vars), len(ssps), figsize=(18, 12))

for i, var in enumerate(target_vars):
    for j, ssp in enumerate(ssps):
        vals = subset[var].sel(ssp=ssp).values.flatten()
        ax[i, j].hist(vals, bins=100)
        ax[i, j].set_title(f"{var} - {ssp}")
        ax[i, j].set_xlabel(f"{var} value")
        ax[i, j].set_ylabel("Frequency")

        mean_val = np.nanmean(vals)
        std_val = np.nanstd(vals)
        print(f"{var} ({ssp}) → mean: {mean_val:.2f}, std: {std_val:.2f}")

plt.tight_layout()
plt.show()

In [None]:
ssps = ["ssp126", "ssp370", "ssp585"]
subset = data.sel(ssp=ssps).mean(dim=["member_id"])

chosen_vars = ["CO2", "CH4", "rsdt"]

fig, axs = plt.subplots(len(chosen_vars), len(ssps), figsize=(18, 12))

for i, var in enumerate(chosen_vars):
    for j, ssp in enumerate(ssps):
        vals = subset[var].sel(ssp=ssp).values.flatten()
        axs[i, j].hist(vals, bins=100)
        axs[i, j].set_title(f"{var} - {ssp}")
        axs[i, j].set_xlabel(f"{var} value")
        axs[i, j].set_ylabel("Frequency")

        mean_val = np.nanmean(vals)
        std_val = np.nanstd(vals)
        print(f"{var} ({ssp}) → mean: {mean_val:.2f}, std: {std_val:.2f}")

plt.tight_layout()
plt.show()