In [75]:
# ---------- #
# System
import gc
import os
import sys
import yaml
import glob
import logging
import warnings
import traceback
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
import datetime
import pandas as pd
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
import torch.distributed as dist
from torchvision import transforms
# import wandb

# ---------- #
# credit
from credit.data import *
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.seed import seed_everything
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper
from credit.models.checkpoint import load_model_state
from credit.solar import TOADataLoader
from credit.output import split_and_reshape, load_metadata, make_xarray, save_netcdf_increment
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [2]:
# old rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/wxformer_6h_rollout/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [3]:
rank = 0
world_size = 1

## New rollout

In [4]:
class Predict_Dataset(torch.utils.data.IterableDataset):
    '''
    Same as ERA5_and_Forcing_Dataset() but for prediction only
    '''
    def __init__(self,
                 conf, 
                 varname_upper_air,
                 varname_surface,
                 varname_forcing,
                 varname_static,
                 filenames,
                 filename_surface,
                 filename_forcing,
                 filename_static,
                 fcst_datetime,
                 history_len,
                 rank,
                 world_size,
                 transform=None,
                 rollout_p=0.0,
                 which_forecast=None):
        
        # ------------------------------------------------------------------------------ #
        
        ## no diagnostics because they are output only
        varname_diagnostic = None
        
        self.rank = rank
        self.world_size = world_size
        self.transform = transform
        self.history_len = history_len
        self.fcst_datetime = fcst_datetime
        self.which_forecast = which_forecast # <-- got from the old roll-out. Dont know 
        
        # -------------------------------------- #
        self.filenames = sorted(filenames) # <---------------- a list of files
        self.filename_surface = sorted(filename_surface) # <-- a list of files
        self.filename_forcing = filename_forcing # <-- single file
        self.filename_static = filename_static # <---- single file
        
        # -------------------------------------- #
        self.varname_upper_air = varname_upper_air
        self.varname_surface = varname_surface
        self.varname_forcing = varname_forcing
        self.varname_static = varname_static

        # ====================================== #
        # import all upper air zarr files
        all_files = []
        for fn in self.filenames:
            # drop variables if they are not in the config
            xarray_dataset = get_forward_data(filename=fn)
            xarray_dataset = drop_var_from_dataset(xarray_dataset, self.varname_upper_air)
            # collect yearly datasets within a list
            all_files.append(xarray_dataset)
        self.all_files = all_files
        # ====================================== #

        # -------------------------------------- #
        # other settings
        self.current_epoch = 0
        self.rollout_p = rollout_p
        
        if 'lead_time_periods' in conf['data']:
            self.lead_time_periods = conf['data']['lead_time_periods']
        else:
            self.lead_time_periods = 1
        
        if 'skip_periods' in conf['data']:
            self.skip_periods = conf['data']['skip_periods']
        else:
            self.skip_periods = 1
            
        if self.skip_periods is None:
            self.skip_periods = 1
            

    def ds_read_and_subset(self, filename, time_start, time_end, varnames):
        sliced_x = xr.open_zarr(filename, consolidated=True)
        sliced_x = sliced_x.isel(time=slice(time_start, time_end))
        sliced_x = drop_var_from_dataset(sliced_x, varnames)
        return sliced_x

    def load_zarr_as_input(self, file_key, time_key):
        # get the needed file from a list of zarr files
        # open the zarr file as xr.dataset and subset based on the needed time
        
        # sliced_x: the final output, starts with an upper air xr.dataset
        sliced_x = self.ds_read_and_subset(self.filenames[file_key], 
                                           time_key, 
                                           time_key+self.history_len+1, 
                                           self.varname_upper_air)
        # surface variables
        if self.varname_surface is not None:
            sliced_surface = self.ds_read_and_subset(self.filename_surface[file_key], 
                                                     time_key, 
                                                     time_key+self.history_len+1, 
                                                     self.varname_surface)
            # merge surface to sliced_x
            sliced_surface['time'] = sliced_x['time']
            sliced_x = sliced_x.merge(sliced_surface)
            
        # forcing / static
        if self.filename_forcing is not None:
            sliced_forcing = xr.open_dataset(self.filename_forcing)
            sliced_forcing = drop_var_from_dataset(sliced_forcing, self.varname_forcing)

            # See also `ERA5_and_Forcing_Dataset`
            # =============================================================================== #
            # matching month, day, hour between forcing and upper air [time]
            # this approach handles leap year forcing file and non-leap-year upper air file
            month_day_forcing = extract_month_day_hour(np.array(sliced_forcing['time']))
            month_day_inputs = extract_month_day_hour(np.array(sliced_x['time']))
            # indices to subset
            ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
            sliced_forcing = sliced_forcing.isel(time=ind_forcing)
            # forcing and upper air have different years but the same mon/day/hour
            # safely replace forcing time with upper air time
            sliced_forcing['time'] = sliced_x['time']
            # =============================================================================== #
            
            # merge forcing to sliced_x
            sliced_x = sliced_x.merge(sliced_forcing)
            
        if self.filename_static is not None:
            sliced_static = xr.open_dataset(self.filename_static)
            sliced_static = drop_var_from_dataset(sliced_static, self.varname_static)
            sliced_static = sliced_static.expand_dims(dim={"time": len(sliced_x['time'])})
            sliced_static['time'] = sliced_x['time']
            # merge static to sliced_x
            sliced_x = sliced_x.merge(sliced_static)
        return sliced_x

    
    def find_start_stop_indices(self, index):
        # convert the first forecasted time to initialization time
        # by subtracting the forecast length (assuming 1 step)
        # other later forecasted time are viewed as init time directly 
        # becuase their previous step forecasted time are init times of the later forecasted time
        start_time = self.fcst_datetime[index][0] # string
        date_object = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
        # =========================================================================== #
        # <--- !! it MAY NOT work when self.skip_period != 1
        shifted_hours = self.lead_time_periods * self.skip_periods * self.history_len
        # =========================================================================== #
        date_object = date_object - datetime.timedelta(hours=shifted_hours)
        self.fcst_datetime[index][0] = date_object.strftime('%Y-%m-%d %H:%M:%S')

        # convert all strings to np.datetime64
        datetime_objs = [np.datetime64(date) for date in self.fcst_datetime[index]]
        start_time, stop_time = [str(datetime_obj) + '.000000000' for datetime_obj in datetime_objs]
        self.start_time = np.datetime64(start_time).astype(datetime.datetime)
        self.stop_time = np.datetime64(stop_time).astype(datetime.datetime)

        info = {}
        for idx, dataset in enumerate(self.all_files):
            start_time = np.datetime64(dataset['time'].min().values).astype(datetime.datetime)
            stop_time = np.datetime64(dataset['time'].max().values).astype(datetime.datetime)
            track_start = False
            track_stop = False
            if start_time <= self.start_time <= stop_time:
                # Start time is in this file, use start time index
                dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                start_idx = np.searchsorted(dataset, self.start_time)
                start_idx = max(0, min(start_idx, len(dataset)-1))
                track_start = True
            elif start_time < self.stop_time and stop_time > self.start_time:
                # File overlaps time range, use full file
                start_idx = 0
                track_start = True

            if start_time <= self.stop_time <= stop_time:
                # Stop time is in this file, use stop time index
                if isinstance(dataset, np.ndarray):
                    pass
                else:
                    dataset = np.array([np.datetime64(x.values).astype(datetime.datetime) for x in dataset['time']])
                stop_idx = np.searchsorted(dataset, self.stop_time)
                stop_idx = max(0, min(stop_idx, len(dataset)-1))
                track_stop = True

            elif start_time < self.stop_time and stop_time >= self.start_time:
                # File overlaps time range, use full file
                stop_idx = len(dataset) - 1
                track_stop = True

            # Only include files that overlap the time range
            if track_start and track_stop:
                info[idx] = ((idx, start_idx), (idx, stop_idx))

        indices = []
        for dataset_idx, (start, stop) in info.items():
            for i in range(start[1], stop[1]+1):
                indices.append((start[0], i))
        return indices

    def __len__(self):
        return len(self.fcst_datetime)

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        sampler = DistributedSampler(self, 
                                     num_replicas=num_workers*self.world_size, 
                                     rank=self.rank*num_workers+worker_id, 
                                     shuffle=False)
        for index in sampler:
            # get time indices for inputs
            data_lookup = self.find_start_stop_indices(index)
            for k, (file_key, time_key) in enumerate(data_lookup):
                if k == 0:
                    output_dict = {}
                    # get all inputs (upper air, surface, forcing, static ) in one xr.Dataset
                    sliced_x = self.load_zarr_as_input(file_key, time_key)
                    
                    # Check if additional data from the next file is needed
                    if len(sliced_x['time']) < self.history_len + 1:
                        
                        # Load excess data from the next file
                        next_file_idx = self.filenames.index(self.filenames[file_key]) + 1
                        
                        if next_file_idx == len(self.filenames):
                            # not enough input data to support this forecast
                            raise OSError("You have reached the end of the available data. Exiting.")
                            
                        else:
                            # time_key = 0 because we need the beginning of the next file only
                            sliced_x_next = self.load_zarr_as_input(next_file_idx, 0)
                            
                            # Concatenate excess data from the next file with the current data
                            sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')
                            sliced_x = sliced_x.isel(time=slice(0, self.history_len+1))
                                                     
                    # key 'historical_ERA5_images' is recongnized as input in credit.transform
                    sample_x = {
                        'historical_ERA5_images': sliced_x.isel(time=slice(0, self.history_len))
                    }
                    
                    if self.transform:
                        sample_x = self.transform(sample_x)
                        
                    for key in sample_x.keys():
                        output_dict[key] = sample_x[key]
                        
                    output_dict['forecast_hour'] = k + 1
                    # Adjust stopping condition
                    output_dict['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1))
                    output_dict['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]
                    print(sliced_x['time'])
                else:
                    output_dict['forecast_hour'] = k + 1
                     # Adjust stopping condition
                    output_dict['stop_forecast'] = (k == (len(data_lookup)-self.history_len-1)) 
                yield output_dict

                if output_dict['stop_forecast']:
                    break

In [8]:
# setup rank and world size for GPU-based rollout
if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(rank % torch.cuda.device_count())
else:
    device = torch.device("cpu")

# config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

# number of input time frames 
history_len = conf["data"]["history_len"]

# transform and ToTensor class
transform = load_transforms(conf)
if conf["data"]["scaler_type"] == 'std_new':
    state_transformer = Normalize_ERA5_and_Forcing(conf)
else:
    print('Scaler type {} not supported'.format(conf["data"]["scaler_type"]))
    raise
# ----------------------------------------------------------------- #
# parse varnames and save_locs from config
if 'lead_time_periods' in conf['data']:
    lead_time_periods = conf['data']['lead_time_periods']
else:
    lead_time_periods = 1

## upper air variables
all_ERA_files = sorted(glob(conf["data"]["save_loc"]))
varname_upper_air = conf['data']['variables']

## surface variables
if "save_loc_surface" in conf["data"]:
    surface_files = sorted(glob(conf["data"]["save_loc_surface"]))
    varname_surface = conf['data']['surface_variables']
else:
    surface_files = None
    varname_surface = None 

## forcing variables
if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None

## static variables
if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

In [125]:
def generate_datetime(start_time, end_time, interval_hr):
    # Define the time interval (e.g., every hour)
    interval = datetime.timedelta(hours=interval_hr)
    
    # Generate the list of datetime objects
    datetime_list = []
    current_time = start_time
    while current_time <= end_time:
        datetime_list.append(current_time)
        current_time += interval
    return datetime_list


def hour_to_nanoseconds(input_hr):
    # hr * min_per_hr * sec_per_min * nanosec_per_sec
    return input_hr*60 * 60 * 1000000000

def nanoseconds_to_year(nanoseconds_value):
    return np.datetime64(nanoseconds_value, 'ns').astype('datetime64[Y]').astype(int) + 1970

In [136]:
def ds_read_and_subset(filename, time_start, time_end, varnames):
    sliced_x = xr.open_zarr(filename, consolidated=True)
    sliced_x = sliced_x.isel(time=slice(time_start, time_end))
    sliced_x = drop_var_from_dataset(sliced_x, varnames)
    return sliced_x

filenames = all_ERA_files
filename_surface = surface_files
filename_forcing = forcing_files
filename_static = static_files

def load_zarr_as_input(i_file, i_init_start, i_init_end):
    # get the needed file from a list of zarr files
    # open the zarr file as xr.dataset and subset based on the needed time
    
    # sliced_x: the final output, starts with an upper air xr.dataset
    sliced_x = ds_read_and_subset(filenames[i_file], 
                                       i_init_start, 
                                       i_init_end+1, 
                                       varname_upper_air)
    # surface variables
    if varname_surface is not None:
        sliced_surface = ds_read_and_subset(filename_surface[i_file], 
                                                 i_init_start, 
                                                 i_init_end+1, 
                                                 varname_surface)
        # merge surface to sliced_x
        sliced_surface['time'] = sliced_x['time']
        sliced_x = sliced_x.merge(sliced_surface)
        
    # forcing / static
    if filename_forcing is not None:
        sliced_forcing = xr.open_dataset(filename_forcing)
        sliced_forcing = drop_var_from_dataset(sliced_forcing, varname_forcing)

        # See also `ERA5_and_Forcing_Dataset`
        # =============================================================================== #
        # matching month, day, hour between forcing and upper air [time]
        # this approach handles leap year forcing file and non-leap-year upper air file
        month_day_forcing = extract_month_day_hour(np.array(sliced_forcing['time']))
        month_day_inputs = extract_month_day_hour(np.array(sliced_x['time']))
        # indices to subset
        ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
        sliced_forcing = sliced_forcing.isel(time=ind_forcing)
        # forcing and upper air have different years but the same mon/day/hour
        # safely replace forcing time with upper air time
        sliced_forcing['time'] = sliced_x['time']
        # =============================================================================== #
        
        # merge forcing to sliced_x
        sliced_x = sliced_x.merge(sliced_forcing)
        
    if filename_static is not None:
        sliced_static = xr.open_dataset(filename_static)
        sliced_static = drop_var_from_dataset(sliced_static, varname_static)
        sliced_static = sliced_static.expand_dims(dim={"time": len(sliced_x['time'])})
        sliced_static['time'] = sliced_x['time']
        # merge static to sliced_x
        sliced_x = sliced_x.merge(sliced_static)
    return sliced_x

In [5]:
lead_time_periods = conf['data']['lead_time_periods']
skip_periods = 1

In [6]:
load_forecasts(conf)

[['2020-01-01 00:00:00', '2020-01-02 18:00:00'],
 ['2020-01-01 12:00:00', '2020-01-03 06:00:00']]

In [72]:
all_files = []
for fn in all_ERA_files:
    # drop variables if they are not in the config
    xarray_dataset = get_forward_data(filename=fn)
    xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air)
    # collect yearly datasets within a list
    all_files.append(xarray_dataset)

In [98]:
init_datetime = load_forecasts(conf)
index = 0
# ============================================================================ #
# shift hours for history_len > 1, becuase more than one init times are needed
# <--- !! it MAY NOT work when self.skip_period != 1
shifted_hours = lead_time_periods * skip_periods * (history_len-1)
# ============================================================================ #
# subtrack shifted_hour form the 1st & last init times
# convert to datetime object
init_datetime[index][0] = datetime.datetime.strptime(
    init_datetime[index][0], '%Y-%m-%d %H:%M:%S') - datetime.timedelta(hours=shifted_hours)
init_datetime[index][1] = datetime.datetime.strptime(
    init_datetime[index][1], '%Y-%m-%d %H:%M:%S') - datetime.timedelta(hours=shifted_hours)

# convert the 1st & last init times to a list of init times
init_datetime[index] = generate_datetime(init_datetime[index][0], init_datetime[index][1], lead_time_periods)
# convert datetime obj to nanosecondes
init_time_list_dt = [np.datetime64(date.strftime('%Y-%m-%d %H:%M:%S')) for date in init_datetime[index]]
init_time_list_np = [np.datetime64(str(dt_obj) + '.000000000').astype(datetime.datetime) for dt_obj in init_time_list_dt]

In [168]:
for i_file, ds in enumerate(all_files):
    # get the year of the current file
    ds_year = int(np.datetime_as_string(ds['time'][0].values, unit='Y'))

    # get the first and last years of init times
    init_year0 = nanoseconds_to_year(init_time_list_np[0])
    
    # found the right yearly file
    if init_year0 == ds_year:
        
        # convert ds['time'] to a list of nanosecondes
        ds_time_list = [np.datetime64(ds_time.values).astype(datetime.datetime) for ds_time in ds['time']]
        ds_start_time = ds_time_list[0]
        ds_end_time = ds_time_list[-1]
        
        init_time_start = init_time_list_np[0]
        # if initalization time is within this (yearly) xr.Dataset
        if ds_start_time <= init_time_start <= ds_end_time:

            # try getting the index of the first initalization time 
            i_init_start = ds_time_list.index(init_time_start)
            
            # for multiple init time inputs (history_len > 1), init_end is different for init_start
            init_time_end = init_time_start + hour_to_nanoseconds(shifted_hours)

            # see if init_time_end is alos in this file
            if ds_start_time <= init_time_end <= ds_end_time:
                
                # try getting the index
                i_init_end = ds_time_list.index(init_time_end)
            else:
                # this set of initalizations have crossed years
                # get the last element of the current file
                # we have anthoer section that checks additional input data
                i_init_end = len(ds_time_list) - 1
                
            info = [i_file, i_init_start, i_init_end]

In [175]:
data_lookup = info

for k, _ in enumerate(init_time_list_np):
    # the first initialization time: get initalization from data
    if k == 0:
        i_file, i_init_start, i_init_end = data_lookup
        output_dict = {}
        # get all inputs (upper air, surface, forcing, static ) in one xr.Dataset
        sliced_x = load_zarr_as_input(i_file, i_init_start, i_init_end)
        
        # Check if additional data from the next file is needed
        if len(sliced_x['time']) < history_len:
            
            # Load excess data from the next file
            next_file_idx = filenames.index(filenames[i_file]) + 1
            
            if next_file_idx >= len(filenames):
                # not enough input data to support this forecast
                raise OSError("You have reached the end of the available data. Exiting.")
                
            else:
                # i_init_start = 0 because we need the beginning of the next file only
                sliced_x_next = load_zarr_as_input(next_file_idx, 0, history_len)
                
                # Concatenate excess data from the next file with the current data
                sliced_x = xr.concat([sliced_x, sliced_x_next], dim='time')
                sliced_x = sliced_x.isel(time=slice(0, history_len))
                                         
        # key 'historical_ERA5_images' is recongnized as input in credit.transform
        # if len(sliced_x['time']) > history_len:
        #     sliced_x = sliced_x.isel(time=slice(0, history_len))
        sample_x = {'historical_ERA5_images': sliced_x}
        print(sliced_x['time'])
        if transform:
            sample_x = transform(sample_x)
            
        for key in sample_x.keys():
            output_dict[key] = sample_x[key]

        # <--- !! 'forecast_hour' is actually "forecast_step" but named by assuming hourly
        output_dict['forecast_hour'] = k + 1 
        # Adjust stopping condition
        output_dict['stop_forecast'] = k == (len(init_time_list_np) - 1)
        output_dict['datetime'] = sliced_x.time.values.astype('datetime64[s]').astype(int)[-1]
        
    # other later initialization time: the same initalization as in k=0, but add more forecast steps
    else:
        output_dict['forecast_hour'] = k + 1
         # Adjust stopping condition
        output_dict['stop_forecast'] = k == (len(init_time_list_np) - 1)
    #yield output_dict
    
    if output_dict['stop_forecast']:
        break

<xarray.DataArray 'time' (time: 1)> Size: 8B
array(['2020-01-01T00:00:00.000000000'], dtype='datetime64[ns]')
Coordinates:
  * time     (time) datetime64[ns] 8B 2020-01-01


In [161]:
sliced_x.time.values.astype('datetime64[s]')[-1]

numpy.datetime64('2020-01-01T00:00:00')

In [176]:
output_dict['forecast_hour']

8

In [157]:
sample_x.keys()

dict_keys(['x_forcing_static', 'x_surf', 'x'])

In [149]:
sample_x

NameError: name 'sample_x' is not defined

In [94]:
ds_start_time

283996800000000000

0

In [9]:


# ----------------------------------------------------------------- #\
# get dataset
dataset = Predict_Dataset(
    conf, 
    varname_upper_air,
    varname_surface,
    varname_forcing,
    varname_static,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    fcst_datetime=load_forecasts(conf),
    history_len=history_len,
    rank=rank,
    world_size=world_size,
    transform=transform,
    rollout_p=0.0,
    which_forecast=None
)

In [10]:
load_forecasts(conf)

[['2020-01-01 00:00:00', '2020-01-02 18:00:00'],
 ['2020-01-01 12:00:00', '2020-01-03 06:00:00']]

In [16]:
#ds_next_new = next(iter(dataset))

In [11]:
# setup the dataloder
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=False,
)

In [12]:
torch_next_new = next(iter(data_loader))

<xarray.DataArray 'time' (time: 2)> Size: 16B
array(['2019-12-31T18:00:00.000000000', '2020-01-01T00:00:00.000000000'],
      dtype='datetime64[ns]')
Coordinates:
  * time     (time) datetime64[ns] 16B 2019-12-31T18:00:00 2020-01-01


In [14]:
datetime.datetime.utcfromtimestamp(torch_next_new['datetime'])

datetime.datetime(2020, 1, 1, 0, 0)

In [24]:
# forcing is the last in the new dataset class
# batch size = 1, history_len, vars, lat, lon
TOA_new = torch_next_new['x_forcing_static'][0, 0, 0, ...].numpy()

In [20]:
# pull the normalized TSI
ds_forcing = xr.open_dataset('/glade/campaign/cisl/aiml/ksha/CREDIT/forcing_norm_6h.nc')
TSI = np.array(ds_forcing['TSI'])

In [31]:
np.sum(TSI[-1, ...] - TOA_new)

0.0

In [32]:
ds_forcing['time'][-1]

In [33]:
# get lat/lons from x-array
latlons = xr.open_dataset(conf["loss"]["latitude_weights"])

In [None]:
forecast_count = 0

# y_pred allocation
y_pred = None
static = None
results = []

In [35]:
for k, batch in enumerate(data_loader):
    # get the datetime and forecasted hours
    date_time = batch["datetime"].item()
    forecast_hour = batch["forecast_hour"].item()

    if forecast_hour == 1:
        # Initialize x and x_surf with the first time step
        if "x_surf" in batch:
            # combine x and x_surf
            # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon) 
            # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
            x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float()
        else:
            # no x_surf
            x = reshape_only(batch["x"]).to(device).float()

        init_datetime_str = datetime.utcfromtimestamp(date_time)
        init_datetime_str = init_datetime_str.strftime('%Y-%m-%dT%HZ')


        # -------------------------------------------------------------------------------------- #
        # add forcing and static variables (regardless of fcst hours)
        if 'x_forcing_static' in batch:
            
            # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
            x_forcing_batch = batch['x_forcing_static'].to(device).permute(0, 2, 1, 3, 4).float()

            # concat on var dimension
            x = torch.cat((x, x_forcing_batch), dim=1)

TypeError: 'DataLoader' object is not subscriptable

In [None]:


    # Rollout
    with torch.no_grad():
        # forecast count = a constant for each run

    
        # model inference loop
        for k, batch in enumerate(data_loader):
    

            # initialization on the first forecast hour
            if forecast_hour == 1:
                
                # Initialize x and x_surf with the first time step
                if "x_surf" in batch:
                    # combine x and x_surf
                    # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon) 
                    # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
                    x = concat_and_reshape(batch["x"], batch["x_surf"]).to(device).float()
                else:
                    # no x_surf
                    x = reshape_only(batch["x"]).to(device).float()

                init_datetime_str = datetime.utcfromtimestamp(date_time)
                init_datetime_str = init_datetime_str.strftime('%Y-%m-%dT%HZ')

            # -------------------------------------------------------------------------------------- #
            # add forcing and static variables (regardless of fcst hours)
            if 'x_forcing_static' in batch:
                
                # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
                x_forcing_batch = batch['x_forcing_static'].to(device).permute(0, 2, 1, 3, 4).float()

                # concat on var dimension
                x = torch.cat((x, x_forcing_batch), dim=1)

            # -------------------------------------------------------------------------------------- #
            # start prediction
            y_pred = model(x)
            y_pred = state_transformer.inverse_transform(y_pred.cpu())
            
            if ("use_laplace_filter" in conf["predict"] and conf["predict"]["use_laplace_filter"]):
                y_pred = (
                    dpf.diff_lap2d_filt(y_pred.to(device).squeeze())
                    .unsqueeze(0)
                    .unsqueeze(2)
                    .cpu()
                )
    
            # Save the current forecast hour data in parallel
            utc_datetime = datetime.utcfromtimestamp(date_time) + timedelta(hours=lead_time_periods*forecast_hour)
    
            # convert the current step result as x-array
            darray_upper_air, darray_single_level = make_xarray(
                y_pred,
                utc_datetime,
                latlons.latitude.values,
                latlons.longitude.values,
                conf,
            )
            
            # Save the current forecast hour data in parallel
            result = p.apply_async(
                save_netcdf_increment,
                (
                    darray_upper_air, 
                     darray_single_level, 
                     init_datetime_str, 
                     lead_time_periods*forecast_hour, 
                     meta_data, 
                     conf
                )
            )
            results.append(result)
            
            # Update the input
            # setup for next iteration, transform to z-space and send to device
            y_pred = state_transformer.transform_array(y_pred).to(device)
    
            if history_len == 1:
                x = y_pred.detach()
            else:
                # use multiple past forecast steps as inputs
                # static channels will get updated on next pass
                static_dim_size = abs(x.shape[1] - y_pred.shape[1])
                
                # if static_dim_size=0 then :0 gives empty range
                x_detach = x[:, :-static_dim_size, 1:].detach() if static_dim_size else x[:, :, 1:].detach()  
                x = torch.cat([x_detach, y_pred.detach()], dim=2)
    
            # Explicitly release GPU memory
            torch.cuda.empty_cache()
            gc.collect()
    
            if batch["stop_forecast"][0]:
                # Wait for all processes to finish in order
                for result in results:
                    result.get()
    
                # Now merge all the files into one and delete leftovers
                # merge_netcdf_files(init_datetime_str, conf)
    
                # forecast count = a constant for each run
                forecast_count += 1
    
                # update lists
                results = []
    
                # y_pred allocation
                y_pred = None
    
                gc.collect()
    
                if distributed:
                    torch.distributed.barrier()
    
    if distributed:
        torch.distributed.barrier()