In [1]:
import xarray as xr
import torch
from torch.utils.data import DataLoader
from xbatcher import BatchGenerator

import matplotlib.pyplot as plt
from utils.general import load_config

config = load_config()

In [2]:
config

{'model': {'architecture': 'SRResNet',
  'large_kernel_size': 9,
  'small_kernel_size': 3,
  'n_channels': 64,
  'n_blocks': 16,
  'scaling_factor': 8},
 'training': {'streaming': False,
  'learning_rate': 0.01,
  'batch_size': 32,
  'epochs': 100,
  'optimizer': 'Adam',
  'loss_function': 'mse_loss',
  'devices': [0],
  'accelerator': 'gpu',
  'deterministic': True,
  'seed': 42},
 'dataset': {'hr_zarr_url': 'https://cacheb.dcms.destine.eu/d1-climate-dt/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-high-sfc-v0.zarr',
  'lr_zarr_url': 'https://cacheb.dcms.destine.eu/d1-climate-dt/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-standard-sfc-v0.zarr',
  'time_range': '2024-10',
  'start_date': '2020-01-01',
  'end_date': '2020-01-10',
  'latitude_range': [35.0, 71.0],
  'longitude_range': [-25.0, 40.0],
  'data_variable': ['t2m'],
  'data_target': ['t2m'],
  'unit': 'Temperature (C)'},
 'validation': {'val_split_ratio': 0.3},
 'checkpoint': {'monitor': 'val_ssim',
  'mode': 'max',
  'filename': 'best-val-ssim-{

In [3]:

start_date = config['dataset']['start_date']
end_date = config['dataset']['end_date']
data_vars = config['dataset']['data_variable']
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})


latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
data = data.sel(time=slice(start_date, end_date))
data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                longitude=slice(longitude_range[0],longitude_range[1]),
                time=slice(start_date,end_date))

In [17]:
# data_vars = list(data.data_vars)

data = data[data_vars]
data.sizes

Frozen({'time': 24, 'latitude': 819, 'longitude': 1479})

In [5]:
data.sizes['latitude']

819

In [6]:
GPU_DEVICE = 0
device = torch.device("cuda",GPU_DEVICE)
batch_generator = BatchGenerator(data, input_dims={"time": config['training']['batch_size'],
                                                   "latitude":  data.sizes['latitude'],
                                                   "longitude": data.sizes['longitude']})

In [7]:
# Iterate through one batch
for batch in batch_generator:

    data  = batch.load()
    print(data.sizes)
    data = data.to_array().values
    data = torch.tensor(data)
    data = torch.permute(data, (1, 0, 2, 3))
    print(data.shape)
    data.to(device)
    break

Frozen({'time': 32, 'latitude': 819, 'longitude': 1479})
torch.Size([32, 1, 819, 1479])


## Benchmark

In [8]:
import xarray as xr
import time
import numpy as np
from xbatcher import BatchGenerator
from utils.general import load_config

config = load_config()

In [None]:
import time
import numpy as np
import xarray as xr

num_trials = 10  # Number of repetitions
batch = 8

# Open dataset once
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={}
)
data = data.sel(
    time=slice("2025-03-01", "2025-03-01"),
    latitude=slice(*config["dataset"]["latitude_range"]),
    longitude=slice(*config["dataset"]["longitude_range"])
)

data_vars = list(data.data_vars)  # List of available variables

num_vars_list = []
time_avg_list = []
time_std_list = []
size_data = []

for num_vars in range(1, min(len(data_vars), 10)):  # Avoid index error
    selected_vars = data_vars[:num_vars]  # Select `num_vars` variables
    hr_data_subset = data[selected_vars]
    print(selected_vars)

    batch_generator_hr = BatchGenerator(hr_data_subset, input_dims={
        "time": batch,
        "latitude": hr_data_subset.sizes["latitude"],
        "longitude": hr_data_subset.sizes["longitude"]
    })

    times = []

    for _ in range(num_trials):  # Run multiple trials
        start_time = time.time()

        for batch_data in batch_generator_hr:
            data_batch = batch_data.load()
            print(data_batch.sizes)


        elapsed_time = time.time() - start_time
        times.append(elapsed_time)

        size = data_batch.nbytes / (1024 * 1024)  # Size in MB

    avg_time = np.mean(times)
    std_time = np.std(times)

    num_vars_list.append(num_vars)
    time_avg_list.append(avg_time)
    time_std_list.append(std_time)
    size_data.append(size)

    print(f"Num Vars: {num_vars}, Avg Time: {avg_time:.4f} sec, Std Dev: {std_time:.4f} sec")


Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude': 1479})
Frozen({'time': 8, 'latitude': 819, 'longitude':

In [15]:
# num_trials = 10  # Number of repetitions
# num_vars_list = []
# time_avg_list = []
# time_std_list = []
# size_data = []
# batch = 8
# for num_vars in range(1, 10):

#     data = xr.open_dataset(
#     config["dataset"]["hr_zarr_url"],
#     engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
#     chunks={})
#     data_vars = list(data.data_vars)
#     start_date = "2025-03-01"
#     end_date = "2025-03-01"
#     latitude_range = tuple(config["dataset"]["latitude_range"])
#     longitude_range = tuple(config["dataset"]["longitude_range"])
#     data = data.sel(time=slice(start_date, end_date))
#     data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
#                     longitude=slice(longitude_range[0],longitude_range[1]))

#     selected_vars = "t2m"

#     hr_data_subset = data[selected_vars]  # Subset dataset
#     batch_generator_hr = BatchGenerator(hr_data_subset, input_dims={
#         "time": batch,
#         "latitude": hr_data_subset.sizes["latitude"],
#         "longitude": hr_data_subset.sizes["longitude"]
#     })

#     times = []

#     for _ in range(num_trials):  # Run multiple trials
#         start_time = time.time()

#         for batch in batch_generator_hr:

#             data_batch = batch.load()
#             print(data_batch.shape)

#         elapsed_time = time.time() - start_time
#         times.append(elapsed_time)
#         size = data_batch.nbytes / (1024*1024)
#         # print(size)
#         data_batch = 0
#     avg_time = np.mean(times)
#     std_time = np.std(times)

#     num_vars_list.append(num_vars)
#     time_avg_list.append(avg_time)
#     time_std_list.append(std_time)
#     size_data.append(size)
#     print(f"Num Vars: {num_vars}, Avg Time: {avg_time:.4f} sec, Std Dev: {std_time:.4f} sec")

In [None]:
# Plot the results with error bars
plt.figure(figsize=(8, 5))
plt.errorbar(num_vars_list, time_avg_list, yerr=time_std_list, fmt='-o', capsize=4, label="Avg Time ± Std Dev")
plt.xlabel("Number of Data Variables")
plt.ylabel("Time to Load (seconds)")
plt.title("Time to Load vs Number of Data Variables")
plt.grid()
plt.legend()
plt.savefig("load_vs_parameters.png")
plt.show()

In [None]:
import pandas as pd


max_bandwidth_Mbps = 25000
df = pd.DataFrame(data={"climate_variables":num_vars_list,
                        "time_avg":time_avg_list,
                        "time_std":time_std_list,
                        "size_data":size_data})

df["batch"] = 64
df["fps"] = (df['climate_variables'] *  df["batch"] / df["time_avg"])
df['bandwidth_MBps'] = df['size_data'] / df["time_avg"]
df['bandwidth_Mbps'] = df['size_data'] / df["time_avg"] * 8
df['max_bandwidth_Mbps'] = max_bandwidth_Mbps
df["max_fps"] = df['max_bandwidth_Mbps'] * df["fps"] / df['bandwidth_Mbps']
df["max_fps"] = df['max_bandwidth_Mbps'] * df["fps"] / df['bandwidth_Mbps']