### Import dependencies

In [1]:
import warnings
import torch
import lightning as L
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from utils.general import (get_bbox_from_config,
                           load_config, compute_mean_std,
                           create_gif_from_images,
                           save_best_model_as_pt,
                           save_config_to_log_dir)
from data.loaders import load_data
from data.datasets import CreateDataset
import xarray as xr

### 1. Load Configuration

In [2]:
# Load configuration and set seed
config = load_config()

In [3]:
lr_data = xr.open_dataset(
    config["dataset"]["lr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})

lr_data = lr_data.astype("float32")
latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
# lr = lr_data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
#                     longitude=slice(longitude_range[0],longitude_range[1]))

start_date = config["dataset"]["start_date"]
end_date = config["dataset"]["end_date"]


lr = lr_data.sel(time=slice(start_date, end_date))
size_in_mb = lr.nbytes / (1024 * 1024)
print(f"Dataset size: {size_in_mb:.2f} MB")
lr

Dataset size: 317397.74 MB


Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 17.22 GiB 36.00 MiB Shape (8808, 512, 1025) (576, 128, 128) Dask graph 576 chunks in 3 graph layers Data type float32 numpy.ndarray",1025  512  8808,

Unnamed: 0,Array,Chunk
Bytes,17.22 GiB,36.00 MiB
Shape,"(8808, 512, 1025)","(576, 128, 128)"
Dask graph,576 chunks in 3 graph layers,576 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [4]:
lr_path = "ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-standard-sfc-v0.zarr"
lr = xr.open_dataset(lr_path, engine="zarr")
lr

In [5]:
# Load High-Resolution (HR) and Low-Resolution (LR) datasets
hr_path = "ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-high-sfc-v0.zarr"
lr_path = "ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-standard-sfc-v0.zarr"

hr = xr.open_dataset(hr_path, engine="zarr")
lr = xr.open_dataset(lr_path, engine="zarr")

In [6]:
print(hr.dims)
print(lr.dims)



In [None]:
import torch
import torch.nn.functional as F
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
import cv2
import numpy as np

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_psnr(hr_data, lr_upscaled):
    """
    Compute the Peak Signal-to-Noise Ratio (PSNR) between HR and upscaled LR data.

    Args:
        hr_data (torch.Tensor): High-resolution image.
        lr_upscaled (torch.Tensor): Low-resolution image upscaled to HR shape.

    Returns:
        float: PSNR value.
    """
    mse = F.mse_loss(hr_data, lr_upscaled)  # Mean Squared Error (MSE)
    if mse == 0:
        return float('inf')  # If MSE is zero, PSNR is infinite

    max_pixel = torch.max(hr_data)  # Max possible pixel value
    psnr = 10 * torch.log10((max_pixel ** 2) / mse)  # PSNR formula
    return psnr.item()

def compute_ssim(hr_data, lr_upscaled):
    """
    Compute the Structural Similarity Index Measure (SSIM) between HR and upscaled LR data.

    Args:
        hr_data (torch.Tensor): High-resolution image.
        lr_upscaled (torch.Tensor): Low-resolution image upscaled to HR shape.

    Returns:
        float: SSIM value.
    """
    ssim_index = StructuralSimilarityIndexMeasure().to(device)
    ssim_value = ssim_index(hr_data.unsqueeze(0).unsqueeze(0), lr_upscaled.unsqueeze(0).unsqueeze(0))
    return ssim_value.item()

# Store PSNR and SSIM values
psnr_values = []
ssim_values = []

# Iterate over all time steps
for t in range(hr["t2m"].shape[0]):
    hr_data = torch.tensor(hr["t2m"][t].values, dtype=torch.float32).to(device)  # Extract HR data at time t
    lr_data = torch.tensor(lr["t2m"][t].values, dtype=torch.float32).to(device)  # Extract LR data at time t

    # Resize LR data to match HR shape using bicubic interpolation
    lr_upscaled = F.interpolate(lr_data.unsqueeze(0).unsqueeze(0), size=(hr_data.shape[0], hr_data.shape[1]), mode='bilinear', align_corners=False).squeeze()

    # Compute PSNR and SSIM
    psnr_values.append(compute_psnr(hr_data, lr_upscaled))
    ssim_values.append(compute_ssim(hr_data, lr_upscaled))

# Compute average PSNR and SSIM over all time steps
average_psnr = np.mean(psnr_values)
average_ssim = np.mean(ssim_values)

std_psnr = np.std(psnr_values)
std_ssim = np.std(ssim_values)

print(f"Average PSNR over {len(psnr_values)} time steps: {average_psnr:.2f} dB ± {std_psnr:.2f}")
print(f"Average SSIM over {len(ssim_values)} time steps: {average_ssim:.4f} ± {std_ssim:.4f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.figure(figsize=(16,10))
plt.imshow(lr_data[::-1,:])
plt.colorbar()
plt.show()

In [None]:
diff = hr_down-lr_data

In [None]:
diff.max()

In [None]:
diff.min()

In [None]:
diff.mean()

In [None]:
diff

In [None]:
plt.figure(figsize=(16,10))
plt.imshow(diff[::-1,:], vmax=diff.max(),vmin=diff.min(),cmap="Grays")
plt.colorbar()