In [1]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import ERA5Dataset, ERA5_and_Forcing_Dataset, get_forward_data, concat_and_reshape, drop_var_from_dataset
from credit.models import load_model
from credit.transforms import load_transforms
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
config_name = '/glade/u/home/ksha/miles-credit/results/fuxi_norm/model.yml'
#config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_6h/model_rollout.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [3]:
class PredictForecast(torch.utils.data.IterableDataset):
    def __init__(self,
                 filenames,
                 forecasts,
                 history_len,
                 skip_periods,
                 rank,
                 world_size,
                 shuffle=False,
                 transform=None,
                 rollout_p=0.0,
                 which_forecast=None):

        self.dataset = ERA5Dataset(
            filenames=filenames,
            history_len=history_len,
            forecast_len=1,
            skip_periods=skip_periods,
            transform=transform
        )
        self.meta_data_dict = self.dataset.meta_data_dict
        self.all_files = self.dataset.all_fils
        self.history_len = history_len
        self.filenames = filenames
        self.transform = transform
        self.rank = rank
        self.world_size = world_size
        self.shuffle = shuffle
        self.skip_periods = skip_periods
        self.current_epoch = 0
        self.rollout_p = rollout_p
        self.forecasts = forecasts
        self.skip_periods = skip_periods if skip_periods is not None else 1
        self.which_forecast = which_forecast

    def find_start_stop_indices(self, index):
        start_time = self.forecasts[index][0]
        date_object = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
        shifted_hours = self.skip_periods * self.history_len
        date_object = date_object - datetime.timedelta(hours=shifted_hours)
        self.forecasts[index][0] = date_object.strftime('%Y-%m-%d %H:%M:%S')

        datetime_objs = [np.datetime64(date) for date in self.forecasts[index]]
        start_time, stop_time = [str(datetime_obj) + '.000000000' for datetime_obj in datetime_objs]
        self.start_time = np.datetime64(start_time).astype(datetime.datetime)
        self.stop_time = np.datetime64(stop_time).astype(datetime.datetime)

        info = {}

        for idx, dataset in enumerate(self.all_files):
            start_time = np.datetime64(dataset['time'].min().values).astype(datetime.datetime)
            stop_time = np.datetime64(dataset['time'].max().values).astype(datetime.datetime)
            track_start = False
            track_stop = False

            if start_time <= self.start_time <= stop_time:
                # Start time is in this file, use start time index
                dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                start_idx = np.searchsorted(dataset, self.start_time)
                start_idx = max(0, min(start_idx, len(dataset)-1))
                track_start = True

            elif start_time < self.stop_time and stop_time > self.start_time:
                # File overlaps time range, use full file
                start_idx = 0
                track_start = True

            if start_time <= self.stop_time <= stop_time:
                # Stop time is in this file, use stop time index
                if isinstance(dataset, np.ndarray):
                    pass
                else:
                    dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                stop_idx = np.searchsorted(dataset, self.stop_time)
                stop_idx = max(0, min(stop_idx, len(dataset)-1))
                track_stop = True

            elif start_time < self.stop_time and stop_time > self.start_time:
                # File overlaps time range, use full file
                stop_idx = len(dataset) - 1
                track_stop = True

            # Only include files that overlap the time range
            if track_start and track_stop:
                info[idx] = ((idx, start_idx), (idx, stop_idx))

        indices = []
        for dataset_idx, (start, stop) in info.items():
            for i in range(start[1], stop[1]+1):
                indices.append((start[0], i))
        return indices

    def __len__(self):
        return len(self.forecasts)

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(self, num_replicas=num_workers*self.world_size, rank=self.rank*num_workers+worker_id, shuffle=self.shuffle)

        for index in sampler:

            data_lookup = self.find_start_stop_indices(index)

            for k, (file_key, time_key) in enumerate(data_lookup):

                if k == 0:
                    concatenated_samples = {'x': [], 'x_surf': []}
                    sliced_x = xr.open_zarr(self.filenames[file_key], consolidated=True).isel(time=slice(time_key, time_key+self.history_len+1))

                    # Check if additional data from the next file is needed
                    if len(sliced_x['time']) < self.history_len + 1:
                        # Load excess data from the next file
                        next_file_idx = self.filenames.index(self.filenames[file_key]) + 1
                        if next_file_idx == len(self.filenames):
                            raise OSError("You have reached the end of the available data. Exiting.")
                        sliced_x_next = xr.open_zarr(
                            self.filenames[next_file_idx],
                            consolidated=True).isel(time=slice(0, self.history_len+1-len(sliced_x['time'])))

                        # Concatenate excess data from the next file with the current data
                        sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')

                    sample_x = {
                        'x': sliced_x.isel(time=slice(0, self.history_len))
                    }

                    if self.transform:
                        sample_x = self.transform(sample_x)
                        # Add static vars, if any, to the return dictionary
                        if "static" in sample_x:
                            concatenated_samples["static"] = []
                        if "TOA" in sample_x:
                            concatenated_samples["TOA"] = []

                    for key in concatenated_samples.keys():
                        concatenated_samples[key] = sample_x[key].squeeze(0) if self.history_len == 1 else sample_x[key]

                    concatenated_samples['forecast_hour'] = k + 1
                    concatenated_samples['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))  # Adjust stopping condition
                    concatenated_samples['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]

                else:
                    concatenated_samples['forecast_hour'] = k + 1
                    concatenated_samples['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))  # Adjust stopping condition

                yield concatenated_samples

                if concatenated_samples['stop_forecast']:
                    break

In [4]:
rank = 0
world_size = 1

In [5]:
if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

history_len = conf["data"]["history_len"]
time_step = conf["data"]["time_step"] if "time_step" in conf["data"] else None

# Load paths to all ERA5 data available
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

# <--- !! works for 'std_new' only
transform = load_transforms(conf)

dataset = PredictForecast(
    filenames=all_ERA_files,
    forecasts=load_forecasts(conf),
    history_len=history_len,
    skip_periods=time_step,
    transform=transform,
    rank=0,
    world_size=1,
    shuffle=False,
)

In [6]:
load_forecasts(conf)

[['2018-06-01 00:00:00', '2018-06-01 02:00:00']]

In [7]:
test = next(iter(dataset))

In [8]:
test['x'][0, 0, 0, ...]

tensor([[-0.1052, -0.1064, -0.1076,  ..., -0.1018, -0.1028, -0.1040],
        [-0.1237, -0.1249, -0.1261,  ..., -0.1201, -0.1213, -0.1225],
        [-0.1413, -0.1425, -0.1437,  ..., -0.1378, -0.1390, -0.1402],
        ...,
        [-1.2384, -1.2350, -1.2317,  ..., -1.2483, -1.2450, -1.2417],
        [-1.2458, -1.2425, -1.2394,  ..., -1.2554, -1.2522, -1.2490],
        [-1.2406, -1.2375, -1.2344,  ..., -1.2496, -1.2466, -1.2436]])

In [9]:
# setup the dataloder for this process
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [10]:
for test in data_loader:
    print(test['x'][0, 0, 0, 0, 320, 640])
    print('-----------')
    print(test['x_surf'][0, 1, 6, 320, 640])
    print('-----------')
    print(test['TOA'][0, 0, 320, 640])

tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
