In [1]:
import yaml
import glob
import torch
import numpy as np
import datetime
from credit.transforms import load_transforms
from datetime import datetime

# Load a config

In [2]:
with open(
    "/glade/derecho/scratch/schreck/repos/miles-credit/results/wxformer/6hr/model.yml"
) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [3]:
from credit.data import (
    Sample,
    drop_var_from_dataset,
    get_forward_data_netCDF4,
    find_key_for_number,
    extract_month_day_hour,
    find_common_indices,
    ERA5_and_Forcing_Dataset,
    get_forward_data,
)
import os
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

# Load transforms and single-step / one-shot dataset

In [4]:
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))
varname_upper_air = conf["data"]["variables"]
surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))
diagnostic_files = None  # sorted(glob.glob(conf["data"]["save_loc_diagnostic"]))
is_train = False

if ("forcing_variables" in conf["data"]) and (
    len(conf["data"]["forcing_variables"]) > 0
):
    forcing_files = conf["data"]["save_loc_forcing"]
    varname_forcing = conf["data"]["forcing_variables"]
else:
    forcing_files = None
    varname_forcing = None

if ("static_variables" in conf["data"]) and (len(conf["data"]["static_variables"]) > 0):
    static_files = conf["data"]["save_loc_static"]
    varname_static = conf["data"]["static_variables"]
else:
    static_files = None
    varname_static = None

if surface_files is not None:
    varname_surface = conf["data"]["surface_variables"]
else:
    varname_surface = None

if diagnostic_files is not None:
    varname_diagnostic = conf["data"]["diagnostic_variables"]
else:
    varname_diagnostic = None

# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]

if is_train:
    history_len = history_len
    forecast_len = forecast_len
    # print out training / validation
    name = "training"
else:
    history_len = valid_history_len
    forecast_len = valid_forecast_len
    name = "validation"

# max_forecast_len
if "max_forecast_len" not in conf["data"]:
    max_forecast_len = None
else:
    max_forecast_len = conf["data"]["max_forecast_len"]

# skip_periods
if "skip_periods" not in conf["data"]:
    skip_periods = None
else:
    skip_periods = conf["data"]["skip_periods"]

# one_shot
if "one_shot" not in conf["data"]:
    one_shot = None
else:
    one_shot = conf["data"]["one_shot"]

In [5]:
# data preprocessing utils
transforms = load_transforms(conf)

  _pyproj_global_context_initialize()


In [6]:
dataset = ERA5_and_Forcing_Dataset(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=one_shot,
    max_forecast_len=max_forecast_len,
    transform=transforms,
)

In [7]:
sample = dataset.__getitem__(0)

In [8]:
sample.keys()

dict_keys(['x_forcing_static', 'x_surf', 'x', 'y_surf', 'y', 'index'])

In [9]:
sample["index"]

0

# Load N multi-step dataset where N is seqeunce length (Noah's loss)

In [None]:
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic


class DistributedSequentialDatasetV2(torch.utils.data.IterableDataset):
    # https://colab.research.google.com/drive/1OFLZnX9y5QUFNONuvFsxOizq4M-tFvk-?usp=sharing#scrollTo=CxSCQPOMHgwo

    def __init__(
        self,
        varname_upper_air,
        varname_surface,
        varname_forcing,
        varname_static,
        varname_diagnostic,
        filenames,
        filename_surface=None,
        filename_forcing=None,
        filename_static=None,
        filename_diagnostic=None,
        rank=0,
        world_size=1,
        history_len=2,
        forecast_len=0,
        transform=None,
        seed=42,
        skip_periods=None,
        one_shot=None,
        max_forecast_len=None,
        shuffle=True,
    ):
        self.history_len = history_len
        self.forecast_len = forecast_len
        self.transform = transform
        self.rank = rank
        self.world_size = world_size
        self.shuffle = shuffle
        self.current_epoch = 0

        # skip periods
        self.skip_periods = skip_periods
        if self.skip_periods is None:
            self.skip_periods = 1

        # one shot option
        self.one_shot = one_shot

        # total number of needed forecast lead times
        self.total_seq_len = self.history_len + self.forecast_len

        # set random seed
        self.rng = np.random.default_rng(seed=seed)

        # max possible forecast len
        self.max_forecast_len = max_forecast_len

        # ======================================================== #
        # ERA5 operations
        all_files = []
        filenames = sorted(filenames)

        for fn in filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)

            # collect yearly datasets within a list
            all_files.append(xarray_dataset)

        self.all_files = all_files

        # set data places:
        indo = 0
        self.meta_data_dict = {}
        for ee, bb in enumerate(self.all_files):
            self.meta_data_dict[str(ee)] = [
                len(bb["time"]),
                indo,
                indo + len(bb["time"]),
            ]
            indo += len(bb["time"]) + 1

        # get sample indices from ERA5 upper-air files:
        ind_start = 0
        self.ERA5_indices = {}  # <------ change
        for ind_file, ERA5_xarray in enumerate(self.all_files):
            # [number of samples, ind_start, ind_end]
            self.ERA5_indices[str(ind_file)] = [
                len(ERA5_xarray["time"]),
                ind_start,
                ind_start + len(ERA5_xarray["time"]),
            ]
            ind_start += len(ERA5_xarray["time"]) + 1

        # ======================================================== #
        # forcing file
        self.filename_forcing = filename_forcing

        if self.filename_forcing is not None:
            assert os.path.isfile(
                filename_forcing
            ), "Cannot find forcing file [{}]".format(filename_forcing)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data_netCDF4(filename_forcing)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing)

            self.xarray_forcing = xarray_dataset
        else:
            self.xarray_forcing = False

        # ======================================================== #
        # static file
        self.filename_static = filename_static

        if self.filename_static is not None:
            assert os.path.isfile(
                filename_static
            ), "Cannot find static file [{}]".format(filename_static)

            # drop variables if they are not in the config
            xarray_dataset = get_forward_data_netCDF4(filename_static)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static)

            self.xarray_static = xarray_dataset
        else:
            self.xarray_static = False

        # ======================================================== #
        # diagnostic file
        self.filename_diagnostic = filename_diagnostic

        if self.filename_diagnostic is not None:
            diagnostic_files = []
            filename_diagnostic = sorted(filename_diagnostic)

            for fn in filename_diagnostic:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(
                    xarray_dataset, varname_diagnostic
                )

                diagnostic_files.append(xarray_dataset)

            self.diagnostic_files = diagnostic_files

            assert (
                len(self.diagnostic_files) == len(self.all_files)
            ), "Mismatch between the total number of diagnostic files and upper-air files"
        else:
            self.diagnostic_files = False

        # ======================================================== #
        # surface files
        if filename_surface is not None:
            surface_files = []
            filename_surface = sorted(filename_surface)

            for fn in filename_surface:
                # drop variables if they are not in the config
                xarray_dataset = get_forward_data(filename=fn)
                xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface)

                surface_files.append(xarray_dataset)

            self.surface_files = surface_files

            assert len(self.surface_files) == len(
                self.all_files
            ), "Mismatch between the total number of surface files and upper-air files"
        else:
            self.surface_files = False

    def __post_init__(self):
        # Total sequence length of each sample.
        self.total_seq_len = self.history_len + self.forecast_len

    def __len__(self):
        # compute the total number of length
        total_len = 0
        for ERA5_xarray in self.all_files:
            total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1
        return total_len

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(
            self,
            num_replicas=num_workers * self.world_size,
            rank=self.rank * num_workers + worker_id,
            shuffle=self.shuffle,
        )
        sampler.set_epoch(self.current_epoch)

        for index in iter(sampler):
            indices = list(range(index, index + self.history_len + self.forecast_len))
            stop_forecast = False

            for k, ind in enumerate(indices):
                # select the ind_file based on the iter index
                ind_file = find_key_for_number(ind, self.ERA5_indices)

                # get the ind within the current file
                ind_start = self.ERA5_indices[ind_file][1]
                ind_start_in_file = ind - ind_start

                # handle out-of-bounds
                ind_largest = len(self.all_files[int(ind_file)]["time"]) - (
                    self.history_len + self.forecast_len + 1
                )
                if ind_start_in_file > ind_largest:
                    ind_start_in_file = ind_largest
                # ========================================================================== #
                # subset xarray on time dimension & load it to the memory

                ind_end_in_file = (
                    ind_start_in_file + self.history_len + self.forecast_len
                )

                ## ERA5_subset: a xarray dataset that contains training input and target (for the current index)
                ERA5_subset = (
                    self.all_files[int(ind_file)]
                    .isel(time=slice(ind_start_in_file, ind_end_in_file + 1))
                    .load()
                )

                if self.surface_files:
                    ## subset surface variables
                    surface_subset = (
                        self.surface_files[int(ind_file)]
                        .isel(time=slice(ind_start_in_file, ind_end_in_file + 1))
                        .load()
                    )

                    ## merge upper-air and surface here:
                    ERA5_subset = ERA5_subset.merge(surface_subset)

                # ==================================================== #
                # split ERA5_subset into training inputs and targets + merge with forcing and static

                # the ind_end of the ERA5_subset
                ind_end_time = len(ERA5_subset["time"])

                # datetiem information as int number (used in some normalization methods)
                datetime_as_number = ERA5_subset.time.values.astype(
                    "datetime64[s]"
                ).astype(int)

                # ==================================================== #
                # xarray dataset as input
                ## historical_ERA5_images: the final input

                historical_ERA5_images = ERA5_subset.isel(
                    time=slice(0, self.history_len, self.skip_periods)
                )

                # merge forcing inputs
                if self.xarray_forcing:
                    # =============================================================================== #
                    # matching month, day, hour between forcing and upper air [time]
                    # this approach handles leap year forcing file and non-leap-year upper air file
                    month_day_forcing = extract_month_day_hour(
                        np.array(self.xarray_forcing["time"])
                    )
                    month_day_inputs = extract_month_day_hour(
                        np.array(historical_ERA5_images["time"])
                    )  # <-- upper air
                    # indices to subset
                    ind_forcing, _ = find_common_indices(
                        month_day_forcing, month_day_inputs
                    )
                    forcing_subset_input = self.xarray_forcing.isel(
                        time=ind_forcing
                    ).load()
                    # forcing and upper air have different years but the same mon/day/hour
                    # safely replace forcing time with upper air time
                    forcing_subset_input["time"] = historical_ERA5_images["time"]
                    # =============================================================================== #

                    # merge
                    historical_ERA5_images = historical_ERA5_images.merge(
                        forcing_subset_input
                    )

                # merge static inputs
                if self.xarray_static:
                    # expand static var on time dim
                    N_time_dims = len(ERA5_subset["time"])
                    static_subset_input = self.xarray_static.expand_dims(
                        dim={"time": N_time_dims}
                    )
                    # assign coords 'time'
                    static_subset_input = static_subset_input.assign_coords(
                        {"time": ERA5_subset["time"]}
                    )

                    # slice + load to the GPU
                    static_subset_input = static_subset_input.isel(
                        time=slice(0, self.history_len, self.skip_periods)
                    ).load()

                    # update
                    static_subset_input["time"] = historical_ERA5_images["time"]

                    # merge
                    historical_ERA5_images = historical_ERA5_images.merge(
                        static_subset_input
                    )

                # ==================================================== #
                # xarray dataset as target
                ## target_ERA5_images: the final target

                # return the next state only
                target_ERA5_images = ERA5_subset.isel(
                    time=slice(
                        self.history_len,
                        self.history_len + self.skip_periods,
                        self.skip_periods,
                    )
                )

                ## merge diagnoisc input here:
                if self.diagnostic_files:
                    # subset diagnostic variables
                    diagnostic_subset = (
                        self.diagnostic_files[int(ind_file)]
                        .isel(time=slice(ind_start_in_file, ind_end_in_file + 1))
                        .load()
                    )

                    # merge into the target dataset
                    target_diagnostic = diagnostic_subset.isel(
                        time=slice(self.history_len, ind_end_time, self.skip_periods)
                    )
                    target_ERA5_images = target_ERA5_images.merge(target_diagnostic)

                if self.one_shot is not None:
                    # get the final state of the target as one-shot
                    target_ERA5_images = target_ERA5_images.isel(time=slice(0, 1))

                # pipe xarray datasets to the sampler
                sample = Sample(
                    historical_ERA5_images=historical_ERA5_images,
                    target_ERA5_images=target_ERA5_images,
                    datetime_index=datetime_as_number,
                )

                # ==================================== #
                # data normalization
                if self.transform:
                    sample = self.transform(sample)

                # assign sample index
                sample["index"] = index

                stop_forecast = k == self.forecast_len

                sample["forecast_hour"] = k + 1
                sample["index"] = index
                sample["stop_forecast"] = stop_forecast
                sample["datetime"] = [
                    int(
                        historical_ERA5_images.time.values[0]
                        .astype("datetime64[s]")
                        .astype(int)
                    ),
                    int(
                        target_ERA5_images.time.values[0]
                        .astype("datetime64[s]")
                        .astype(int)
                    ),
                ]

                yield sample

                if stop_forecast:
                    break

                if k == self.forecast_len:
                    break

    """

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(self, num_replicas=num_workers * self.world_size,
                                     rank=self.rank * num_workers + worker_id, shuffle=self.shuffle)
        sampler.set_epoch(self.current_epoch)
    
        process_index_partial = partial(process_index,
            ERA5_indices=self.ERA5_indices,
            all_files=self.all_files,
            surface_files=self.surface_files,
            history_len=self.history_len,
            forecast_len=self.forecast_len,
            skip_periods=self.skip_periods,
            xarray_forcing=self.xarray_forcing,
            xarray_static=self.xarray_static,
            diagnostic_files=self.diagnostic_files,
            one_shot=self.one_shot,
            transform=self.transform
        )

        for index in iter(sampler):
            indices = list(range(index, index + self.history_len + self.forecast_len))
            # Use pool.map to parallelize the inner loop
            for ind in range(len(indices)):
                yield process_index_partial(index, ind)
                if sample['stop_forecast']:
                    break
    
        # with Pool(2) as p:
        #     for index in iter(sampler):
        #         indices = list(range(index, index + self.history_len + self.forecast_len))
        #         # Use pool.map to parallelize the inner loop
        #         for sample in p.map(process_index_partial, [(index, ind) for ind in range(len(indices))]):
        #             yield sample
        #             if sample['stop_forecast']:
        #                 break


def process_index(
    _index, ERA5_indices, all_files, surface_files, history_len, forecast_len,
    skip_periods, xarray_forcing, xarray_static, diagnostic_files, one_shot,
    transform
):
    index, ind = _index
    # select the ind_file based on the iter index 
    ind_file = find_key_for_number(ind, ERA5_indices)

    # get the ind within the current file
    ind_start = ERA5_indices[ind_file][1]
    ind_start_in_file = ind - ind_start

    # handle out-of-bounds
    ind_largest = len(all_files[int(ind_file)]['time']) - (history_len + forecast_len + 1)
    if ind_start_in_file > ind_largest:
        ind_start_in_file = ind_largest

    # subset xarray on time dimension & load it to the memory
    ind_end_in_file = ind_start_in_file + history_len + forecast_len
    
    ERA5_subset = all_files[int(ind_file)].isel(
        time=slice(ind_start_in_file, ind_end_in_file + 1)).load()
    
    if surface_files:
        surface_subset = surface_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file + 1)).load()
        ERA5_subset = ERA5_subset.merge(surface_subset)

    ind_end_time = len(ERA5_subset['time'])
    datetime_as_number = ERA5_subset.time.values.astype('datetime64[s]').astype(int)

    historical_ERA5_images = ERA5_subset.isel(time=slice(0, history_len, skip_periods))

    if xarray_forcing:
        month_day_forcing = extract_month_day_hour(np.array(xarray_forcing['time']))
        month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time']))
        ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
        forcing_subset_input = xarray_forcing.isel(time=ind_forcing).load()
        forcing_subset_input['time'] = historical_ERA5_images['time']
        historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input)

    if xarray_static:
        N_time_dims = len(ERA5_subset['time'])
        static_subset_input = xarray_static.expand_dims(dim={"time": N_time_dims})
        static_subset_input = static_subset_input.assign_coords({'time': ERA5_subset['time']})
        static_subset_input = static_subset_input.isel(time=slice(0, history_len, skip_periods)).load()
        static_subset_input['time'] = historical_ERA5_images['time']
        historical_ERA5_images = historical_ERA5_images.merge(static_subset_input)
    
    target_ERA5_images = ERA5_subset.isel(time=slice(history_len, history_len + skip_periods, skip_periods))

    if diagnostic_files:
        diagnostic_subset = diagnostic_files[int(ind_file)].isel(
            time=slice(ind_start_in_file, ind_end_in_file + 1)).load()
        target_diagnostic = diagnostic_subset.isel(time=slice(history_len, ind_end_time, skip_periods))
        target_ERA5_images = target_ERA5_images.merge(target_diagnostic)
        
    if one_shot is not None:
        target_ERA5_images = target_ERA5_images.isel(time=slice(0, 1))

    sample = Sample(
        historical_ERA5_images=historical_ERA5_images,
        target_ERA5_images=target_ERA5_images,
        datetime_index=datetime_as_number
    )

    if transform:
        sample = transform(sample)

    sample["index"] = index
    stop_forecast = (ind == forecast_len)
    sample['forecast_hour'] = ind + 1
    sample['index'] = index
    sample['stop_forecast'] = stop_forecast
    sample["datetime"] = [
        int(historical_ERA5_images.time.values[0].astype('datetime64[s]').astype(int)),
        int(target_ERA5_images.time.values[0].astype('datetime64[s]').astype(int))
    ]

    return sample
    """

In [None]:
forecast_len = 3  # really its 4

# Z-score
dataset = DistributedSequentialDatasetV2(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=one_shot,
    max_forecast_len=max_forecast_len,
    transform=transforms,
    rank=0,
    world_size=1,
    shuffle=True,
)

In [None]:
for k, result in enumerate(dataset):
    print(
        k,
        result["stop_forecast"],
        [
            datetime.utcfromtimestamp(x).strftime("%B %d, %Y at %I:%M %p UTC")
            for x in result["datetime"]
        ],
    )
    if (k + 1) == forecast_len * 5 + 1:
        break