In [2]:
import torch
import glob
import yaml
import os

import numpy as np
from credit.data import *
from credit.transforms import load_transforms

In [3]:
with open(
    "/glade/derecho/scratch/schreck/repos/miles-credit/production/multistep/wxformer_6h/model.yml"
) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [4]:
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

# <------------------------------------------ std_new
if conf["data"]["scaler_type"] == "std_new":
    # check and glob surface files
    if ("surface_variables" in conf["data"]) and (
        len(conf["data"]["surface_variables"]) > 0
    ):
        surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))

    else:
        surface_files = None

    # check and glob dyn forcing files
    if ("dynamic_forcing_variables" in conf["data"]) and (
        len(conf["data"]["dynamic_forcing_variables"]) > 0
    ):
        dyn_forcing_files = sorted(glob.glob(conf["data"]["save_loc_dynamic_forcing"]))

    else:
        dyn_forcing_files = None

    # check and glob diagnostic files
    if ("diagnostic_variables" in conf["data"]) and (
        len(conf["data"]["diagnostic_variables"]) > 0
    ):
        diagnostic_files = sorted(glob.glob(conf["data"]["save_loc_diagnostic"]))

    else:
        diagnostic_files = None

# -------------------------------------------------- #
# import training / validation years from conf

if "train_years" in conf["data"]:
    train_years_range = conf["data"]["train_years"]
else:
    train_years_range = [1979, 2014]

if "valid_years" in conf["data"]:
    valid_years_range = conf["data"]["valid_years"]
else:
    valid_years_range = [2014, 2018]

# convert year info to str for file name search
train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]

# Filter the files for training / validation
train_files = [
    file for file in all_ERA_files if any(year in file for year in train_years)
]
valid_files = [
    file for file in all_ERA_files if any(year in file for year in valid_years)
]

# <----------------------------------- std_new
if conf["data"]["scaler_type"] == "std_new":
    if surface_files is not None:
        train_surface_files = [
            file for file in surface_files if any(year in file for year in train_years)
        ]
        valid_surface_files = [
            file for file in surface_files if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_surface_files) == len(train_files)
        ), "Mismatch between the total number of training set [surface files] and [upper-air files]"
        assert (
            len(valid_surface_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [surface files] and [upper-air files]"

    else:
        train_surface_files = None
        valid_surface_files = None

    if dyn_forcing_files is not None:
        train_dyn_forcing_files = [
            file
            for file in dyn_forcing_files
            if any(year in file for year in train_years)
        ]
        valid_dyn_forcing_files = [
            file
            for file in dyn_forcing_files
            if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_dyn_forcing_files) == len(train_files)
        ), "Mismatch between the total number of training set [dynamic forcing files] and [upper-air files]"
        assert (
            len(valid_dyn_forcing_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [dynamic forcing files] and [upper-air files]"

    else:
        train_dyn_forcing_files = None
        valid_dyn_forcing_files = None

    if diagnostic_files is not None:
        train_diagnostic_files = [
            file
            for file in diagnostic_files
            if any(year in file for year in train_years)
        ]
        valid_diagnostic_files = [
            file
            for file in diagnostic_files
            if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_diagnostic_files) == len(train_files)
        ), "Mismatch between the total number of training set [diagnostic files] and [upper-air files]"
        assert (
            len(valid_diagnostic_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [diagnostic files] and [upper-air files]"

    else:
        train_diagnostic_files = None
        valid_diagnostic_files = None

In [5]:
# convert $USER to the actual user name
conf["save_loc"] = os.path.expandvars(conf["save_loc"])

# ======================================================== #
# parse intputs

# upper air variables
varname_upper_air = conf["data"]["variables"]

if ("forcing_variables" in conf["data"]) and (
    len(conf["data"]["forcing_variables"]) > 0
):
    forcing_files = conf["data"]["save_loc_forcing"]
    varname_forcing = conf["data"]["forcing_variables"]
else:
    forcing_files = None
    varname_forcing = None

if ("static_variables" in conf["data"]) and (len(conf["data"]["static_variables"]) > 0):
    static_files = conf["data"]["save_loc_static"]
    varname_static = conf["data"]["static_variables"]
else:
    static_files = None
    varname_static = None

# get surface variable names
if surface_files is not None:
    varname_surface = conf["data"]["surface_variables"]
else:
    varname_surface = None

# get dynamic forcing variable names
if dyn_forcing_files is not None:
    varname_dyn_forcing = conf["data"]["dynamic_forcing_variables"]
else:
    varname_dyn_forcing = None

# get diagnostic variable names
if diagnostic_files is not None:
    varname_diagnostic = conf["data"]["diagnostic_variables"]
else:
    varname_diagnostic = None

# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]

# max_forecast_len
if "max_forecast_len" not in conf["data"]:
    max_forecast_len = None
else:
    max_forecast_len = conf["data"]["max_forecast_len"]

# skip_periods
if "skip_periods" not in conf["data"]:
    skip_periods = None
else:
    skip_periods = conf["data"]["skip_periods"]

# one_shot
if "one_shot" not in conf["data"]:
    one_shot = None
else:
    one_shot = conf["data"]["one_shot"]

In [6]:
if "train_years" in conf["data"]:
    train_years_range = conf["data"]["train_years"]
else:
    train_years_range = [1979, 2014]

if "valid_years" in conf["data"]:
    valid_years_range = conf["data"]["valid_years"]
else:
    valid_years_range = [2014, 2018]

# convert year info to str for file name search
train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]

# Filter the files for training / validation
train_files = [
    file for file in all_ERA_files if any(year in file for year in train_years)
]
valid_files = [
    file for file in all_ERA_files if any(year in file for year in valid_years)
]

# <----------------------------------- std_new
if conf["data"]["scaler_type"] == "std_new":
    if surface_files is not None:
        train_surface_files = [
            file for file in surface_files if any(year in file for year in train_years)
        ]
        valid_surface_files = [
            file for file in surface_files if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_surface_files) == len(train_files)
        ), "Mismatch between the total number of training set [surface files] and [upper-air files]"
        assert (
            len(valid_surface_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [surface files] and [upper-air files]"

    else:
        train_surface_files = None
        valid_surface_files = None

    if dyn_forcing_files is not None:
        train_dyn_forcing_files = [
            file
            for file in dyn_forcing_files
            if any(year in file for year in train_years)
        ]
        valid_dyn_forcing_files = [
            file
            for file in dyn_forcing_files
            if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_dyn_forcing_files) == len(train_files)
        ), "Mismatch between the total number of training set [dynamic forcing files] and [upper-air files]"
        assert (
            len(valid_dyn_forcing_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [dynamic forcing files] and [upper-air files]"

    else:
        train_dyn_forcing_files = None
        valid_dyn_forcing_files = None

    if diagnostic_files is not None:
        train_diagnostic_files = [
            file
            for file in diagnostic_files
            if any(year in file for year in train_years)
        ]
        valid_diagnostic_files = [
            file
            for file in diagnostic_files
            if any(year in file for year in valid_years)
        ]

        # ---------------------------- #
        # check total number of files
        assert (
            len(train_diagnostic_files) == len(train_files)
        ), "Mismatch between the total number of training set [diagnostic files] and [upper-air files]"
        assert (
            len(valid_diagnostic_files) == len(valid_files)
        ), "Mismatch between the total number of validation set [diagnostic files] and [upper-air files]"

    else:
        train_diagnostic_files = None
        valid_diagnostic_files = None

In [7]:
transforms = load_transforms(conf)

  _pyproj_global_context_initialize()


In [8]:
class ERA5_and_Forcing_Dataset(torch.utils.data.Dataset):
    """
    A Pytorch Dataset class that works on:
        - upper-air variables (time, level, lat, lon)
        - surface variables (time, lat, lon)
        - dynamic forcing variables (time, lat, lon)
        - foring variables (time, lat, lon)
        - diagnostic variables (time, lat, lon)
        - static variables (lat, lon)
    """

    def __init__(
        self,
        varname_upper_air,
        varname_surface,
        varname_dyn_forcing,
        varname_forcing,
        varname_static,
        varname_diagnostic,
        filenames,
        filename_surface=None,
        filename_dyn_forcing=None,
        filename_forcing=None,
        filename_static=None,
        filename_diagnostic=None,
        history_len=2,
        forecast_len=0,
        transform=None,
        seed=42,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None,
    ):
        """
        Initialize the ERA5_and_Forcing_Dataset

        Parameters:
        - varname_upper_air (list): List of upper air variable names.
        - varname_surface (list): List of surface variable names.
        - varname_dyn_forcing (list): List of dynamic forcing variable names.
        - varname_forcing (list): List of forcing variable names.
        - varname_static (list): List of static variable names.
        - varname_diagnostic (list): List of diagnostic variable names.
        - filenames (list): List of filenames for upper air data.
        - filename_surface (list, optional): List of filenames for surface data.
        - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing data.
        - filename_forcing (str, optional): Filename for forcing data.
        - filename_static (str, optional): Filename for static data.
        - filename_diagnostic (list, optional): List of filenames for diagnostic data.
        - history_len (int, optional): Length of the history sequence. Default is 2.
        - forecast_len (int, optional): Length of the forecast sequence. Default is 0.
        - transform (callable, optional): Transformation function to apply to the data.
        - seed (int, optional): Random seed for reproducibility. Default is 42.
        - skip_periods (int, optional): Number of periods to skip between samples.
        - one_shot(bool, optional): Whether to return all states or just
                                    the final state of the training target. Default is None
        - max_forecast_len (int, optional): Maximum length of the forecast sequence.
        - shuffle (bool, optional): Whether to shuffle the data. Default is True.

        Returns:
        - sample (dict): A dictionary containing historical_ERA5_images,
                                                 target_ERA5_images,
                                                 datetime index, and additional information.
        """

        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # one shot option
        self.one_shot = one_shot

        # total number of needed forecast lead times
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)

        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # upper-air files

        all_files = []
        filenames = sorted(filenames)

        for fn in filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)

        self.all_files = all_files

        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {}  # <------ change
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [
                len(ERA5_xarray["time"]),
                ind_start,
                ind_start + len(ERA5_xarray["time"]),
            ]
            ind_start += len(ERA5_xarray["time"]) + 1

        # ======================================================== #
        # surface files
        if filename_surface is not None:
            surface_files = []
            filename_surface = sorted(filename_surface)

            for fn in filename_surface:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)

                surface_files.append(xarray_dataset)

            self.surface_files = surface_files

        else:
            self.surface_files = False

        # ======================================================== #
        # dynamic forcing files
        if filename_dyn_forcing is not None:
            dyn_forcing_files = []
            filename_dyn_forcing = sorted(filename_dyn_forcing)

            for fn in filename_dyn_forcing:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_dyn_forcing
                )

                dyn_forcing_files.append(xarray_dataset)

            self.dyn_forcing_files = dyn_forcing_files

        else:
            self.dyn_forcing_files = False

        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic

        if self.filename_diagnostic is not None:
            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)

            for fn in filename_diagnostic:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_diagnostic
                )

                diagnostic_files.append(xarray_dataset)

            self.diagnostic_files = diagnostic_files

        else:
            self.diagnostic_files = False

        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing

        if self.filename_forcing is not None:
            assert os.path.isfile(
                filename_forcing
            ), "Cannot find forcing file [{}]".format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)

            self.xarray_forcing = xarray_dataset
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static

        if self.filename_static is not None:
            assert os.path.isfile(
                filename_static
            ), "Cannot find static file [{}]".format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)

            self.xarray_static = xarray_dataset
        else:
            self.xarray_static = False

    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1
        return total_len

    def __getitem__(self, index):
        # ========================================================================== #
        # cross-year indices --> the index of the year + indices within that year

        # select the ind_file based on the iter index
        ind_file = find_key_for_number(index, self.ERA5_indices)

        # get the ind within the current file
        ind_start = self.ERA5_indices[ind_file][1]
        ind_start_in_file = index - ind_start

        # handle out-of-bounds
        ind_largest = len(self.all_files[int(ind_file)]["time"]) - (
            self.history_len + self.forecast_len + 1
        )
        if ind_start_in_file > ind_largest:
            ind_start_in_file = ind_largest

        # ========================================================================== #
        # subset xarray on time dimension

        ind_end_in_file = ind_start_in_file + self.history_len + self.forecast_len

        ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch)
        ERA5_subset = self.all_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file + 1)
        )  # .load() NOT load into memory

        # ========================================================================== #
        # merge surface into the dataset

        if self.surface_files:
            ## subset surface variables
            surface_subset = self.surface_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )  # .load() NOT load into memory

            ## merge upper-air and surface here:
            ERA5_subset = ERA5_subset.merge(
                surface_subset
            )  # <-- lazy merge, ERA5 and surface both not loaded

        # ==================================================== #
        # split ERA5_subset into training inputs and targets
        #   + merge with dynamic forcing, forcing, and static

        # the ind_end of the ERA5_subset
        ind_end_time = len(ERA5_subset["time"])

        # datetiem information as int number (used in some normalization methods)
        datetime_as_number = ERA5_subset.time.values.astype("datetime64[s]").astype(int)

        # ==================================================== #
        # xarray dataset as input
        ## historical_ERA5_images: the final input

        historical_ERA5_images = ERA5_subset.isel(
            time=slice(0, self.history_len, self.skip_periods)
        ).load()  # <-- load into memory

        # ========================================================================== #
        # merge dynamic forcing inputs
        if self.dyn_forcing_files:
            dyn_forcing_subset = self.dyn_forcing_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )
            dyn_forcing_subset = dyn_forcing_subset.isel(
                time=slice(0, self.history_len, self.skip_periods)
            ).load()  # <-- load into memory

            historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset)

        # ========================================================================== #
        # merge forcing inputs
        if self.xarray_forcing:
            # ------------------------------------------------------------------------------- #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(
                np.array(self.xarray_forcing["time"])
            )
            month_day_inputs = extract_month_day_hour(
                np.array(historical_ERA5_images["time"])
            )  # <-- upper air
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            forcing_subset_input = self.xarray_forcing.isel(
                time=ind_forcing
            ).load()  # <-- load into memory
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            forcing_subset_input["time"] = historical_ERA5_images["time"]
            # ------------------------------------------------------------------------------- #

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

        # ========================================================================== #
        # merge static inputs
        if self.xarray_static:
            # expand static var on time dim
            N_time_dims = len(ERA5_subset["time"])
            static_subset_input = self.xarray_static.expand_dims(
                dim={"time": N_time_dims}
            )
            # assign coords 'time'
            static_subset_input = static_subset_input.assign_coords(
                {"time": ERA5_subset["time"]}
            )

            # slice + load to the GPU
            static_subset_input = static_subset_input.isel(
                time=slice(0, self.history_len, self.skip_periods)
            ).load()  # <-- load into memory

            # update
            static_subset_input["time"] = historical_ERA5_images["time"]

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)

        # ==================================================== #
        # xarray dataset as target
        ## target_ERA5_images: the final target

        if self.one_shot is not None:
            # one_shot is True (on), go straight to the last element
            target_ERA5_images = ERA5_subset.isel(
                time=slice(-1, None)
            ).load()  # <-- load into memory

            ## merge diagnoisc input here:
            if self.diagnostic_files:
                diagnostic_subset = self.diagnostic_files[int(ind_file)].isel(
                    time=slice(ind_start_in_file, ind_end_in_file + 1)
                )

                diagnostic_subset = diagnostic_subset.isel(
                    time=slice(-1, None)
                ).load()  # <-- load into memory

                target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

        else:
            # one_shot is None (off), get the full target length based on forecast_len
            target_ERA5_images = ERA5_subset.isel(
                time=slice(self.history_len, ind_end_time, self.skip_periods)
            ).load()  # <-- load into memory

            ## merge diagnoisc input here:
            if self.diagnostic_files:
                # subset diagnostic variables
                diagnostic_subset = self.diagnostic_files[int(ind_file)].isel(
                    time=slice(ind_start_in_file, ind_end_in_file + 1)
                )

                diagnostic_subset = diagnostic_subset.isel(
                    time=slice(self.history_len, ind_end_time, self.skip_periods)
                ).load()  # <-- load into memory

                # merge into the target dataset
                target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

        # pipe xarray datasets to the sampler
        sample = Sample(
            historical_ERA5_images=historical_ERA5_images,
            target_ERA5_images=target_ERA5_images,
            datetime_index=datetime_as_number,
        )

        # ==================================== #
        # data normalization
        if self.transform:
            sample = self.transform(sample)

        # assign sample index
        sample["index"] = index
        sample["datetime"] = datetime_as_number

        return sample

In [9]:
class ERA5_and_Forcing_MultiStep(torch.utils.data.Dataset):
    """
    A Pytorch Dataset class that works on:
        - upper-air variables (time, level, lat, lon)
        - surface variables (time, lat, lon)
        - dynamic forcing variables (time, lat, lon)
        - foring variables (time, lat, lon)
        - diagnostic variables (time, lat, lon)
        - static variables (lat, lon)
    """

    def __init__(
        self,
        varname_upper_air,
        varname_surface,
        varname_dyn_forcing,
        varname_forcing,
        varname_static,
        varname_diagnostic,
        filenames,
        filename_surface=None,
        filename_dyn_forcing=None,
        filename_forcing=None,
        filename_static=None,
        filename_diagnostic=None,
        history_len=2,
        forecast_len=0,
        transform=None,
        seed=42,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None,
    ):
        """
        Initialize the ERA5_and_Forcing_Dataset

        Parameters:
        - varname_upper_air (list): List of upper air variable names.
        - varname_surface (list): List of surface variable names.
        - varname_dyn_forcing (list): List of dynamic forcing variable names.
        - varname_forcing (list): List of forcing variable names.
        - varname_static (list): List of static variable names.
        - varname_diagnostic (list): List of diagnostic variable names.
        - filenames (list): List of filenames for upper air data.
        - filename_surface (list, optional): List of filenames for surface data.
        - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing data.
        - filename_forcing (str, optional): Filename for forcing data.
        - filename_static (str, optional): Filename for static data.
        - filename_diagnostic (list, optional): List of filenames for diagnostic data.
        - history_len (int, optional): Length of the history sequence. Default is 2.
        - forecast_len (int, optional): Length of the forecast sequence. Default is 0.
        - transform (callable, optional): Transformation function to apply to the data.
        - seed (int, optional): Random seed for reproducibility. Default is 42.
        - skip_periods (int, optional): Number of periods to skip between samples.
        - one_shot(bool, optional): Whether to return all states or just
                                    the final state of the training target. Default is None
        - max_forecast_len (int, optional): Maximum length of the forecast sequence.
        - shuffle (bool, optional): Whether to shuffle the data. Default is True.

        Returns:
        - sample (dict): A dictionary containing historical_ERA5_images,
                                                 target_ERA5_images,
                                                 datetime index, and additional information.
        """

        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # one shot option
        self.one_shot = one_shot

        # total number of needed forecast lead times
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)

        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # upper-air files

        all_files = []
        filenames = sorted(filenames)

        for fn in filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)

        self.all_files = all_files

        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {}  # <------ change
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [
                len(ERA5_xarray["time"]),
                ind_start,
                ind_start + len(ERA5_xarray["time"]),
            ]
            ind_start += len(ERA5_xarray["time"]) + 1

        # ======================================================== #
        # surface files
        if filename_surface is not None:
            surface_files = []
            filename_surface = sorted(filename_surface)

            for fn in filename_surface:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)

                surface_files.append(xarray_dataset)

            self.surface_files = surface_files

        else:
            self.surface_files = False

        # ======================================================== #
        # dynamic forcing files
        if filename_dyn_forcing is not None:
            dyn_forcing_files = []
            filename_dyn_forcing = sorted(filename_dyn_forcing)

            for fn in filename_dyn_forcing:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_dyn_forcing
                )

                dyn_forcing_files.append(xarray_dataset)

            self.dyn_forcing_files = dyn_forcing_files

        else:
            self.dyn_forcing_files = False

        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic

        if self.filename_diagnostic is not None:
            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)

            for fn in filename_diagnostic:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_diagnostic
                )

                diagnostic_files.append(xarray_dataset)

            self.diagnostic_files = diagnostic_files

        else:
            self.diagnostic_files = False

        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing

        if self.filename_forcing is not None:
            assert os.path.isfile(
                filename_forcing
            ), "Cannot find forcing file [{}]".format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)

            self.xarray_forcing = xarray_dataset
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static

        if self.filename_static is not None:
            assert os.path.isfile(
                filename_static
            ), "Cannot find static file [{}]".format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)

            self.xarray_static = xarray_dataset
        else:
            self.xarray_static = False

        self.start_index = self._get_random_start_index()
        self.forecast_step = 0
        self.total_length = len(self.ERA5_indices)

    def _get_random_start_index(self):
        """Generate a random start index based on the length of the dataset."""
        dataset_length = len(self)
        return 0  # random.randint(0, dataset_length - 1)

    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1
        return total_len

    def set_epoch(self, epoch):
        self.current_epoch = epoch
        self.forecast_step_count = 0
        self.current_index = None
        self.initial_index = None

    def _get_new_start_index(self, worker_id=0, num_workers=1):
        # Divide the data among workers such that there's no overlap
        total_steps = len(self.ERA5_indices) // num_workers
        worker_offset = worker_id * total_steps
        return worker_offset + (self.start_index % total_steps)

    def __getitem__(self, index):
        # worker_info = get_worker_info()
        # worker_id = worker_info.id if worker_info else 0

        if (self.forecast_step_count == self.forecast_len + 1) or (
            self.current_index is None
        ):
            # We've completed the last forecast or we're starting for the first time
            # Start a new forecast using the sampler index
            self.current_index = index  # self._get_random_start_index()
            self.forecast_step_count = 0
            index = self.current_index
            self.initial_index = self.current_index
        else:
            # Ignore the sampler index and continue the forecast
            self.current_index += 1
            index = self.current_index

        # select the ind_file based on the iter index
        ind_file = find_key_for_number(index, self.ERA5_indices)

        # get the ind within the current file
        ind_start = self.ERA5_indices[ind_file][1]
        ind_start_in_file = index - ind_start

        # handle out-of-bounds
        ind_largest = len(self.all_files[int(ind_file)]["time"]) - (
            self.history_len + self.forecast_len + 1
        )
        if ind_start_in_file > ind_largest:
            ind_start_in_file = ind_largest

        # ========================================================================== #
        # subset xarray on time dimension

        ind_end_in_file = ind_start_in_file + self.history_len

        ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch)
        ERA5_subset = self.all_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file + 1)
        )  # .load() NOT load into memory

        # ========================================================================== #
        # merge surface into the dataset

        if self.surface_files:
            ## subset surface variables
            surface_subset = self.surface_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )  # .load() NOT load into memory

            ## merge upper-air and surface here:
            ERA5_subset = ERA5_subset.merge(
                surface_subset
            )  # <-- lazy merge, ERA5 and surface both not loaded

        # ==================================================== #
        # split ERA5_subset into training inputs and targets
        #   + merge with dynamic forcing, forcing, and static

        # the ind_end of the ERA5_subset
        ind_end_time = len(ERA5_subset["time"])

        # datetiem information as int number (used in some normalization methods)
        datetime_as_number = ERA5_subset.time.values.astype("datetime64[s]").astype(int)

        # ==================================================== #
        # xarray dataset as input
        ## historical_ERA5_images: the final input

        historical_ERA5_images = ERA5_subset.isel(
            time=slice(0, self.history_len, self.skip_periods)
        ).load()  # <-- load into memory

        # ========================================================================== #
        # merge dynamic forcing inputs
        if self.dyn_forcing_files:
            dyn_forcing_subset = self.dyn_forcing_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )
            dyn_forcing_subset = dyn_forcing_subset.isel(
                time=slice(0, self.history_len, self.skip_periods)
            ).load()  # <-- load into memory

            historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset)

        # ========================================================================== #
        # merge forcing inputs
        if self.xarray_forcing:
            # ------------------------------------------------------------------------------- #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(
                np.array(self.xarray_forcing["time"])
            )
            month_day_inputs = extract_month_day_hour(
                np.array(historical_ERA5_images["time"])
            )  # <-- upper air
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            forcing_subset_input = self.xarray_forcing.isel(
                time=ind_forcing
            ).load()  # <-- load into memory
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            forcing_subset_input["time"] = historical_ERA5_images["time"]
            # ------------------------------------------------------------------------------- #

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

        # ========================================================================== #
        # merge static inputs
        if self.xarray_static:
            # expand static var on time dim
            N_time_dims = len(ERA5_subset["time"])
            static_subset_input = self.xarray_static.expand_dims(
                dim={"time": N_time_dims}
            )
            # assign coords 'time'
            static_subset_input = static_subset_input.assign_coords(
                {"time": ERA5_subset["time"]}
            )

            # slice + load to the GPU
            static_subset_input = static_subset_input.isel(
                time=slice(0, self.history_len, self.skip_periods)
            ).load()  # <-- load into memory

            # update
            static_subset_input["time"] = historical_ERA5_images["time"]

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)

        # ==================================================== #
        # xarray dataset as target
        ## target_ERA5_images: the final target

        target_ERA5_images = ERA5_subset.isel(
            time=slice(-1, None)
        ).load()  # <-- load into memory

        ## merge diagnoisc input here:
        if self.diagnostic_files:
            diagnostic_subset = self.diagnostic_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )

            diagnostic_subset = diagnostic_subset.isel(
                time=slice(-1, None)
            ).load()  # <-- load into memory

            target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

        # pipe xarray datasets to the sampler
        sample = Sample(
            historical_ERA5_images=historical_ERA5_images,
            target_ERA5_images=target_ERA5_images,
            datetime_index=datetime_as_number,
        )

        # ==================================== #
        # data normalization
        if self.transform:
            sample = self.transform(sample)

        # assign sample index
        sample["datetime"] = datetime_as_number
        sample["forecast_step"] = self.forecast_step + 1
        sample["index"] = index
        sample["stop_forecast"] = self.forecast_step == self.forecast_len

        # update the step count
        self.forecast_step += 1

        return sample

In [10]:
forecast_len = 7

In [13]:
dataset = ERA5_and_Forcing_Dataset(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=1,
    one_shot=None,
    max_forecast_len=forecast_len + history_len,
    transform=transforms,
)

In [14]:
dataset_multi = ERA5_and_Forcing_MultiStep(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=False,
    max_forecast_len=max_forecast_len,
    transform=transforms,
)

In [15]:
sample = dataset.__getitem__(0)

In [16]:
sample["y"].shape

torch.Size([8, 4, 15, 640, 1280])

In [17]:
sample["datetime"]

array([283996800, 284018400, 284040000, 284061600, 284083200, 284104800,
       284126400, 284148000, 284169600])

### Get the first sample from the multi-step reader

In [18]:
dataset_multi.set_epoch(0)
multi_sample = dataset_multi.__getitem__(0)

In [19]:
multi_sample["y"].shape

torch.Size([1, 4, 15, 640, 1280])

In [20]:
multi_sample["datetime"]

array([283996800, 284018400])

In [21]:
(sample["y"][0] == multi_sample["y"][0]).float().mean()

tensor(1.)

In [22]:
dataset_multi.set_epoch(0)
for k in range(forecast_len + 1):
    multi_sample = dataset_multi.__getitem__(k)
    print(k, (sample["y"][k] == multi_sample["y"][0]).float().mean())

0 tensor(1.)
1 tensor(1.)
2 tensor(1.)
3 tensor(1.)
4 tensor(1.)
5 tensor(1.)
6 tensor(1.)
7 tensor(1.)


### Another variation using our worker method

In [23]:
from typing import Any, Callable, Dict, List, Optional, Tuple
from functools import partial


def worker(
    tuple_index: Tuple[int, int],
    ERA5_indices: Dict[str, List[int]],
    all_files: List[Any],
    surface_files: Optional[List[Any]],
    dyn_forcing_files: Optional[List[Any]],
    diagnostic_files: Optional[List[Any]],
    xarray_forcing: Optional[Any],
    xarray_static: Optional[Any],
    history_len: int,
    forecast_len: int,
    skip_periods: int,
    transform: Optional[Callable],
) -> Dict[str, Any]:
    """
    Processes a given index to extract and transform data for a specific time slice.

    Parameters:
    - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing.
    - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata.
    - all_files (List[Any]): List of xarray datasets containing upper air data.
    - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data.
    - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data.
    - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data.
    - history_len (int): Length of the history sequence.
    - forecast_len (int): Length of the forecast sequence.
    - skip_periods (int): Number of periods to skip between samples.
    - xarray_forcing (Optional[Any]): xarray dataset containing forcing data.
    - xarray_static (Optional[Any]): xarray dataset containing static data.

    - transform (Optional[Callable]): Transformation function to apply to the data.

    Returns:
    - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.
    """

    index, ind_start_current_step = tuple_index

    try:
        # select the ind_file based on the iter index
        ind_file = find_key_for_number(ind_start_current_step, ERA5_indices)

        # get the ind within the current file
        ind_start = ERA5_indices[ind_file][1]
        ind_start_in_file = ind_start_current_step - ind_start

        # handle out-of-bounds
        ind_largest = len(all_files[int(ind_file)]["time"]) - (
            history_len + forecast_len + 1
        )
        if ind_start_in_file > ind_largest:
            ind_start_in_file = ind_largest

        # ========================================================================== #
        # subset xarray on time dimension & load it to the memory

        ind_end_in_file = ind_start_in_file + history_len + forecast_len

        ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch)
        ERA5_subset = all_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file + 1)
        )  # .load() NOT load into memory

        if surface_files:
            ## subset surface variables
            surface_subset = surface_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )  # .load() NOT load into memory

            ## merge upper-air and surface here:
            ERA5_subset = ERA5_subset.merge(
                surface_subset
            )  # <-- lazy merge, ERA5 and surface both not loaded

        # ==================================================== #
        # split ERA5_subset into training inputs and targets
        #   + merge with forcing and static

        # the ind_end of the ERA5_subset
        # ind_end_time = len(ERA5_subset['time'])

        # datetiem information as int number (used in some normalization methods)
        datetime_as_number = ERA5_subset.time.values.astype("datetime64[s]").astype(int)

        # ==================================================== #
        # xarray dataset as input
        ## historical_ERA5_images: the final input

        historical_ERA5_images = ERA5_subset.isel(
            time=slice(0, history_len, skip_periods)
        ).load()  # <-- load into memory

        # ========================================================================== #
        # merge dynamic forcing inputs
        if dyn_forcing_files:
            dyn_forcing_subset = dyn_forcing_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )
            dyn_forcing_subset = dyn_forcing_subset.isel(
                time=slice(0, history_len, skip_periods)
            ).load()  # <-- load into memory

            historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset)

        # ========================================================================== #
        # merge forcing inputs
        if xarray_forcing:
            # =============================================================================== #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(np.array(xarray_forcing["time"]))
            month_day_inputs = extract_month_day_hour(
                np.array(historical_ERA5_images["time"])
            )  # <-- upper air
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            forcing_subset_input = xarray_forcing.isel(
                time=ind_forcing
            ).load()  # <-- load into memory
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            forcing_subset_input["time"] = historical_ERA5_images["time"]
            # =============================================================================== #

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

        # ========================================================================== #
        # merge static inputs
        if xarray_static:
            # expand static var on time dim
            N_time_dims = len(ERA5_subset["time"])
            static_subset_input = xarray_static.expand_dims(dim={"time": N_time_dims})
            # assign coords 'time'
            static_subset_input = static_subset_input.assign_coords(
                {"time": ERA5_subset["time"]}
            )

            # slice + load to the GPU
            static_subset_input = static_subset_input.isel(
                time=slice(0, history_len, skip_periods)
            ).load()  # <-- load into memory

            # update
            static_subset_input["time"] = historical_ERA5_images["time"]

            # merge
            historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)

        # ==================================================== #
        # xarray dataset as target
        ## target_ERA5_images: the final target

        # get the next forecast step
        target_ERA5_images = ERA5_subset.isel(
            time=slice(history_len, history_len + skip_periods, skip_periods)
        ).load()  # <-- load into memory

        ## merge diagnoisc input here:
        if diagnostic_files:
            # subset diagnostic variables
            diagnostic_subset = diagnostic_files[int(ind_file)].isel(
                time=slice(ind_start_in_file, ind_end_in_file + 1)
            )

            # get the next forecast step
            diagnostic_subset = diagnostic_subset.isel(
                time=slice(history_len, history_len + skip_periods, skip_periods)
            ).load()  # <-- load into memory

            # merge into the target dataset
            target_ERA5_images = target_ERA5_images.merge(diagnostic_subset)

        # create a dict object with input/output tensors
        sample = Sample(
            historical_ERA5_images=historical_ERA5_images,
            target_ERA5_images=target_ERA5_images,
            datetime_index=datetime_as_number,
        )

        # data normalization
        if transform:
            sample = transform(sample)

        sample["index"] = index
        stop_forecast = (ind_start_current_step - index) == forecast_len
        sample["forecast_hour"] = ind_start_current_step - index + 1
        sample["index"] = index
        sample["stop_forecast"] = stop_forecast
        sample["datetime"] = [
            int(
                historical_ERA5_images.time.values[0]
                .astype("datetime64[s]")
                .astype(int)
            ),
            int(target_ERA5_images.time.values[0].astype("datetime64[s]").astype(int)),
        ]

    except Exception as e:
        logger.error(f"Error processing index {tuple_index}: {e}")
        raise

    return sample

In [24]:
import os
import torch
from credit.data import (
    drop_var_from_dataset,
    get_forward_data,
    Sample,
    find_key_for_number,
    extract_month_day_hour,
    find_common_indices,
)


class ERA5_and_Forcing_MultiStep(torch.utils.data.Dataset):
    """
    A Pytorch Dataset class that works on:
        - upper-air variables (time, level, lat, lon)
        - surface variables (time, lat, lon)
        - dynamic forcing variables (time, lat, lon)
        - foring variables (time, lat, lon)
        - diagnostic variables (time, lat, lon)
        - static variables (lat, lon)
    """

    def __init__(
        self,
        varname_upper_air,
        varname_surface,
        varname_dyn_forcing,
        varname_forcing,
        varname_static,
        varname_diagnostic,
        filenames,
        filename_surface=None,
        filename_dyn_forcing=None,
        filename_forcing=None,
        filename_static=None,
        filename_diagnostic=None,
        history_len=2,
        forecast_len=0,
        transform=None,
        seed=42,
        rank=0,
        world_size=1,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None,
    ):
        """
        Initialize the ERA5_and_Forcing_Dataset

        Parameters:
        - varname_upper_air (list): List of upper air variable names.
        - varname_surface (list): List of surface variable names.
        - varname_dyn_forcing (list): List of dynamic forcing variable names.
        - varname_forcing (list): List of forcing variable names.
        - varname_static (list): List of static variable names.
        - varname_diagnostic (list): List of diagnostic variable names.
        - filenames (list): List of filenames for upper air data.
        - filename_surface (list, optional): List of filenames for surface data.
        - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing data.
        - filename_forcing (str, optional): Filename for forcing data.
        - filename_static (str, optional): Filename for static data.
        - filename_diagnostic (list, optional): List of filenames for diagnostic data.
        - history_len (int, optional): Length of the history sequence. Default is 2.
        - forecast_len (int, optional): Length of the forecast sequence. Default is 0.
        - transform (callable, optional): Transformation function to apply to the data.
        - seed (int, optional): Random seed for reproducibility. Default is 42.
        - skip_periods (int, optional): Number of periods to skip between samples.
        - one_shot(bool, optional): Whether to return all states or just
                                    the final state of the training target. Default is None
        - max_forecast_len (int, optional): Maximum length of the forecast sequence.
        - shuffle (bool, optional): Whether to shuffle the data. Default is True.

        Returns:
        - sample (dict): A dictionary containing historical_ERA5_images,
                                                 target_ERA5_images,
                                                 datetime index, and additional information.
        """

        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform
        self.seed = seed
        self.rank = rank
        self.world_size = world_size

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # one shot option
        self.one_shot = one_shot

        # total number of needed forecast lead times
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)

        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # upper-air files

        all_files = []
        filenames = sorted(filenames)

        for fn in filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)

        self.all_files = all_files

        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {}
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [
                len(ERA5_xarray["time"]),
                ind_start,
                ind_start + len(ERA5_xarray["time"]),
            ]
            ind_start += len(ERA5_xarray["time"]) + 1

        # ======================================================== #
        # surface files
        if filename_surface is not None:
            surface_files = []
            filename_surface = sorted(filename_surface)

            for fn in filename_surface:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)

                surface_files.append(xarray_dataset)

            self.surface_files = surface_files

        else:
            self.surface_files = False

        # dynamic forcing files
        if filename_dyn_forcing is not None:
            dyn_forcing_files = []
            filename_dyn_forcing = sorted(filename_dyn_forcing)

            for fn in filename_dyn_forcing:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_dyn_forcing
                )

                dyn_forcing_files.append(xarray_dataset)

            self.dyn_forcing_files = dyn_forcing_files

        else:
            self.dyn_forcing_files = False

        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic

        if self.filename_diagnostic is not None:
            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)

            for fn in filename_diagnostic:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_diagnostic
                )

                diagnostic_files.append(xarray_dataset)

            self.diagnostic_files = diagnostic_files

        else:
            self.diagnostic_files = False

        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing

        if self.filename_forcing is not None:
            assert os.path.isfile(
                filename_forcing
            ), "Cannot find forcing file [{}]".format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)

            self.xarray_forcing = xarray_dataset
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static

        if self.filename_static is not None:
            assert os.path.isfile(
                filename_static
            ), "Cannot find static file [{}]".format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)

            self.xarray_static = xarray_dataset
        else:
            self.xarray_static = False

        self.worker = partial(
            worker,
            ERA5_indices=self.ERA5_indices,
            all_files=self.all_files,
            surface_files=self.surface_files,
            dyn_forcing_files=self.dyn_forcing_files,
            diagnostic_files=self.diagnostic_files,
            xarray_forcing=self.xarray_forcing,
            xarray_static=self.xarray_static,
            history_len=self.history_len,
            forecast_len=self.forecast_len,
            skip_periods=self.skip_periods,
            transform=self.transform,
        )

        self.start_index = None
        self.forecast_step = 0
        self.total_length = len(self.ERA5_indices)
        self.epoch = None

    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1
        return total_len

    def set_epoch(self, epoch):
        self.current_epoch = epoch
        self.forecast_step_count = 0
        self.current_index = None
        self.initial_index = None

    def __getitem__(self, index):
        if (self.forecast_step_count == self.forecast_len + 1) or (
            self.current_index is None
        ):
            # We've completed the last forecast or we're starting for the first time
            # Start a new forecast using the sampler index
            self.current_index = index  # self._get_random_start_index()
            self.forecast_step_count = 0
            index = self.current_index
            self.initial_index = self.current_index
        else:
            # Ignore the sampler index and continue the forecast
            self.current_index += 1
            index = self.current_index

        print(self.forecast_step_count, self.forecast_len, self.current_index)
        index_pair = (self.initial_index, index)
        # Worker process
        sample = self.worker(index_pair)

        # update the step count
        self.forecast_step += 1

        return sample

In [25]:
dataset = ERA5_and_Forcing_MultiStep(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=False,
    max_forecast_len=max_forecast_len,
    transform=transforms,
)

In [26]:
# batch_size = 1
# num_workers = 1

# dataloader = DataLoader(
#     dataset,
#     batch_size=batch_size,  # Adjust the batch size as needed
#     shuffle=False,   # Shuffle the dataset if needed
#     num_workers=num_workers,  # Number of subprocesses to use for data loading (adjust as needed)
#     drop_last=True,  # Drop the last incomplete batch if not divisible by batch_size,
#     prefetch_factor=4
# )

In [27]:
# dataloader.dataset.set_epoch(0)
# for (k, sample) in enumerate(dataloader):
#     print(k, sample["forecast_hour"], sample["index"], sample["datetime"], sample["stop_forecast"])
#     if k == 25:
#         break

dataset_multi.set_epoch(0)
for k in range(forecast_len + 1):
    multi_sample = dataset_multi.__getitem__(k)
    print(k, (sample["y"][k] == multi_sample["y"][0]).float().mean())

0 tensor(1.)
1 tensor(1.)
2 tensor(1.)
3 tensor(1.)
4 tensor(1.)
5 tensor(1.)
6 tensor(1.)
7 tensor(1.)
