In [2]:
import xarray as xr
import torch
from torch.utils.data import DataLoader
from xbatcher import BatchGenerator

import matplotlib.pyplot as plt
from utils.general import load_config

config = load_config()

In [None]:
hr_data = xr.open_dataset(
    "/home/ubuntu/project/destine-super-resolution/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-high-sfc-v0.zarr",
    engine="zarr",
    chunks={})

lr_data = xr.open_dataset(
    "/home/ubuntu/project/destine-super-resolution/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-standard-sfc-v0.zarr",
    engine="zarr",
    chunks={})

In [16]:
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})


start_date = "2025-03-01"
end_date = "2025-03-01T15:00:00"
latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
hr_data = data.sel(time=slice(start_date, end_date))
hr_data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                longitude=slice(longitude_range[0],longitude_range[1]),
                time=slice(start_date,end_date))
data_vars = list(hr_data.data_vars)

hr_data = hr_data['t2m']

In [None]:

num_trials = 10
# Iterate through one batch
for _ in range(num_trials):
    data = hr_data.load()



In [None]:
hr_data = hr_data[config['dataset']['data_target']]
lr_data = lr_data[config['dataset']['data_variable']]

In [None]:
GPU_DEVICE = 0
device = torch.device("cuda",GPU_DEVICE)
batch_generator_lr = BatchGenerator(lr_data, input_dims={"time": config['training']['batch_size'], "latitude":  512, "longitude": 1025})
batch_generator_hr = BatchGenerator(hr_data, input_dims={"time": config['training']['batch_size'], "latitude":  512, "longitude": 1025})

In [None]:

# Iterate through one batch
for batch in batch_generator_lr:

    data  = batch.load()
    print(data.sizes)
    data = data.to_array().values
    data = torch.tensor(data)

    data = torch.permute(data, (1, 0, 2, 3))
    print(data.shape)
    data.to(device)
    break

In [None]:
# cloud based 20.3 s ± 437 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (24 x 4096 x 8193)
# 3.23 s ± 41.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (24 x 4096 x 8193)

In [None]:
# Iterate through both batch generators together
for batch_lr, batch_hr in zip(batch_generator_lr, batch_generator_hr):

    # Load LR and HR batches into memory
    lr_data = batch_lr.load().to_array().values
    hr_data = batch_hr.load().to_array().values

    # Convert to PyTorch tensors
    lr_tensor = torch.tensor(lr_data, dtype=torch.float32).to(device)
    hr_tensor = torch.tensor(hr_data, dtype=torch.float32).to(device)

    hr_tensor = torch.permute(hr_tensor, (1, 0, 2, 3))
    lr_tensor = torch.permute(lr_tensor, (1, 0, 2, 3))

    print("LR Batch Shape:", lr_tensor.shape)  # Expected: (num_vars, batch_size, lat, lon)
    print("HR Batch Shape:", hr_tensor.shape)  # Expected: (num_vars, batch_size, lat, lon)

    break  # Remove this if you want to iterate through all batches