# Develop and test Pytorch IterableDataset for multi-step training

In [1]:
import os
import glob
import yaml
import logging
from functools import partial
from concurrent.futures import ProcessPoolExecutor as Pool
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np

import torch
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

from credit.data import (Sample, find_key_for_number, get_forward_data, 
                         drop_var_from_dataset, extract_month_day_hour, find_common_indices)

from credit.transforms import load_transforms

In [2]:
logger = logging.getLogger(__name__)

In [3]:
# new rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/wxformer_6h/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_dyn = yaml.safe_load(stream)

In [4]:
conf = conf_dyn
is_train = False

## Load transforms and single-step / one-shot dataset

In [5]:
if 'train_years' in conf['data']:
    train_years_range = conf['data']['train_years']
else:
    train_years_range = [1979, 2014]

if 'valid_years' in conf['data']:
    valid_years_range = conf['data']['valid_years']
else:
    valid_years_range = [2014, 2018]

# convert year info to str for file name search
train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]

# get file names
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

# <------------------------------------------ std_new
if conf['data']['scaler_type'] == 'std_new':

    # check and glob surface files
    if ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0):
        
        print('collecting surface files')
        surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))
        
    else:
        surface_files = None

    # check and glob dyn forcing files
    if ('dynamic_forcing_variables' in conf['data']) and (len(conf['data']['dynamic_forcing_variables']) > 0):

        print('collecting dynamic forcing files')
        dyn_forcing_files = sorted(glob.glob(conf["data"]["save_loc_dynamic_forcing"]))
        
    else:
        dyn_forcing_files = None

    # check and glob diagnostic files
    if ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0):

        print('collecting diagnostic files')
        diagnostic_files = sorted(glob.glob(conf["data"]["save_loc_diagnostic"]))
        
    else:
        diagnostic_files = None

# Filter the files for training / validation
train_files = [file for file in all_ERA_files if any(year in file for year in train_years)]
valid_files = [file for file in all_ERA_files if any(year in file for year in valid_years)]

# <----------------------------------- std_new
if conf['data']['scaler_type'] == 'std_new':
    
    if surface_files is not None:
        
        train_surface_files = [file for file in surface_files if any(year in file for year in train_years)]
        valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_surface_files) == len(train_files), \
        'Mismatch between the total number of training set [surface files] and [upper-air files]'
        assert len(valid_surface_files) == len(valid_files), \
        'Mismatch between the total number of validation set [surface files] and [upper-air files]'
    
    else:
        train_surface_files = None
        valid_surface_files = None

    if dyn_forcing_files is not None:
        
        train_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in train_years)]
        valid_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_dyn_forcing_files) == len(train_files), \
        'Mismatch between the total number of training set [dynamic forcing files] and [upper-air files]'
        assert len(valid_dyn_forcing_files) == len(valid_files), \
        'Mismatch between the total number of validation set [dynamic forcing files] and [upper-air files]'
    
    else:
        train_dyn_forcing_files = None
        valid_dyn_forcing_files = None
        
    if diagnostic_files is not None:
        
        train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)]
        valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_diagnostic_files) == len(train_files), \
        'Mismatch between the total number of training set [diagnostic files] and [upper-air files]'
        assert len(valid_diagnostic_files) == len(valid_files), \
        'Mismatch between the total number of validation set [diagnostic files] and [upper-air files]'
    
    else:
        train_diagnostic_files = None
        valid_diagnostic_files = None

collecting surface files
collecting dynamic forcing files


In [6]:
# file names
varname_all = []

# upper air
varname_upper_air = conf['data']['variables']

if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None
    
if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

# get surface variable names
if surface_files is not None:
    varname_surface = conf['data']['surface_variables']
else:
    varname_surface = None

# get dynamic forcing variable names
if dyn_forcing_files is not None:
    varname_dyn_forcing = conf['data']['dynamic_forcing_variables']
else:
    varname_dyn_forcing = None

# get diagnostic variable names
if diagnostic_files is not None:
    varname_diagnostic = conf['data']['diagnostic_variables']
else:
    varname_diagnostic = None
        
# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]

if is_train:
    history_len = history_len
    forecast_len = forecast_len
    # print out training / validation
    name = "training"
else:
    history_len = valid_history_len
    forecast_len = valid_forecast_len
    name = 'validation'
    
# max_forecast_len
if "max_forecast_len" not in conf["data"]:
    max_forecast_len = None
else:
    max_forecast_len = conf["data"]["max_forecast_len"]

# skip_periods
if "skip_periods" not in conf["data"]:
    skip_periods = None
else:
    skip_periods = conf["data"]["skip_periods"]
    
# one_shot
if "one_shot" not in conf["data"]:
    one_shot = None
else:
    one_shot = conf["data"]["one_shot"]

# shufle
shuffle = False

In [7]:
# data preprocessing utils
transforms = load_transforms(conf)

In [8]:
class DistributedSequentialDataset(torch.utils.data.IterableDataset):
    
    def __init__(
        self,
        varname_upper_air: List[str],
        varname_surface: List[str],
        varname_dyn_forcing: List[str],
        varname_forcing: List[str],
        varname_static: List[str],
        varname_diagnostic: List[str],
        filenames: List[str],
        filename_surface: Optional[List[str]] = None,
        filename_dyn_forcing: Optional[List[str]] = None,
        filename_forcing: Optional[str] = None,
        filename_static: Optional[str] = None,
        filename_diagnostic: Optional[List[str]] = None,
        rank: int = 0,
        world_size: int = 1,
        history_len: int = 2,
        forecast_len: int = 0,
        transform: Optional[Callable] = None,
        seed: int = 42,
        skip_periods: Optional[int] = None,
        max_forecast_len: Optional[int] = None,
        shuffle: bool = True,
        num_workers: int = 0
    ):

        '''
        Initialize the DistributedSequentialDatasetV2.

        Parameters:
        - varname_upper_air (list): List of upper air variable names.
        - varname_surface (list): List of surface variable names.
        - varname_dyn_forcing (list): List of dynamic forcing variable names.
        - varname_forcing (list): List of forcing variable names.
        - varname_static (list): List of static variable names.
        - varname_diagnostic (list): List of diagnostic variable names.
        - filenames (list): List of filenames for upper air data.
        - filename_surface (list, optional): List of filenames for surface data.
        - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing.
        - filename_forcing (str, optional): Filename for forcing data.
        - filename_static (str, optional): Filename for static data.
        - filename_diagnostic (list, optional): List of filenames for diagnostic data.
        - rank (int, optional): Rank of the current process. Default is 0.
        - world_size (int, optional): Total number of processes. Default is 1.
        - history_len (int, optional): Length of the history sequence. Default is 2.
        - forecast_len (int, optional): Length of the forecast sequence. Default is 0.
        - transform (callable, optional): Transformation function to apply to the data.
        - seed (int, optional): Random seed for reproducibility. Default is 42.
        - skip_periods (int, optional): Number of periods to skip between samples.
        - max_forecast_len (int, optional): Maximum length of the forecast sequence.
        - shuffle (bool, optional): Whether to shuffle the data. Default is True.
        - num_workers (int, optional): Number of worker processes. Default is 0.

        Returns:
        - sample (dict): A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.
        '''

        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform
        self.rank = rank
        self.world_size = world_size
        self.shuffle = shuffle
        self.current_epoch = 0
        self.num_workers = num_workers

        logger.info(f"Using {num_workers} workers in the iterable dataset")

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # total number of needed forecast lead times
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)

        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # ERA5 operations
        all_files = []
        filenames = sorted(filenames)
        
        for fn in filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)
            
        self.all_files = all_files
        
        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {} # <------ change
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [len(ERA5_xarray['time']),
                                                ind_start,
                                                ind_start + len(ERA5_xarray['time'])]
            ind_start += len(ERA5_xarray['time']) + 1

        # ======================================================== #
        # surface files
        if filename_surface is not None:
        
            surface_files = []
            filename_surface = sorted(filename_surface)
        
            for fn in filename_surface:
                
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)
                
                surface_files.append(xarray_dataset)
                
            self.surface_files = surface_files
            
        else:
            self.surface_files = False

        # ======================================================== #
        # dynamic forcing files
        if filename_dyn_forcing is not None:
        
            dyn_forcing_files = []
            filename_dyn_forcing = sorted(filename_dyn_forcing)
        
            for fn in filename_dyn_forcing:

                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_dyn_forcing)
                
                dyn_forcing_files.append(xarray_dataset)
                
            self.dyn_forcing_files = dyn_forcing_files
            
        else:
            self.dyn_forcing_files = False
        
        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic
        
        if self.filename_diagnostic is not None:

            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)
            
            for fn in filename_diagnostic:

                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_diagnostic)
                
                diagnostic_files.append(xarray_dataset)
                
            self.diagnostic_files = diagnostic_files
            
        else:
            self.diagnostic_files = False
        
        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing

        if self.filename_forcing is not None:
            assert os.path.isfile(filename_forcing), 'Cannot find forcing file [{}]'.format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)
            
            self.xarray_forcing = xarray_dataset
            
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static

        if self.filename_static is not None:
            assert os.path.isfile(filename_static), 'Cannot find static file [{}]'.format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)
            
            self.xarray_static = xarray_dataset
            
        else:
            self.xarray_static = False
            
    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self) -> int:
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray['time']) - self.total_seq_len + 1
        return total_len

    def set_epoch(self, epoch: int) -> None:
        self.current_epoch = epoch

    def __iter__(self):

        # ------------------------------------------------------------------- #
        # get worker info
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0

        # distributed sampler with worker info
        sampler = DistributedSampler(self, num_replicas=num_workers * self.world_size,
                                     rank=self.rank * num_workers + worker_id, shuffle=self.shuffle)
        sampler.set_epoch(self.current_epoch)

        # ------------------------------------------------------------------- #
        # worker process
        process_index_partial = partial(
            worker,
            ERA5_indices=self.ERA5_indices,
            all_files=self.all_files,
            surface_files=self.surface_files,
            dyn_forcing_files = self.dyn_forcing_files,
            diagnostic_files=self.diagnostic_files,
            xarray_forcing=self.xarray_forcing,
            xarray_static=self.xarray_static,
            history_len=self.history_len,
            forecast_len=self.forecast_len,
            skip_periods=self.skip_periods,
            transform=self.transform
        )

        # Dont use multi-processing
        if self.num_workers <= 1:
            for index in iter(sampler):
                # Explicit inner (time step) loop
                indices = list(range(index, index + self.history_len + self.forecast_len))
                for ind_start_current_step in indices:
                    sample = process_index_partial((index, ind_start_current_step))
                    yield sample
                    if sample['stop_forecast']:
                        break
                        
        else:  # use multi-processing
            with Pool(self.num_workers) as p:
                batch_size = 2 * self.num_workers  # limit the size of the "queue"
                for index in iter(sampler):
                    indices = list(range(index, index + self.history_len + self.forecast_len))

                    # Process indices in batches to avoid potential memory problems if indices is very long
                    for i in range(0, len(indices), batch_size):
                        batch_indices = indices[i:i+batch_size]
                        batch_tasks = [(index, ind_start_current_step) for ind_start_current_step in batch_indices]

                        # Process the batch
                        batch_results = p.map(process_index_partial, batch_tasks)

                        # Yield results from the batch
                        for sample in batch_results:
                            yield sample
                            if sample['stop_forecast']:
                                return


def worker(
    tuple_index: Tuple[int, int],
    ERA5_indices: Dict[str, List[int]],
    all_files: List[Any],
    surface_files: Optional[List[Any]],
    dyn_forcing_files: Optional[List[Any]],
    diagnostic_files: Optional[List[Any]],
    xarray_forcing: Optional[Any],
    xarray_static: Optional[Any],
    history_len: int,
    forecast_len: int,
    skip_periods: int,
    transform: Optional[Callable]
) -> Dict[str, Any]:

    '''
    Processes a given index to extract and transform data for a specific time slice.

    Parameters:
    - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing.
    - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata.
    - all_files (List[Any]): List of xarray datasets containing upper air data.
    - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data.
    - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data.
    - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data.
    - history_len (int): Length of the history sequence.
    - forecast_len (int): Length of the forecast sequence.
    - skip_periods (int): Number of periods to skip between samples.
    - xarray_forcing (Optional[Any]): xarray dataset containing forcing data.
    - xarray_static (Optional[Any]): xarray dataset containing static data.

    - transform (Optional[Callable]): Transformation function to apply to the data.

    Returns:
    - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.
    '''

    index, ind_start_current_step = tuple_index

    try:
        # select the ind_file based on the iter index
        ind_file = find_key_for_number(ind_start_current_step, ERA5_indices)

        # get the ind within the current file
        ind_start = ERA5_indices[ind_file][1]
        ind_start_in_file = ind_start_current_step - ind_start

        # handle out-of-bounds
        ind_largest = len(all_files[int(ind_file)]['time']) - (history_len + forecast_len + 1)
        if ind_start_in_file > ind_largest:
            ind_start_in_file = ind_largest

        # ========================================================================== #
        # subset xarray on time dimension & load it to the memory
        
        ind_end_in_file = ind_start_in_file+history_len+forecast_len
        
        ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch)
        ERA5_subset = all_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory
        
        if surface_files:
            ## subset surface variables
            surface_subset = surface_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory
            
            ## merge upper-air and surface here:
            ERA5_subset = ERA5_subset.merge(surface_subset) # <-- lazy merge, ERA5 and surface both not loaded
        
        # ==================================================== #
        # split ERA5_subset into training inputs and targets
        #   + merge with forcing and static
        
        # the ind_end of the ERA5_subset
        # ind_end_time = len(ERA5_subset['time'])
        
        # datetiem information as int number (used in some normalization methods)
        datetime_as_number = ERA5_subset.time.values.astype('datetime64[s]').astype(int)
        
        # ==================================================== #
        # xarray dataset as input
        ## historical_ERA5_images: the final input
        
        historical_ERA5_images = ERA5_subset.isel(
            time=slice(0, history_len, skip_periods)).load() # <-- load into memory

        # ========================================================================== #
        # merge dynamic forcing inputs
        if dyn_forcing_files:
            dyn_forcing_subset = dyn_forcing_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file+1))
            dyn_forcing_subset = dyn_forcing_subset.isel(
                time=slice(0, history_len, skip_periods)).load() # <-- load into memory
            
            historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset)

        # ========================================================================== #
        # merge forcing inputs
        if xarray_forcing:
            # =============================================================================== #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(np.array(xarray_forcing['time']))
            month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time'])) # <-- upper air
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            forcing_subset_input = xarray_forcing.isel(time=ind_forcing).load() # <-- load into memory
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            forcing_subset_input['time'] = historical_ERA5_images['time']
            # =============================================================================== #
        
            # merge
            historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

        # ========================================================================== #
        # merge static inputs
        if xarray_static:
            # expand static var on time dim
            N_time_dims = len(ERA5_subset['time'])
            static_subset_input = xarray_static.expand_dims(dim={"time": N_time_dims})
            # assign coords 'time'
            static_subset_input = static_subset_input.assign_coords({'time': ERA5_subset['time']})
        
            # slice + load to the GPU
            static_subset_input = static_subset_input.isel(
                time=slice(0, history_len, skip_periods)).load() # <-- load into memory
        
            # update 
            static_subset_input['time'] = historical_ERA5_images['time']
        
            # merge
            historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)

        # ==================================================== #
        # xarray dataset as target
        ## target_ERA5_images: the final target
        
        # get the next forecast step
        target_ERA5_images = ERA5_subset.isel(
            time=slice(history_len, history_len+skip_periods, skip_periods)).load() # <-- load into memory
        
        ## merge diagnoisc input here:
        if diagnostic_files:
            
            # subset diagnostic variables
            diagnostic_subset = diagnostic_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file+1))
            
            # get the next forecast step
            diagnostic_subset = diagnostic_subset.isel(
                time=slice(history_len, history_len+skip_periods, skip_periods)
            ).load() # <-- load into memory
            
            # merge into the target dataset
            target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

        # create a dict object with input/output tensors
        sample = Sample(
            historical_ERA5_images=historical_ERA5_images,
            target_ERA5_images=target_ERA5_images,
            datetime_index=datetime_as_number
        )

        # data normalization
        if transform:
            sample = transform(sample)

        sample["index"] = index
        stop_forecast = ((ind_start_current_step - index) == forecast_len)
        sample['forecast_hour'] = ind_start_current_step - index + 1
        sample['index'] = index
        sample['stop_forecast'] = stop_forecast
        sample["datetime"] = [
            int(historical_ERA5_images.time.values[0].astype('datetime64[s]').astype(int)),
            int(target_ERA5_images.time.values[0].astype('datetime64[s]').astype(int))
        ]

        # # print out to check input and target datetimes
        # print('Input time: {}'.format(np.array(historical_ERA5_images['time'])))
        # print('Target time: {}'.format(np.array(target_ERA5_images['time'])))
    
    except Exception as e:
        logger.error(f"Error processing index {tuple_index}: {e}")
        raise

    return sample


class DistributedSequentialDatasetBasic(DistributedSequentialDataset):

    def __iter__(self):

        # ------------------------------------------------------------------- #
        # get worker info
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0

        # distributed sampler with worker info
        sampler = DistributedSampler(self, num_replicas=num_workers * self.world_size,
                                     rank=self.rank * num_workers + worker_id, shuffle=self.shuffle)
        sampler.set_epoch(self.current_epoch)

        for index in iter(sampler):

            indices = list(range(index, index + self.history_len + self.forecast_len))
            stop_forecast = False

            for k, ind_start_current_step in enumerate(indices):

                # select the ind_file based on the iter index
                ind_file = find_key_for_number(ind_start_current_step, self.ERA5_indices)

                # get the ind within the current file
                ind_start = self.ERA5_indices[ind_file][1]
                ind_start_in_file = ind_start_current_step - ind_start

                # handle out-of-bounds
                ind_largest = len(self.all_files[int(ind_file)]['time'])-(self.history_len+self.forecast_len+1)
                if ind_start_in_file > ind_largest:
                    ind_start_in_file = ind_largest
                # ========================================================================== #
                # subset xarray on time dimension & load it to the memory

                ind_end_in_file = ind_start_in_file+self.history_len+self.forecast_len

                ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch)
                ERA5_subset = self.all_files[int(ind_file)].isel(
                    time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory

                if self.surface_files:
                    ## subset surface variables
                    surface_subset = self.surface_files[int(ind_file)].isel(
                        time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory

                    ## merge upper-air and surface here:
                    ERA5_subset = ERA5_subset.merge(surface_subset) # <-- lazy merge, ERA5 and surface both not loaded

                # ==================================================== #
                # split ERA5_subset into training inputs and targets + merge with forcing and static

                # the ind_end of the ERA5_subset
                # ind_end_time = len(ERA5_subset['time'])

                # datetiem information as int number (used in some normalization methods)
                datetime_as_number = ERA5_subset.time.values.astype('datetime64[s]').astype(int)

                # ==================================================== #
                # xarray dataset as input
                # historical_ERA5_images: the final input
                
                historical_ERA5_images = ERA5_subset.isel(
                    time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory

                # ========================================================================== #
                # merge dynamic forcing inputs
                if self.dyn_forcing_files:
                    dyn_forcing_subset = self.dyn_forcing_files[int(ind_file)].isel(
                        time=slice(ind_start_in_file, ind_end_in_file+1))
                    dyn_forcing_subset = dyn_forcing_subset.isel(
                        time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory
                    
                    historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset)

                # ========================================================================== #
                # merge forcing inputs
                if self.xarray_forcing:
                    # =============================================================================== #
                    # matching month, day, hour between forcing and upper air [time]
                    # this approach handles leap year forcing file and non-leap-year upper air file
                    month_day_forcing = extract_month_day_hour(np.array(self.xarray_forcing['time']))
                    month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time']))  # <-- upper air
                    # indices to subset
                    ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
                    forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing).load()
                    # forcing and upper air have different years but the same mon/day/hour
                    # safely replace forcing time with upper air time
                    forcing_subset_input['time'] = historical_ERA5_images['time']
                    # =============================================================================== #

                    # merge
                    historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

                # ========================================================================== #
                # merge static inputs
                if self.xarray_static:
                    # expand static var on time dim
                    N_time_dims = len(ERA5_subset['time'])
                    static_subset_input = self.xarray_static.expand_dims(dim={"time": N_time_dims})
                    # assign coords 'time'
                    static_subset_input = static_subset_input.assign_coords({'time': ERA5_subset['time']})

                    # slice + load to the GPU
                    static_subset_input = static_subset_input.isel(
                        time=slice(0, self.history_len, self.skip_periods)).load()

                    # update
                    static_subset_input['time'] = historical_ERA5_images['time']

                    # merge
                    historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)

                # ==================================================== #
                # xarray dataset as target
                # target_ERA5_images: the final target
                
                # get the next forecast step
                target_ERA5_images = ERA5_subset.isel(
                    time=slice(self.history_len, self.history_len+self.skip_periods, self.skip_periods)
                ).load() # <-- load into memory
                
                ## merge diagnoisc input here:
                if self.diagnostic_files:
                    
                    # subset diagnostic variables
                    diagnostic_subset = self.diagnostic_files[int(ind_file)].isel(
                        time=slice(ind_start_in_file, ind_end_in_file+1))
                    
                    # get the next forecast step
                    diagnostic_subset = diagnostic_subset.isel(
                        time=slice(
                            self.history_len, self.history_len+self.skip_periods, self.skip_periods)
                    ).load() # <-- load into memory
                    
                    # merge into the target dataset
                    target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

                # pipe xarray datasets to the sampler
                sample = Sample(
                    historical_ERA5_images=historical_ERA5_images,
                    target_ERA5_images=target_ERA5_images,
                    datetime_index=datetime_as_number
                )

                # ==================================== #
                # data normalization
                if self.transform:
                    sample = self.transform(sample)

                # assign sample index
                sample["index"] = index

                stop_forecast = (k == self.forecast_len)

                sample['forecast_hour'] = k + 1
                sample['index'] = index
                sample['stop_forecast'] = stop_forecast
                sample["datetime"] = [
                    int(historical_ERA5_images.time.values[0].astype('datetime64[s]').astype(int)),
                    int(target_ERA5_images.time.values[0].astype('datetime64[s]').astype(int))
                ]

                # # print out to check input and target datetimes
                # print('Input time: {}'.format(np.array(historical_ERA5_images['time'])))
                # print('Target time: {}'.format(np.array(target_ERA5_images['time'])))
                
                yield sample

                if stop_forecast:
                    break

                if (k == self.forecast_len):
                    break

In [9]:
forecast_len = 3 # really its 4

# Z-score
dataset = DistributedSequentialDataset(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    max_forecast_len=max_forecast_len,
    transform=transforms,
    rank=0,
    world_size=1,
    shuffle=False
)

In [10]:
for k, result in enumerate(dataset):
    print(k, result['stop_forecast'], result['x'].shape)
    if (k + 1) == forecast_len * 5 + 1:
        break

0 False torch.Size([1, 4, 15, 640, 1280])
1 False torch.Size([1, 4, 15, 640, 1280])
2 False torch.Size([1, 4, 15, 640, 1280])
3 True torch.Size([1, 4, 15, 640, 1280])
4 False torch.Size([1, 4, 15, 640, 1280])
5 False torch.Size([1, 4, 15, 640, 1280])
6 False torch.Size([1, 4, 15, 640, 1280])
7 True torch.Size([1, 4, 15, 640, 1280])
8 False torch.Size([1, 4, 15, 640, 1280])
9 False torch.Size([1, 4, 15, 640, 1280])
10 False torch.Size([1, 4, 15, 640, 1280])
11 True torch.Size([1, 4, 15, 640, 1280])
12 False torch.Size([1, 4, 15, 640, 1280])
13 False torch.Size([1, 4, 15, 640, 1280])
14 False torch.Size([1, 4, 15, 640, 1280])
15 True torch.Size([1, 4, 15, 640, 1280])


In [15]:
result['x_forcing_static'][0, 0, ...]/6

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.7914, 0.7915, 0.7916,  ..., 0.7912, 0.7913, 0.7914],
        [0.7858, 0.7859, 0.7859,  ..., 0.7857, 0.7857, 0.7858],
        [0.7803, 0.7803, 0.7803,  ..., 0.7802, 0.7802, 0.7803]])

In [10]:
forecast_len = 3 # really its 4

# Z-score
dataset = DistributedSequentialDatasetBasic(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    max_forecast_len=max_forecast_len,
    transform=None,
    rank=0,
    world_size=1,
    shuffle=False
)

In [11]:
samples_dyn = next(iter(dataset))

In [12]:
samples_dyn

{'historical_ERA5_images': <xarray.Dataset> Size: 233MB
 Dimensions:     (time: 1, level: 15, latitude: 640, longitude: 1280,
                  half_level: 138)
 Coordinates:
   * half_level  (half_level) int32 552B 1 2 3 4 5 6 ... 133 134 135 136 137 138
   * latitude    (latitude) float64 5kB 89.78 89.51 89.23 ... -89.51 -89.78
   * level       (level) int32 60B 10 30 40 50 60 70 ... 100 105 110 120 130 136
   * longitude   (longitude) float64 10kB 0.0 0.2812 0.5625 ... 359.2 359.4 359.7
   * time        (time) datetime64[ns] 8B 1979-01-01
 Data variables: (12/14)
     Q           (time, level, latitude, longitude) float32 49MB -2.874 ... -1...
     T           (time, level, latitude, longitude) float32 49MB 4.688 ... -1.385
     U           (time, level, latitude, longitude) float32 49MB -1.259 ... -0...
     V           (time, level, latitude, longitude) float32 49MB 2.979 ... -1.274
     Q500        (time, latitude, longitude) float32 3MB -0.6879 ... -0.63
     SP          (time, 

In [13]:
from credit.transforms import Normalize_ERA5_and_Forcing

In [14]:
transform_scaler = Normalize_ERA5_and_Forcing(conf)

In [15]:
transform_scaler(samples_dyn)

<xarray.Dataset> Size: 233MB
Dimensions:     (time: 1, level: 15, latitude: 640, longitude: 1280,
                 half_level: 138)
Coordinates:
  * half_level  (half_level) int32 552B 1 2 3 4 5 6 ... 133 134 135 136 137 138
  * latitude    (latitude) float64 5kB 89.78 89.51 89.23 ... -89.51 -89.78
  * level       (level) int32 60B 10 30 40 50 60 70 ... 100 105 110 120 130 136
  * longitude   (longitude) float64 10kB 0.0 0.2812 0.5625 ... 359.2 359.4 359.7
  * time        (time) datetime64[ns] 8B 1979-01-01
Data variables: (12/14)
    Q           (time, level, latitude, longitude) float32 49MB -2.874 ... -1...
    T           (time, level, latitude, longitude) float32 49MB 4.688 ... -1.385
    U           (time, level, latitude, longitude) float32 49MB -1.259 ... -0...
    V           (time, level, latitude, longitude) float32 49MB 2.979 ... -1.274
    Q500        (time, latitude, longitude) float32 3MB -0.6879 ... -0.63
    SP          (time, latitude, longitude) float32 3MB 0.6142 0.

KeyError: "No variable named 'tsi'. Variables on the dataset include ['half_level', 'Q', 'T', 'U', 'V', ..., 'T500', 'U500', 'Q500', 'V500', 'Z500']"

In [14]:
for k, result in enumerate(dataset):
    print(k, result['stop_forecast'], result['x'].shape)
    if (k + 1) == forecast_len * 5 + 1:
        break

KeyError: "No variable named 'tsi'. Variables on the dataset include ['half_level', 'Q', 'T', 'U', 'V', ..., 'T500', 'U500', 'Q500', 'V500', 'Z500']"