In [1]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import Predict_Dataset
#ERA5Dataset, ERA5_and_Forcing_Dataset, get_forward_data, concat_and_reshape, drop_var_from_dataset
from credit.models import load_model
from credit.transforms import load_transforms
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
config_name = '/glade/u/home/ksha/miles-credit/results/fuxi_norm/model_new.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [3]:
history_len = conf["data"]["history_len"]
time_step = conf["data"]["time_step"] if "time_step" in conf["data"] else None

all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))
transform = load_transforms(conf)

In [6]:
# class Predict_Dataset(torch.utils.data.IterableDataset):
#     def __init__(self,
#                  conf, 
#                  varname_upper_air,
#                  varname_surface,
#                  varname_forcing,
#                  varname_static,
#                  filenames,
#                  filename_surface,
#                  filename_forcing,
#                  filename_static,
#                  fcst_datetime,
#                  history_len,
#                  rank,
#                  world_size,
#                  transform=None,
#                  rollout_p=0.0,
#                  which_forecast=None):
        
#         # ------------------------------------------------------------------------------ #
        
#         ## no diagnostics because they are output only
#         varname_diagnostic = None
        
#         self.rank = rank
#         self.world_size = world_size
#         self.transform = transform
#         self.history_len = history_len
#         self.fcst_datetime = fcst_datetime
#         self.which_forecast = which_forecast # <-- got from the old roll-out. Dont know 
        
#         # -------------------------------------- #
#         self.filenames = sorted(filenames)
#         self.filename_surface = sorted(filename_surface)
#         self.filename_forcing = forcing_files
#         self.filename_static = static_files
        
#         # -------------------------------------- #
#         self.varname_upper_air = varname_upper_air
#         self.varname_surface = varname_surface
#         self.varname_forcing = varname_forcing
#         self.varname_static = varname_static

#         # ====================================== #
#         # import all upper air zarr files
#         all_files = []
#         for fn in self.filenames:
#             # drop variables if they are not in the config
#             xarray_dataset = get_forward_data(filename=fn)
#             xarray_dataset = drop_var_from_dataset(xarray_dataset, self.varname_upper_air)
#             # collect yearly datasets within a list
#             all_files.append(xarray_dataset)
#         self.all_files = all_files
#         # ====================================== #

#         # -------------------------------------- #
#         # other settings
#         self.current_epoch = 0
#         self.rollout_p = rollout_p

#     def ds_read_and_subset(self, filename, time_start, time_end, varnames):
#         sliced_x = xr.open_zarr(filename, consolidated=True)
#         sliced_x = sliced_x.isel(time=slice(time_start, time_end))
#         sliced_x = drop_var_from_dataset(sliced_x, varnames)
#         return sliced_x

#     def load_zarr_as_input(self, file_key, time_key):
#         # get the needed file from a list of zarr files
#         # open the zarr file as xr.dataset and subset based on the needed time
        
#         # sliced_x: the final output, starts with an upper air xr.dataset
#         sliced_x = self.ds_read_and_subset(self.filenames[file_key], 
#                                            time_key, 
#                                            time_key+self.history_len+1, 
#                                            self.varname_upper_air)
#         # surface variables
#         if self.varname_surface is not None:
#             sliced_surface = self.ds_read_and_subset(self.filename_surface[file_key], 
#                                                      time_key, 
#                                                      time_key+self.history_len+1, 
#                                                      self.varname_surface)
#             # merge surface to sliced_x
#             sliced_surface['time'] = sliced_x['time']
#             sliced_x = sliced_x.merge(sliced_surface)
            
#         # forcing / static
#         if self.filename_forcing is not None:
#             sliced_forcing = xr.open_dataset(self.filename_forcing)
#             sliced_forcing = drop_var_from_dataset(sliced_forcing, self.varname_forcing)
#             sliced_forcing = sliced_forcing.isel(time=slice(time_key, time_key+self.history_len+1))
#             sliced_forcing['time'] = sliced_x['time']
#             # merge forcing to sliced_x
#             sliced_x = sliced_x.merge(sliced_forcing)
            
#         if self.filename_static is not None:
#             sliced_static = xr.open_dataset(self.filename_static)
#             sliced_static = drop_var_from_dataset(sliced_static, self.varname_static)
#             sliced_static = sliced_static.expand_dims(dim={"time": len(sliced_x['time'])})
#             sliced_static['time'] = sliced_x['time']
#             # merge static to sliced_x
#             sliced_x = sliced_x.merge(sliced_static)
#         return sliced_x

    
#     def find_start_stop_indices(self, index):
#         # convert the first forecasted time to initialization time
#         # by subtracting the forecast length (assuming 1 step)
#         # other later forecasted time are viewed as init time directly 
#         # becuase their previous step forecasted time are init times of the later forecasted time
#         start_time = self.fcst_datetime[index][0] # string
#         date_object = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
#         shifted_hours = self.history_len
#         date_object = date_object - datetime.timedelta(hours=shifted_hours)
#         self.fcst_datetime[index][0] = date_object.strftime('%Y-%m-%d %H:%M:%S')

#         # convert all strings to np.datetime64
#         datetime_objs = [np.datetime64(date) for date in self.fcst_datetime[index]]
#         start_time, stop_time = [str(datetime_obj) + '.000000000' for datetime_obj in datetime_objs]
#         self.start_time = np.datetime64(start_time).astype(datetime.datetime)
#         self.stop_time = np.datetime64(stop_time).astype(datetime.datetime)

#         info = {}

#         for idx, dataset in enumerate(self.all_files):
#             start_time = np.datetime64(dataset['time'].min().values).astype(datetime.datetime)
#             stop_time = np.datetime64(dataset['time'].max().values).astype(datetime.datetime)
#             track_start = False
#             track_stop = False

#             if start_time <= self.start_time <= stop_time:
#                 # Start time is in this file, use start time index
#                 dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
#                 start_idx = np.searchsorted(dataset, self.start_time)
#                 start_idx = max(0, min(start_idx, len(dataset)-1))
#                 track_start = True

#             elif start_time < self.stop_time and stop_time > self.start_time:
#                 # File overlaps time range, use full file
#                 start_idx = 0
#                 track_start = True

#             if start_time <= self.stop_time <= stop_time:
#                 # Stop time is in this file, use stop time index
#                 if isinstance(dataset, np.ndarray):
#                     pass
#                 else:
#                     dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
#                 stop_idx = np.searchsorted(dataset, self.stop_time)
#                 stop_idx = max(0, min(stop_idx, len(dataset)-1))
#                 track_stop = True

#             elif start_time < self.stop_time and stop_time > self.start_time:
#                 # File overlaps time range, use full file
#                 stop_idx = len(dataset) - 1
#                 track_stop = True

#             # Only include files that overlap the time range
#             if track_start and track_stop:
#                 info[idx] = ((idx, start_idx), (idx, stop_idx))

#         indices = []
#         for dataset_idx, (start, stop) in info.items():
#             for i in range(start[1], stop[1]+1):
#                 indices.append((start[0], i))
#         return indices

#     def __len__(self):
#         return len(self.fcst_datetime)

#     def __iter__(self):
#         worker_info = get_worker_info()
#         num_workers = worker_info.num_workers if worker_info is not None else 1
#         worker_id = worker_info.id if worker_info is not None else 0
#         sampler = DistributedSampler(self, 
#                                      num_replicas=num_workers*self.world_size, 
#                                      rank=self.rank*num_workers+worker_id, 
#                                      shuffle=False)
#         for index in sampler:
#             # get time indices for inputs
#             data_lookup = self.find_start_stop_indices(index)
#             for k, (file_key, time_key) in enumerate(data_lookup):
                
#                 if k == 0:
#                     output_dict = {}
#                     # get all inputs (upper air, surface, forcing, static ) in one xr.Dataset
#                     sliced_x = self.load_zarr_as_input(file_key, time_key)
                    
#                     # Check if additional data from the next file is needed
#                     if len(sliced_x['time']) < self.history_len + 1:
                        
#                         # Load excess data from the next file
#                         next_file_idx = self.filenames.index(self.filenames[file_key]) + 1
                        
#                         if next_file_idx == len(self.filenames):
#                             # not enough input data to support this forecast
#                             raise OSError("You have reached the end of the available data. Exiting.")
                            
#                         else:
#                             # time_key = 0 because we need the beginning of the next file only
#                             sliced_x_next = self.load_zarr_as_input(next_file_idx, 0)
                            
#                             # Concatenate excess data from the next file with the current data
#                             sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')

#                     # key 'historical_ERA5_images' is recongnized as input in credit.transform
#                     sample_x = {
#                         'historical_ERA5_images': sliced_x.isel(time=slice(0, self.history_len))
#                     }
                    
#                     if self.transform:
#                         sample_x = self.transform(sample_x)
                        
#                     for key in sample_x.keys():
#                         output_dict[key] = sample_x[key]
                        
#                     output_dict['forecast_hour'] = k + 1
#                     output_dict['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]
                    
#                     # Adjust stopping condition
#                     output_dict['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1)) 
#                 else:
#                     output_dict['forecast_hour'] = k + 1
#                     # Adjust stopping condition
#                     output_dict['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))  
                    
#                 yield output_dict

#                 if output_dict['stop_forecast']:
#                     break

In [7]:
rank = 0
world_size = 1

dataset = Predict_Dataset(
    conf, 
    varname_upper_air,
    varname_surface,
    varname_forcing,
    varname_static,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    fcst_datetime=load_forecasts(conf),
    history_len=2,
    rank=rank,
    world_size=world_size,
    transform=transform,
    rollout_p=0.0,
    which_forecast=None
)

In [8]:
load_forecasts(conf)

[['2018-06-01 00:00:00', '2018-06-01 02:00:00']]

In [9]:
test = next(iter(dataset))

In [10]:
test['x'][0, 0, 0, ...]

tensor([[-0.1052, -0.1064, -0.1076,  ..., -0.1018, -0.1028, -0.1040],
        [-0.1237, -0.1249, -0.1261,  ..., -0.1201, -0.1213, -0.1225],
        [-0.1413, -0.1425, -0.1437,  ..., -0.1378, -0.1390, -0.1402],
        ...,
        [-1.2384, -1.2350, -1.2317,  ..., -1.2483, -1.2450, -1.2417],
        [-1.2458, -1.2425, -1.2394,  ..., -1.2554, -1.2522, -1.2490],
        [-1.2406, -1.2375, -1.2344,  ..., -1.2496, -1.2466, -1.2436]])

In [11]:
# setup the dataloder for this process
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [12]:
for test in data_loader:
    print(test['x'][0, 0, 0, 0, 320, 640])
    print('-----------')
    print(test['x_surf'][0, 1, 6, 320, 640])
    print('-----------')
    print(test['x_forcing_static'][0, 0, 0, 320, 640])

tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
tensor(-0.5057)
-----------
tensor(0.8152)
-----------
tensor(0.6407)
