# Data pipeline workspace

In [1]:
import os
from glob import glob
from torch.utils.data.distributed import DistributedSampler

In [2]:
import logging
from typing import Dict

import numpy as np
import pandas as pd
import xarray as xr
import netCDF4 as nc

import torch
from torchvision import transforms as tforms

from credit.data import Sample, extract_month_day_hour, find_common_indices
from bridgescaler import read_scaler

In [3]:
import yaml

In [4]:
logger = logging.getLogger(__name__)

## Config example

In [5]:
# config_name = '/glade/u/home/ksha/miles-credit/config/example_for_data_checks.yml'
# # Read YAML file
# with open(config_name, 'r') as stream:
#     conf = yaml.safe_load(stream)

#config_name = '/glade/u/home/ksha/miles-credit/config/fuxi_baseline_ksha_cpu.yml' #
#config_name = '/glade/u/home/ksha/miles-credit/config/example_for_data_checks.yml'
config_name = '/glade/u/home/ksha/miles-credit/results/fuxi_norm/model_new.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

## Pytorch dataset

In [6]:
def get_forward_data(filename) -> xr.DataArray:
    """Lazily opens the Zarr store on gladefilesystem.
    """
    dataset = xr.open_zarr(filename, consolidated=True)
    return dataset

def get_forward_data_netCDF4(filename) -> xr.DataArray:
    """Lazily opens netCDF4 files.
    """
    dataset = xr.open_dataset(filename)
    return dataset
    
def generate_integer_list_around(number, spacing=10):
    """
    Generate a list of integers on either side of a given number with a specified spacing.

    Parameters:
    - number (int): The central number around which the list is generated.
    - spacing (int): The spacing between consecutive integers in the list. Default is 10.

    Returns:
    - integer_list (list): List of integers on either side of the given number.
    """
    lower_limit = number - spacing
    upper_limit = number + spacing + 1  # Adding 1 to include the upper limit
    integer_list = list(range(lower_limit, upper_limit))

    return integer_list


def find_key_for_number(input_number, data_dict):
    """
    Find the key in the dictionary based on the given number.

    Parameters:
    - input_number (int): The number to search for in the dictionary.
    - data_dict (dict): The dictionary with keys and corresponding value lists.

    Returns:
    - key_found (str): The key in the dictionary where the input number falls within the specified range.
    """
    for key, value_list in data_dict.items():
        if value_list[1] <= input_number <= value_list[2]:
            return key

    # Return None if the number is not within any range
    return None

In [7]:
def drop_var_from_dataset(xarray_dataset, varname_keep):
    varname_all = list(xarray_dataset.keys())

    for varname in varname_all:
        if varname not in varname_keep:
            xarray_dataset = xarray_dataset.drop_vars(varname)

    varname_clean = list(xarray_dataset.keys())
    
    varname_diff = list(set(varname_keep) - set(varname_clean))
    assert len(varname_diff)==0, 'Variable name: {} missing'.format(varname_diff) 
    
    return xarray_dataset

In [13]:
class ERA5_and_Forcing_Dataset(torch.utils.data.Dataset):
    '''
    A Pytorch Dataset class that works on:
        - ERA5 variables (time, level, lat, lon)
        - foring variables (time, lat, lon)
        - static variables (lat, lon)
        
    Parameters:
    - filenames: ERA5 file path as *.zarr with re (e.g., /user/ERA5/*.zarr)
    - filename_forcing: None /or a netCDF4 file that contains all the forcing variables.
    - filename_static: None /or a netCDF4 file that contains all the static variables.
    
    '''
    def __init__(
        self,
        varname_upper_air,
        varname_surface,
        varname_forcing,
        varname_static,
        varname_diagnostic,
        filenames,
        filename_surface=None,
        filename_forcing=None,
        filename_static=None,
        filename_diagnostic=None,
        history_len=2,
        forecast_len=0,
        transform=None,
        seed=42,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None
    ):
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # one shot option
        self.one_shot = one_shot

        # total number of needed forecast lead times 
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)
        
        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # ERA5 operations
        all_files = []
        filenames = sorted(filenames)
        
        for fn in filenames:
            
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)
            
        self.all_files = all_files
        
        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {} # <------ change
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [len(ERA5_xarray['time']), 
                                                  ind_start, 
                                                  ind_start+len(ERA5_xarray['time'])]
            ind_start += len(ERA5_xarray['time'])+1
            
        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing
        
        if self.filename_forcing is not None:
            assert os.path.isfile(filename_forcing), 'Cannot find forcing file [{}]'.format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data_netCDF4(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)
            
            self.xarray_forcing = xarray_dataset
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static
        
        if self.filename_static is not None:
            assert os.path.isfile(filename_static), 'Cannot find static file [{}]'.format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data_netCDF4(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)
            
            self.xarray_static = xarray_dataset
        else:
            self.xarray_static = False

        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic
        
        if self.filename_diagnostic is not None:

            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)
            
            for fn in filename_diagnostic:

                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_diagnostic)
                
                diagnostic_files.append(xarray_dataset)
                
            self.diagnostic_files = diagnostic_files
            
            assert len(self.diagnostic_files)==len(self.all_files), \
                'Mismatch between the total number of diagnostic files and upper-air files'
        else:
            self.diagnostic_files = False
            
        # ======================================================== #
        # surface files
        if filename_surface is not None:
        
            surface_files = []
            filename_surface = sorted(filename_surface)
        
            for fn in filename_surface:

                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)
                
                surface_files.append(xarray_dataset)
                
            self.surface_files = surface_files
            
            assert len(self.surface_files)==len(self.all_files), \
                'Mismatch between the total number of surface files and upper-air files'
        else:
            self.surface_files = False
            
    
    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray['time']) - self.total_seq_len + 1
        return total_len

    def __getitem__(self, index):
        # ========================================================================== #
        # cross-year indices --> the index of the year + indices within that year
        
        # select the ind_file based on the iter index 
        ind_file = find_key_for_number(index, self.ERA5_indices)

        # get the ind within the current file
        ind_start = self.ERA5_indices[ind_file][1]
        ind_start_in_file = index - ind_start

        # handle out-of-bounds
        ind_largest = len(self.all_files[int(ind_file)]['time'])-(self.history_len+self.forecast_len+1)
        if ind_start_in_file > ind_largest:
            ind_start_in_file = ind_largest
        # ========================================================================== #
        # subset xarray on time dimension & load it to the memory
        
        ind_end_in_file = ind_start_in_file+self.history_len+self.forecast_len
        
        ## ERA5_subset: a xarray dataset that contains training input and target (for the current index)
        ERA5_subset = self.all_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file+1)).load()
        
        if self.surface_files:
            ## subset surface variables
            surface_subset = self.surface_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file+1)).load()
    
            ## merge upper-air and surface here:
            ERA5_subset = ERA5_subset.merge(surface_subset)

        
        # ==================================================== #
        # split ERA5_subset into training inputs and targets + merge with forcing and static

        # the ind_end of the ERA5_subset
        ind_end_time = len(ERA5_subset['time'])

        # datetiem information as int number (used in some normalization methods)
        datetime_as_number = ERA5_subset.time.values.astype('datetime64[s]').astype(int)

        # ==================================================== #
        # xarray dataset as input
        ## historical_ERA5_images: the final input
        
        historical_ERA5_images = ERA5_subset.isel(time=slice(0, self.history_len, self.skip_periods))
            
        # merge forcing inputs
        if self.xarray_forcing:
            # =============================================================================== #
            # matching month, day, hour between forcing and upper air [time]
            month_day_forcing = extract_month_day_hour(np.array(self.xarray_forcing['time']))
            month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time']))
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing).load()
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            forcing_subset_input['time'] = historical_ERA5_images['time']
            # =============================================================================== #
            
            # merge
            historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)
            
        # merge static inputs
        if self.xarray_static:
            # expand static var on time dim
            N_time_dims = len(ERA5_subset['time'])
            static_subset_input = self.xarray_static.expand_dims(dim={"time": N_time_dims})
            # assign coords 'time'
            static_subset_input = static_subset_input.assign_coords({'time': ERA5_subset['time']})
            
            # slice + load to the GPU
            static_subset_input = static_subset_input.isel(time=slice(0, self.history_len, self.skip_periods)).load()
            
            # update 
            static_subset_input['time'] = historical_ERA5_images['time']
            
            # merge
            historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)
        
        # ==================================================== #
        # xarray dataset as target
        ## target_ERA5_images: the final target
        
        target_ERA5_images = ERA5_subset.isel(time=slice(self.history_len, ind_end_time, self.skip_periods))

        ## merge diagnoisc input here:
        if self.diagnostic_files:
            
            # subset diagnostic variables
            diagnostic_subset = self.diagnostic_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file+1)).load()
            
            # merge into the target dataset
            target_diagnostic = diagnostic_subset.isel(time=slice(self.history_len, ind_end_time, self.skip_periods))
            target_ERA5_images = target_ERA5_images.merge(target_diagnostic)
        
        if self.one_shot is not None:
            # get the final state of the target as one-shot
            target_ERA5_images = target_ERA5_images.isel(time=slice(0, 1))

        # pipe xarray datasets to the sampler
        sample = Sample(
            historical_ERA5_images=historical_ERA5_images,
            target_ERA5_images=target_ERA5_images,
            datetime_index=datetime_as_number
        )
        
        # ==================================== #
        # data normalization
        if self.transform:
            sample = self.transform(sample)

        # assign sample index
        sample["index"] = index

        return sample

## Transform class

In [14]:
class Normalize_ERA5_and_Forcing:
    def __init__(self, conf):
        
        # import the variable mean
        self.mean_ds = xr.open_dataset(conf['data']['mean_path'])
        
        # import the variable std
        self.std_ds = xr.open_dataset(conf['data']['std_path'])

        # get levels and upper air variables
        self.levels = conf['model']['levels']
        self.varname_upper_air = conf['data']['variables']
        self.num_upper_air = (len(self.varname_upper_air)*self.levels)

        # identify the existence of other variables
        self.flag_surface = ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0)
        self.flag_diagnostic = ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0)
        self.flag_forcing = ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0)
        self.flag_static = ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0)
        
        # get surface varnames
        if self.flag_surface:
            self.varname_surface = conf["data"]["surface_variables"]

        # get diagnostic varnames
        if self.flag_diagnostic:
            self.varname_diagnostic = conf["data"]["diagnostic_variables"]

        # get forcing varnames
        if self.flag_forcing:
            self.varname_forcing = conf["data"]["forcing_variables"]
        else:
            self.varname_forcing = []

        # get static varnames:
        if self.flag_static:
            self.varname_static = conf["data"]["static_variables"]
        else:
            self.varname_static = []
            
        if self.flag_forcing or self.flag_static:
            self.has_forcing_static = True
            self.varname_forcing_static = self.varname_forcing + self.varname_static
            self.num_forcing_static = len(self.varname_forcing_static)
        else:
            self.has_forcing_static = False
        
        logger.info("Loading stored mean and std data for z-score-based transform and inverse transform")

    def __call__(self, sample: Sample, inverse: bool = False) -> Sample:
        if inverse:
            # inverse transformation
            return self.inverse_transform(sample)
        else:
            # transformation
            return self.transform(sample)

    def transform_array(self, x: torch.Tensor) -> torch.Tensor:
        '''
        this function applies to y_pred, so there won't be forcing and static variables.
        Consider its usage (standardize y_pred as input of the next iteration), 
            diagnostics don't need to be trnasformed.
        '''
        # get the current device
        device = x.device

        # subset upper air
        tensor_upper_air = x[:, :self.num_upper_air, :, :]
        transformed_upper_air = tensor_upper_air.clone()
        
        # surface variables
        if self.flag_surface:
            tensor_surface = x[:, self.num_upper_air:(self.num_upper_air+self.num_surface), :, :]
            transformed_surface = tensor_surface.clone()
            
        # diagnostic variables (the very last of the stack)
        if self.flag_diagnostic:
            tensor_diagnostic = x[:, -self.num_diagnostic:, :, :]
            transformed_diagnostic = tensor_diagnostic.clone()
        
        # standardize upper air variables
        # upper air variable structure: var 1 [all levels] --> var 2 [all levels]
        k = 0
        for name in self.varname_upper_air:
            for level in range(self.levels):
                var_mean = self.mean_ds[name].values[level]
                var_std = self.std_ds[name].values[level]
                transformed_upper_air[:, k] = (tensor_upper_air[:, k] - var_mean) / var_std
                k += 1
        
        # standardize surface variables
        if self.flag_surface:
            for k, name in enumerate(self.varname_surface):
                var_mean = self.mean_ds[name].values
                var_std = self.std_ds[name].values
                transformed_surface[:, k] = (tensor_surface[:, k] - var_mean) / var_std
                
        # concat everything
        if self.flag_surface:
            if self.flag_diagnostic:
                transformed_x = torch.cat((transformed_upper_air, 
                                           transformed_surface, 
                                           transformed_diagnostic), dim=1)
            else:
                transformed_x = torch.cat((transformed_upper_air, 
                                           transformed_surface), dim=1)
        else:
            if self.flag_diagnostic:
                transformed_x = torch.cat((transformed_upper_air,
                                           transformed_diagnostic), dim=1)
            else:
                transformed_x = transformed_upper_air
            
        return transformed_x.to(device)

    def transform(self, sample: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        '''
        This function transforms training batches, it handles forcing & static as follows:
            - forcing & static don't need to be transformed; users should transform them and save them to the file
            - other variables (upper-air, surface, diagnostics) need to be transformed
        '''
        normalized_sample = {}
        if self.has_forcing_static:
            for key, value in sample.items():
                # key: 'historical_ERA5_images', 'target_ERA5_images'
                # value: the xarray datasets
                if isinstance(value, xr.Dataset):
                    # training input
                    if key == 'historical_ERA5_images':

                        # get all the input vars
                        varname_inputs = value.keys()

                        # loop through dataset variables, handle forcing and static differently
                        for varname in varname_inputs:

                            # if forcing and static skip it, otherwise do z-score
                            if (varname in self.varname_forcing_static) is False:
                                value[varname] = (value[varname] - self.mean_ds[varname]) / self.std_ds[varname]
                        
                        # put transformed back to 
                        normalized_sample[key] = value
                        
                    # target fields do not contain forcing and static
                    else:
                        normalized_sample[key] = (value - self.mean_ds) / self.std_ds
        else:
            for key, value in sample.items():
                if isinstance(value, xr.Dataset):
                    normalized_sample[key] = (value - self.mean_ds) / self.std_ds
                        
        return normalized_sample
        
    def inverse_transform(self, x: torch.Tensor) -> torch.Tensor:
        '''
        this function applies to y_pred, so there won't be forcing and static variables here 
        '''
        # get the current device
        device = x.device
        
        # subset upper air
        tensor_upper_air = x[:, :self.num_upper_air, :, :]
        transformed_upper_air = tensor_upper_air.clone()
        
        # surface variables
        if self.flag_surface:
            tensor_surface = x[:, self.num_upper_air:(self.num_upper_air+self.num_surface), :, :]
            transformed_surface = tensor_surface.clone()
            
        # diagnostic variables (the very last of the stack)
        if self.flag_diagnostic:
            tensor_diagnostic = x[:, -self.num_diagnostic:, :, :]
            transformed_diagnostic = tensor_diagnostic.clone()
            
        # reverse upper air variables
        k = 0
        for name in self.varname_upper_air:
            for level in range(self.levels):
                mean = self.mean_ds[name].values[level]
                std = self.std_ds[name].values[level]
                transformed_upper_air[:, k] = tensor_upper_air[:, k] * std + mean
                k += 1
                
        # reverse surface variables
        if self.flag_surface:
            for k, name in enumerate(self.varname_surface):
                mean = self.mean_ds[name].values
                std = self.std_ds[name].values
                transformed_surface[:, k] = tensor_surface[:, k] * std + mean

        # reverse diagnostic variables
        if self.flag_diagnostic:
            for k, name in enumerate(self.varname_diagnostic):
                mean = self.mean_ds[name].values
                std = self.std_ds[name].values
                transformed_diagnostic[:, k] = transformed_diagnostic[:, k] * std + mean

        # concat everything
        if self.flag_surface:
            if self.flag_diagnostic:
                transformed_x = torch.cat((transformed_upper_air, 
                                           transformed_surface, 
                                           transformed_diagnostic), dim=1)
            else:
                transformed_x = torch.cat((transformed_upper_air, 
                                           transformed_surface), dim=1)
        else:
            if self.flag_diagnostic:
                transformed_x = torch.cat((transformed_upper_air, 
                                           transformed_diagnostic), dim=1)
            else:
                transformed_x = transformed_upper_air
        
        return transformed_x.to(device)

## xarray to tensor class

In [15]:
class ToTensor_ERA5_and_Forcing:
    def __init__(self, conf):
        self.conf = conf
        self.hist_len = int(conf["data"]["history_len"])
        self.for_len = int(conf["data"]["forecast_len"])
        
        # identify the existence of other variables
        self.flag_surface = ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0)
        self.flag_diagnostic = ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0)
        self.flag_forcing = ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0)
        self.flag_static = ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0)
        
        self.varname_upper_air = conf["data"]["variables"]

        # get surface varnames
        if self.flag_surface:
            self.varname_surface = conf["data"]["surface_variables"]

        # get diagnostic varnames
        if self.flag_diagnostic:
            self.varname_diagnostic = conf["data"]["diagnostic_variables"]

        # get forcing varnames
        if self.flag_forcing:
            self.varname_forcing = conf["data"]["forcing_variables"]
        else:
            self.varname_forcing = []

        # get static varnames:
        if self.flag_static:
            self.varname_static = conf["data"]["static_variables"]
        else:
            self.varname_static = []
            
        
        if self.flag_forcing or self.flag_static:
            self.has_forcing_static = True
        else:
            self.has_forcing_static = False
            
        #self.allvars = self.varname_upper_air + self.varname_surface
            
    def __call__(self, sample: Sample) -> Sample:

        return_dict = {}
        
        for key, value in sample.items():
            
            ## if DataArray
            if isinstance(value, xr.DataArray):
                var_value = value.values

            ## if Dataset
            elif isinstance(value, xr.Dataset):
                
                # organize upper-air vars
                list_vars_upper_air = []
                
                for var_name in self.varname_upper_air:
                    var_value = value[var_name].values
                    list_vars_upper_air.append(var_value)
                numpy_vars_upper_air = np.array(list_vars_upper_air) # [num_vars, hist_len, num_levels, lat, lon]

                # organize surface vars
                if self.flag_surface:
                    list_vars_surface = []
                    
                    for var_name in self.varname_surface:
                        var_value = value[var_name].values
                        list_vars_surface.append(var_value)
                    
                    numpy_vars_surface = np.array(list_vars_surface) # [num_surf_vars, hist_len, lat, lon]

                # organize forcing and static (input only)
                if self.has_forcing_static:
                    if key == 'historical_ERA5_images' or key == 'x':
                        list_vars_forcing_static = []
                        for var_name in (self.varname_forcing + self.varname_static):
                            var_value = value[var_name].values
                            list_vars_forcing_static.append(var_value)
    
                        numpy_vars_forcing_static = np.array(list_vars_forcing_static)

                # organize diagnostic vars (target only)
                if self.flag_diagnostic:
                    if key == 'target_ERA5_images' or key == 'y':
                        list_vars_diagnostic = []
                        for var_name in self.varname_diagnostic:
                            var_value = value[var_name].values
                            list_vars_diagnostic.append(var_value)
                            
                        numpy_vars_diagnostic = np.array(list_vars_diagnostic)
                            
            ## if numpy
            else:
                var_value = value

            # ---------------------------------------------------------------------- #
            # ToTensor: upper-air varialbes
            ## [upper_var, time, level, lat, lon] --> [time, upper_var, level, lat, lon]
            x_upper_air = np.hstack([
                np.expand_dims(var_upper_air, axis=1) for var_upper_air in numpy_vars_upper_air])
            x_upper_air = torch.as_tensor(x_upper_air)
            
            # ---------------------------------------------------------------------- #
            # ToTensor: surface variables
            if self.flag_surface:
                x_surf = torch.as_tensor(numpy_vars_surface).squeeze()
                
                if len(x_surf.shape) == 4:
                    # [surface_var, time, lat, lon] --> [time, surface_var, lat, lon]
                    x_surf = x_surf.permute(1, 0, 2, 3)
                    
                elif len(x_surf.shape) == 3:
                    if len(self.varname_surface) > 1:
                        # single time, multi-vars
                        x_surf = x_surf.unsqueeze(0)
                    else:
                        # multi-time, single vars
                        x_surf = x_surf.unsqueeze(1)
                        
                else:
                    x_surf = x_surf.unsqueeze(0).unsqueeze(0)
                
            if key == 'historical_ERA5_images' or key == 'x':

                # ---------------------------------------------------------------------- #    
                # ToTensor: forcing and static
                if self.has_forcing_static:
                    
                    x_static = torch.as_tensor(numpy_vars_forcing_static).squeeze()
                    
                    if len(x_static.shape) == 4:
                        # [forcing_var, time, lat, lon] --> [time, forcing_var, lat, lon]
                        x_static = x_static.permute(1, 0, 2, 3)
                        
                    elif len(x_static.shape) == 3:
                        if len(self.varname_forcing)+len(self.varname_static) > 1:
                            # single time, multi-vars
                            x_static = x_static.unsqueeze(0)
                        else:
                            # multi-time, single vars
                            x_static = x_static.unsqueeze(1)
                    else:
                        x_static = x_static.unsqueeze(0).unsqueeze(0)
                        
                        # assuming 
                        # [time, lat, lon] --> [time, 1, lat, lon]
                        x_static = x_static.unsqueeze(1)
                        
                    return_dict['x_forcing_static'] = x_static
                
                if self.flag_surface:
                    return_dict['x_surf'] = x_surf
                    
                return_dict['x'] = x_upper_air
                
            elif key == 'target_ERA5_images' or key == 'y':

                # ---------------------------------------------------------------------- #    
                # ToTensor: diagnostic
                if self.flag_diagnostic: 
                    
                    y_diag = torch.as_tensor(numpy_vars_diagnostic).squeeze()
                    
                    if len(y_diag.shape) == 4:
                        # [surface_var, time, lat, lon] --> [time, surface_var, lat, lon]
                        y_diag = y_diag.permute(1, 0, 2, 3)
                        
                    elif len(y_diag.shape) == 3:
                        if len(self.varname_diagnostic) > 1:
                            # single time, multi-vars
                            y_diag = y_diag.unsqueeze(0)
                        else:
                            # multi-time, single vars
                            y_diag = y_diag.unsqueeze(1)
                            
                    else:
                        y_diag = y_diag.unsqueeze(0).unsqueeze(0)
                
                    return_dict['y_diag'] = y_diag
                    
                if self.flag_surface:    
                    return_dict['y_surf'] = x_surf
                    
                return_dict['y'] = x_upper_air
                
        return return_dict

## Testing the dataset

In [16]:
is_train = False
# convert $USER to the actual user name
conf['save_loc'] = os.path.expandvars(conf['save_loc'])

# ======================================================== #
# parse intputs

# file names
all_ERA_files = sorted(glob(conf['data']['save_loc']))
varname_upper_air = conf['data']['variables']

if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None

if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

if ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0):
    surface_files = sorted(glob(conf['data']['save_loc_surface']))
    varname_surface = conf['data']['surface_variables']
else:
    surface_files = None
    varname_surface = None
    
if ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0):
    diagnostic_files = sorted(glob(conf['data']['save_loc_diagnostic']))
    varname_diagnostic = conf['data']['diagnostic_variables']
else:
    diagnostic_files = None
    varname_diagnostic = None


# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]

if is_train:
    history_len = history_len
    forecast_len = forecast_len
    # print out training / validation
    name = "training"
else:
    history_len = valid_history_len
    forecast_len = valid_forecast_len
    name = 'validation'
    
# max_forecast_len
if "max_forecast_len" not in conf["data"]:
    max_forecast_len = None
else:
    max_forecast_len = conf["data"]["max_forecast_len"]

# skip_periods
if "skip_periods" not in conf["data"]:
    skip_periods = None
else:
    skip_periods = conf["data"]["skip_periods"]
    
# one_shot
if "one_shot" not in conf["data"]:
    one_shot = None
else:
    one_shot = conf["data"]["one_shot"]

# shufle
shuffle = is_train

# Z-score
dataset = ERA5_and_Forcing_Dataset(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=one_shot,
    max_forecast_len=max_forecast_len,
    transform=None
)

In [17]:
samples = next(iter(dataset))

In [18]:
samples

{'historical_ERA5_images': <xarray.Dataset> Size: 465MB
 Dimensions:     (time: 2, level: 15, latitude: 640, longitude: 1280,
                  half_level: 138)
 Coordinates:
   * half_level  (half_level) int32 552B 1 2 3 4 5 6 ... 133 134 135 136 137 138
   * latitude    (latitude) float64 5kB 89.78 89.51 89.23 ... -89.51 -89.78
   * level       (level) int32 60B 10 30 40 50 60 70 ... 100 105 110 120 130 136
   * longitude   (longitude) float64 10kB 0.0 0.2812 0.5625 ... 359.2 359.4 359.7
   * time        (time) datetime64[ns] 16B 1979-01-01 1979-01-01T01:00:00
 Data variables: (12/14)
     Q           (time, level, latitude, longitude) float32 98MB 3.235e-06 ......
     T           (time, level, latitude, longitude) float32 98MB 283.3 ... 249.9
     U           (time, level, latitude, longitude) float32 98MB -36.52 ... -1...
     V           (time, level, latitude, longitude) float32 98MB 41.93 ... -6.26
     Q500        (time, latitude, longitude) float32 7MB 0.0001128 ... 0.0001881

## Testing the z-score transfrom

In [46]:
transform_scaler = Normalize_ERA5_and_Forcing(conf)

In [47]:
test_transform = transform_scaler(samples)

<xarray.Dataset> Size: 210MB
Dimensions:     (time: 1, level: 15, latitude: 640, longitude: 1280,
                 half_level: 138)
Coordinates:
  * half_level  (half_level) int32 552B 1 2 3 4 5 6 ... 133 134 135 136 137 138
  * latitude    (latitude) float64 5kB 89.78 89.51 89.23 ... -89.51 -89.78
  * level       (level) int32 60B 10 30 40 50 60 70 ... 100 105 110 120 130 136
  * longitude   (longitude) float64 10kB 0.0 0.2812 0.5625 ... 359.2 359.4 359.7
  * time        (time) datetime64[ns] 8B 1979-01-01T02:00:00
Data variables:
    Q           (time, level, latitude, longitude) float32 49MB 3.272e-06 ......
    T           (time, level, latitude, longitude) float32 49MB 282.3 ... 249.7
    U           (time, level, latitude, longitude) float32 49MB -45.89 ... -1...
    V           (time, level, latitude, longitude) float32 49MB 34.42 ... -6.284
    SP          (time, latitude, longitude) float32 3MB 1.028e+05 ... 7.012e+04
    t2m         (time, latitude, longitude) float32 3MB 244

In [48]:
test_transform.keys()

dict_keys(['historical_ERA5_images', 'target_ERA5_images'])

In [49]:
test_transform['historical_ERA5_images']

In [50]:
test_transform['target_ERA5_images']

## Testing the xarray to tensor

In [24]:
def concat_and_reshape(x1, x2):
    x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5])
    x_concat = torch.cat((x1, x2), dim=2)
    return x_concat.permute(0, 2, 1, 3, 4)

In [25]:
to_tensor_scaler = ToTensor_ERA5_and_Forcing(conf)

In [26]:
test_tensor = to_tensor_scaler(samples)

In [27]:
test_tensor.keys()

dict_keys(['x_surf', 'x', 'x_forcing_static', 'y_diag', 'y_surf', 'y'])

In [28]:
test_tensor['x_surf'].shape

torch.Size([2, 2, 640, 1280])

In [29]:
test_tensor['x_forcing_static'].shape

torch.Size([2, 3, 640, 1280])

In [30]:
test_tensor['x'].shape

torch.Size([2, 4, 15, 640, 1280])

In [31]:
test_tensor['y_diag'].shape

torch.Size([1, 1, 640, 1280])

In [73]:
test_tensor['y_surf'].shape

torch.Size([1, 7, 640, 1280])

In [74]:
test_tensor['y'].shape

torch.Size([1, 4, 15, 640, 1280])

## The wrapper function

In [75]:
def load_transforms(conf):
    if conf["data"]["scaler_type"] == 'quantile':
        transform_scaler = NormalizeState_Quantile(conf)
    elif conf["data"]["scaler_type"] == 'quantile-cached':
        transform_scaler = NormalizeState_Quantile_Bridgescalar(conf)
    elif conf["data"]["scaler_type"] == 'std':
        transform_scaler = NormalizeState(conf)
    # ------------------------------------------------ #
    # experimental
    elif conf["data"]["scaler_type"] == 'std_new':
        transform_scaler = Normalize_ERA5_and_Forcing(conf)
        
    else:
        logger.log('scaler type not supported check data: scaler_type in config file')
        raise

    if conf["data"]["scaler_type"] == 'quantile-cached':
        to_tensor_scaler = ToTensor_BridgeScaler(conf)
    # ------------------------------------------------ #
    # experimental
    elif conf["data"]["scaler_type"] == 'std_new':
        to_tensor_scaler = ToTensor_ERA5_and_Forcing(conf)
        
    else:
        to_tensor_scaler = ToTensor(conf=conf)

    return tforms.Compose([
            transform_scaler,
            to_tensor_scaler,
        ])

In [79]:
def load_dataset_and_sampler_zscore_only(conf, all_ERA_files, surface_files, diagnostic_files, world_size, rank, is_train, seed=42):

    # convert $USER to the actual user name
    conf['save_loc'] = os.path.expandvars(conf['save_loc'])

    # ======================================================== #
    # parse intputs
    
    # file names
    varname_upper_air = conf['data']['variables']
    
    if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
        forcing_files = conf['data']['save_loc_forcing']
        varname_forcing = conf['data']['forcing_variables']
    else:
        forcing_files = None
        varname_forcing = None
    
    if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
        static_files = conf['data']['save_loc_static']
        varname_static = conf['data']['static_variables']
    else:
        static_files = None
        varname_static = None
    
    if surface_files is not None:
        varname_surface = conf['data']['surface_variables']
    else:
        varname_surface = None
        
    if diagnostic_files is not None:
        varname_diagnostic = conf['data']['diagnostic_variables']
    else:
        varname_diagnostic = None
        
    # number of previous lead time inputs
    history_len = conf["data"]["history_len"]
    valid_history_len = conf["data"]["valid_history_len"]

    # number of lead times to forecast
    forecast_len = conf["data"]["forecast_len"]
    valid_forecast_len = conf["data"]["valid_forecast_len"]
    
    if is_train:
        history_len = history_len
        forecast_len = forecast_len
        # print out training / validation
        name = "training"
    else:
        history_len = valid_history_len
        forecast_len = valid_forecast_len
        name = 'validation'
        
    # max_forecast_len
    if "max_forecast_len" not in conf["data"]:
        max_forecast_len = None
    else:
        max_forecast_len = conf["data"]["max_forecast_len"]

    # skip_periods
    if "skip_periods" not in conf["data"]:
        skip_periods = None
    else:
        skip_periods = conf["data"]["skip_periods"]
        
    # one_shot
    if "one_shot" not in conf["data"]:
        one_shot = None
    else:
        one_shot = conf["data"]["one_shot"]

    # shufle
    shuffle = is_train
    
    # data preprocessing utils
    transforms = load_transforms(conf)

    # Z-score
    dataset = ERA5_and_Forcing_Dataset(
        varname_upper_air=varname_upper_air,
        varname_surface=varname_surface,
        varname_forcing=varname_forcing,
        varname_static=varname_static,
        varname_diagnostic=varname_diagnostic,
        filenames=all_ERA_files,
        filename_surface=surface_files,
        filename_forcing=forcing_files,
        filename_static=static_files,
        filename_diagnostic=diagnostic_files,
        history_len=history_len,
        forecast_len=forecast_len,
        skip_periods=skip_periods,
        one_shot=one_shot,
        max_forecast_len=max_forecast_len,
        transform=transforms
    )
    
    # Pytorch sampler
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        seed=seed,
        shuffle=shuffle,
        drop_last=True
    )
    
    logging.info(f" Loaded a {name} ERA dataset, and a distributed sampler (forecast length = {forecast_len + 1})")

    return dataset, sampler

In [80]:
# WORLD_SIZE = int(os.environ["WORLD_SIZE"])
# RANK = int(os.environ["RANK"])

In [81]:
# file names
all_ERA_files = sorted(glob(conf['data']['save_loc']))
surface_files = sorted(glob(conf['data']['save_loc_surface']))
diagnostic_files = sorted(glob(conf['data']['save_loc_diagnostic']))

In [82]:
dataset, sampler = load_dataset_and_sampler_zscore_only(conf, all_ERA_files, surface_files, diagnostic_files,
                                                        world_size=1, rank=0, is_train=True, seed=42)

In [83]:
samples = next(iter(dataset))

In [84]:
samples.keys()

dict_keys(['x_surf', 'x', 'x_forcing_static', 'y_diag', 'y_surf', 'y', 'index'])