In [1]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import *
#from credit.data import Predict_Dataset, ERA5Dataset, ERA5_and_Forcing_Dataset, get_forward_data, concat_and_reshape, drop_var_from_dataset
from credit.models import load_model
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
# old rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_norm/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_old = yaml.safe_load(stream)

In [3]:
# new rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_norm/model_new.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_new = yaml.safe_load(stream)

In [4]:
rank = 0
world_size = 1

## Old rollout

In [5]:
sys.path.insert(0, os.path.realpath('/glade/u/home/ksha/miles-credit/applications'))
import rollout_to_netcdf as old_rollout

In [6]:
conf = conf_old

In [7]:
if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

history_len = conf["data"]["history_len"]
time_step = conf["data"]["time_step"] if "time_step" in conf["data"] else None

# Load paths to all ERA5 data available
all_ERA_files = sorted(glob(conf["data"]["save_loc"]))

# <--- !! works for 'std_new' only
transform = load_transforms(conf)

dataset = old_rollout.PredictForecast(
    filenames=all_ERA_files,
    forecasts=load_forecasts(conf),
    history_len=history_len,
    skip_periods=time_step,
    transform=transform,
    rank=0,
    world_size=1,
    shuffle=False,
)

In [8]:
#ds_next_old = next(iter(dataset))

In [9]:
# setup the dataloder for this process
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [10]:
torch_next_old = next(iter(data_loader))

## New rollout

In [11]:
conf = conf_new

In [12]:
class Predict_Dataset(torch.utils.data.IterableDataset):
    '''
    Same as ERA5_and_Forcing_Dataset() but for prediction only
    '''
    def __init__(self,
                 conf, 
                 varname_upper_air,
                 varname_surface,
                 varname_forcing,
                 varname_static,
                 filenames,
                 filename_surface,
                 filename_forcing,
                 filename_static,
                 fcst_datetime,
                 history_len,
                 rank,
                 world_size,
                 transform=None,
                 rollout_p=0.0,
                 which_forecast=None):
        
        # ------------------------------------------------------------------------------ #
        
        ## no diagnostics because they are output only
        varname_diagnostic = None
        
        self.rank = rank
        self.world_size = world_size
        self.transform = transform
        self.history_len = history_len
        self.init_datetime = fcst_datetime
        self.which_forecast = which_forecast # <-- got from the old roll-out. Dont know 
        
        # -------------------------------------- #
        self.filenames = sorted(filenames) # <---------------- a list of files
        self.filename_surface = sorted(filename_surface) # <-- a list of files
        self.filename_forcing = filename_forcing # <-- single file
        self.filename_static = filename_static # <---- single file
        
        # -------------------------------------- #
        self.varname_upper_air = varname_upper_air
        self.varname_surface = varname_surface
        self.varname_forcing = varname_forcing
        self.varname_static = varname_static

        # ====================================== #
        # import all upper air zarr files
        all_files = []
        for fn in self.filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, self.varname_upper_air)
            # collect yearly datasets within a list
            all_files.append(xarray_dataset)
        self.all_files = all_files
        # ====================================== #

        # -------------------------------------- #
        # other settings
        self.current_epoch = 0
        self.rollout_p = rollout_p
        
        if 'lead_time_periods' in conf['data']:
            self.lead_time_periods = conf['data']['lead_time_periods']
        else:
            self.lead_time_periods = 1
        
        if 'skip_periods' in conf['data']:
            self.skip_periods = conf['data']['skip_periods']
        else:
            self.skip_periods = 1
            
        if self.skip_periods is None:
            self.skip_periods = 1
            

    def ds_read_and_subset(self, filename, time_start, time_end, varnames):
        sliced_x = xr.open_zarr(filename, consolidated=True)
        sliced_x = sliced_x.isel(time=slice(time_start, time_end))
        sliced_x = drop_var_from_dataset(sliced_x, varnames)
        return sliced_x

    def load_zarr_as_input(self, i_file, i_init_start, i_init_end):
        # get the needed file from a list of zarr files
        # open the zarr file as xr.dataset and subset based on the needed time
        
        # sliced_x: the final output, starts with an upper air xr.dataset
        sliced_x = self.ds_read_and_subset(self.filenames[i_file], 
                                           i_init_start,
                                           i_init_end+1,
                                           self.varname_upper_air)
        # surface variables
        if self.varname_surface is not None:
            sliced_surface = self.ds_read_and_subset(self.filename_surface[i_file], 
                                                     i_init_start,
                                                     i_init_end+1,
                                                     self.varname_surface)
            # merge surface to sliced_x
            sliced_surface['time'] = sliced_x['time']
            sliced_x = sliced_x.merge(sliced_surface)
            
        # forcing / static
        if self.filename_forcing is not None:
            sliced_forcing = xr.open_dataset(self.filename_forcing)
            sliced_forcing = drop_var_from_dataset(sliced_forcing, self.varname_forcing)

            # See also `ERA5_and_Forcing_Dataset`
            # =============================================================================== #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(np.array(sliced_forcing['time']))
            month_day_inputs = extract_month_day_hour(np.array(sliced_x['time']))
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            sliced_forcing = sliced_forcing.isel(time=ind_forcing)
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            sliced_forcing['time'] = sliced_x['time']
            # =============================================================================== #
            
            # merge forcing to sliced_x
            sliced_x = sliced_x.merge(sliced_forcing)
            
        if self.filename_static is not None:
            sliced_static = xr.open_dataset(self.filename_static)
            sliced_static = drop_var_from_dataset(sliced_static, self.varname_static)
            sliced_static = sliced_static.expand_dims(dim={"time": len(sliced_x['time'])})
            sliced_static['time'] = sliced_x['time']
            # merge static to sliced_x
            sliced_x = sliced_x.merge(sliced_static)
        return sliced_x

    
    def find_start_stop_indices(self, index):
        # ============================================================================ #
        # shift hours for history_len > 1, becuase more than one init times are needed
        # <--- !! it MAY NOT work when self.skip_period != 1
        shifted_hours = self.lead_time_periods * self.skip_periods * (self.history_len-1)
        # ============================================================================ #
        # subtrack shifted_hour form the 1st & last init times
        # convert to datetime object
        print(self.init_datetime[index])
        self.init_datetime[index][0] = datetime.datetime.strptime(
            self.init_datetime[index][0], '%Y-%m-%d %H:%M:%S') - datetime.timedelta(hours=shifted_hours)
        self.init_datetime[index][1] = datetime.datetime.strptime(
            self.init_datetime[index][1], '%Y-%m-%d %H:%M:%S') - datetime.timedelta(hours=shifted_hours)
        
        # convert the 1st & last init times to a list of init times
        self.init_datetime[index] = generate_datetime(self.init_datetime[index][0], self.init_datetime[index][1], self.lead_time_periods)
        # convert datetime obj to nanosecondes
        init_time_list_dt = [np.datetime64(date.strftime('%Y-%m-%d %H:%M:%S')) for date in self.init_datetime[index]]
        self.init_time_list_np = [np.datetime64(str(dt_obj) + '.000000000').astype(datetime.datetime) for dt_obj in init_time_list_dt]

        for i_file, ds in enumerate(self.all_files):
            # get the year of the current file
            ds_year = int(np.datetime_as_string(ds['time'][0].values, unit='Y'))
        
            # get the first and last years of init times
            init_year0 = nanoseconds_to_year(self.init_time_list_np[0])
            
            # found the right yearly file
            if init_year0 == ds_year:
                
                # convert ds['time'] to a list of nanosecondes
                ds_time_list = [np.datetime64(ds_time.values).astype(datetime.datetime) for ds_time in ds['time']]
                ds_start_time = ds_time_list[0]
                ds_end_time = ds_time_list[-1]
                
                init_time_start = self.init_time_list_np[0]
                # if initalization time is within this (yearly) xr.Dataset
                if ds_start_time <= init_time_start <= ds_end_time:
        
                    # try getting the index of the first initalization time 
                    i_init_start = ds_time_list.index(init_time_start)
                    
                    # for multiple init time inputs (history_len > 1), init_end is different for init_start
                    init_time_end = init_time_start + hour_to_nanoseconds(shifted_hours)
        
                    # see if init_time_end is alos in this file
                    if ds_start_time <= init_time_end <= ds_end_time:
                        
                        # try getting the index
                        i_init_end = ds_time_list.index(init_time_end)
                    else:
                        # this set of initalizations have crossed years
                        # get the last element of the current file
                        # we have anthoer section that checks additional input data
                        i_init_end = len(ds_time_list) - 1
                        
                    info = [i_file, i_init_start, i_init_end]
                    print(info)
                    return info

    def __len__(self):
        return len(self.init_datetime)

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(self, 
                                     num_replicas=num_workers*self.world_size, 
                                     rank=self.rank*num_workers+worker_id, 
                                     shuffle=False)
        for index in sampler:
            # get the init time info for the current sample
            data_lookup = self.find_start_stop_indices(index)
            
            for k, _ in enumerate(self.init_time_list_np):
                
                # the first initialization time: get initalization from data
                if k == 0:
                    i_file, i_init_start, i_init_end = data_lookup
                    
                    # allocate output dict
                    output_dict = {}

                    # get all inputs in one xr.Dataset
                    sliced_x = self.load_zarr_as_input(i_file, i_init_start, i_init_end)
                    
                    # Check if additional data from the next file is needed
                    if len(sliced_x['time']) < self.history_len:
                        
                        # Load excess data from the next file
                        next_file_idx = self.filenames.index(self.filenames[i_file]) + 1
                        
                        if next_file_idx >= len(self.filenames):
                            # not enough input data to support this forecast
                            raise OSError("You have reached the end of the available data. Exiting.")
                            
                        else:
                            # i_init_start = 0 because we need the beginning of the next file only
                            sliced_x_next = self.load_zarr_as_input(next_file_idx, 0, self.history_len)
                            
                            # Concatenate excess data from the next file with the current data
                            sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')
                            sliced_x = sliced_x.isel(time=slice(0, history_len))
                                                     
                    # key 'historical_ERA5_images' is recongnized as input in credit.transform
                    sample_x = {'historical_ERA5_images': sliced_x}
                    print(sliced_x['time'])
                    if self.transform:
                        sample_x = self.transform(sample_x)
                        
                    for key in sample_x.keys():
                        output_dict[key] = sample_x[key]
            
                    # <--- !! 'forecast_hour' is actually "forecast_step" but named by assuming hourly
                    output_dict['forecast_hour'] = k + 1 
                    # Adjust stopping condition
                    output_dict['stop_forecast'] = k == (len(self.init_time_list_np) - 1)
                    output_dict['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]
                    
                # other later initialization time: the same initalization as in k=0, but add more forecast steps
                else:
                    output_dict['forecast_hour'] = k + 1
                     # Adjust stopping condition
                    output_dict['stop_forecast'] = k == (len(self.init_time_list_np) - 1)
                    
                # return output_dict
                yield output_dict
                
                if output_dict['stop_forecast']:
                    break

In [13]:
# setup rank and world size for GPU-based rollout
if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

# config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

# number of input time frames 
history_len = conf["data"]["history_len"]

# transform and ToTensor class
transform = load_transforms(conf)
if conf["data"]["scaler_type"] == 'std_new':
    state_transformer = Normalize_ERA5_and_Forcing(conf)
else:
    print('Scaler type {} not supported'.format(conf["data"]["scaler_type"]))
    raise
# ----------------------------------------------------------------- #
# parse varnames and save_locs from config
if 'lead_time_periods' in conf['data']:
    lead_time_periods = conf['data']['lead_time_periods']
else:
    lead_time_periods = 1

## upper air variables
all_ERA_files = sorted(glob(conf["data"]["save_loc"]))
varname_upper_air = conf['data']['variables']

## surface variables
if "save_loc_surface" in conf["data"]:
    surface_files = sorted(glob(conf["data"]["save_loc_surface"]))
    varname_surface = conf['data']['surface_variables']
else:
    surface_files = None
    varname_surface = None 

## forcing variables
if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None

## static variables
if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

# ----------------------------------------------------------------- #\
# get dataset
dataset = Predict_Dataset(
    conf, 
    varname_upper_air,
    varname_surface,
    varname_forcing,
    varname_static,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    fcst_datetime=load_forecasts(conf),
    history_len=history_len,
    rank=rank,
    world_size=world_size,
    transform=transform,
    rollout_p=0.0,
    which_forecast=None
)

In [14]:
load_forecasts(conf)

[['2020-01-01 00:00:00', '2020-01-01 23:00:00'],
 ['2020-01-01 12:00:00', '2020-01-02 11:00:00']]

In [15]:
#ds_next_new = next(iter(dataset))

In [16]:
# setup the dataloder
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [17]:
torch_next_new = next(iter(data_loader))

['2020-01-01 00:00:00', '2020-01-01 23:00:00']
[40, 8759, 8759]
<xarray.DataArray 'time' (time: 2)> Size: 16B
array(['2019-12-31T23:00:00.000000000', '2020-01-01T00:00:00.000000000'],
      dtype='datetime64[ns]')
Coordinates:
  * time     (time) datetime64[ns] 16B 2019-12-31T23:00:00 2020-01-01


### upper air and surface var checks

In [18]:
print(np.sum((torch_next_old['x'].numpy() - torch_next_new['x'].numpy())))
print(np.sum((torch_next_old['x_surf'].numpy() - torch_next_new['x_surf'].numpy())))

81453.53
1764.9458


they pulled the same upper air and surface variables

### forcing variable checks

In [19]:
torch_next_new['datetime']

tensor([1577836800])

In [20]:
torch_next_old['datetime']

tensor([1577836800])

In [21]:
datetime.datetime.utcfromtimestamp(torch_next_old['datetime'])

datetime.datetime(2020, 1, 1, 0, 0)

In [28]:
# forcing has the key TOA
# batch size = 1, history_len, lat, lon
TOA_old = torch_next_old['TOA'][0, 0, ...].numpy()

# forcing is the last in the new dataset class
# batch size = 1, history_len, vars, lat, lon
TOA_new = torch_next_new['x_forcing_static'][0, 1, -1, ...].numpy()

In [23]:
# pull the normalized TSI
ds_forcing = xr.open_dataset('/glade/campaign/cisl/aiml/ksha/CREDIT/forcing_norm_1h.nc')
TSI = np.array(ds_forcing['TSI'])

In [29]:
np.sum(TSI[0, ...] - TOA_new) # the new dataset picked ind=-2

0.0

In [30]:
ds_forcing['TSI']['time'][0]

In [26]:
np.sum(TSI[-26, ...] - TOA_old) # the old dataset picked ind=-26 (wrong)

0.0

In [27]:
ds_forcing['TSI']['time'][-26]

The old TOA sample gen does not consider leap year, so it was 2000-12-30, not 31

In [28]:
old_output = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_test_old/2020-01-01T00Z/pred_2020-01-01T00Z_001.nc')

In [30]:
old_output['time']

In [31]:
old_output = xr.open_dataset('/glade/derecho/scratch/ksha/CREDIT/fuxi_test_old/2020-01-01T00Z/pred_2020-01-01T00Z_024.nc')

In [32]:
old_output['time']

In [33]:
datetime.datetime.utcfromtimestamp(1577836800)

datetime.datetime(2020, 1, 1, 0, 0)