In [15]:
import xarray as xr
import torch
from torch.utils.data import DataLoader
from xbatcher import BatchGenerator

import matplotlib.pyplot as plt
from utils.general import load_config

config = load_config()

In [16]:
config

{'model': {'architecture': 'SRResNet',
  'large_kernel_size': 9,
  'small_kernel_size': 3,
  'n_channels': 64,
  'n_blocks': 16,
  'scaling_factor': 8},
 'training': {'streaming': False,
  'learning_rate': 0.01,
  'batch_size': 32,
  'epochs': 100,
  'optimizer': 'Adam',
  'loss_function': 'mse_loss',
  'devices': [0],
  'accelerator': 'gpu',
  'deterministic': True,
  'seed': 42},
 'dataset': {'hr_zarr_url': 'https://cacheb.dcms.destine.eu/d1-climate-dt/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-high-sfc-v0.zarr',
  'lr_zarr_url': 'https://cacheb.dcms.destine.eu/d1-climate-dt/ScenarioMIP-SSP3-7.0-IFS-NEMO-0001-standard-sfc-v0.zarr',
  'time_range': '2024-10',
  'start_date': '2020-01-01',
  'end_date': '2020-01-10',
  'latitude_range': [35.0, 71.0],
  'longitude_range': [-25.0, 40.0],
  'data_variable': ['t2m', 'u10', 'v10'],
  'data_target': ['t2m'],
  'unit': 'Temperature (C)'},
 'validation': {'val_split_ratio': 0.3},
 'checkpoint': {'monitor': 'val_ssim',
  'mode': 'max',
  'filename': 'b

In [22]:

start_date = config['dataset']['start_date']
end_date = config['dataset']['end_date']
data_vars = config['dataset']['data_variable']
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})


latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
data = data.sel(time=slice(start_date, end_date))
data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                longitude=slice(longitude_range[0],longitude_range[1]),
                time=slice(start_date,end_date))

In [23]:
# data_vars = list(data.data_vars)

data = data[data_vars]
data

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.08 GiB 48.00 MiB Shape (240, 819, 1479) (48, 512, 512) Dask graph 60 chunks in 4 graph layers Data type float32 numpy.ndarray",1479  819  240,

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.08 GiB 48.00 MiB Shape (240, 819, 1479) (48, 512, 512) Dask graph 60 chunks in 4 graph layers Data type float32 numpy.ndarray",1479  819  240,

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.08 GiB 48.00 MiB Shape (240, 819, 1479) (48, 512, 512) Dask graph 60 chunks in 4 graph layers Data type float32 numpy.ndarray",1479  819  240,

Unnamed: 0,Array,Chunk
Bytes,1.08 GiB,48.00 MiB
Shape,"(240, 819, 1479)","(48, 512, 512)"
Dask graph,60 chunks in 4 graph layers,60 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
data.sizes['latitude']

Frozen({'time': 240, 'latitude': 819, 'longitude': 1479})

In [25]:
GPU_DEVICE = 0
device = torch.device("cuda",GPU_DEVICE)
batch_generator = BatchGenerator(data, input_dims={"time": config['training']['batch_size'],
                                                   "latitude":  data.sizes['latitude'],
                                                   "longitude": data.sizes['longitude']})

In [None]:
# Iterate through one batch
for batch in batch_generator:

    data  = batch.load()
    print(data.sizes)
    data = data.to_array().values
    data = torch.tensor(data)
    data = torch.permute(data, (1, 0, 2, 3))
    print(data.shape)
    data.to(device)
    break

Frozen({'time': 32, 'latitude': 819, 'longitude': 1479})
torch.Size([32, 3, 819, 1479])
