# Compare the old and new ERA5Dataset class

In [1]:
import os
import sys
import yaml
import glob
import numpy as np

In [2]:
import os
from functools import reduce
from itertools import repeat
from dataclasses import dataclass, field
from typing import Optional, Callable, TypedDict, Union, Iterable, NamedTuple, List

# data utils
import datetime
import pandas as pd
import xarray as xr

# Pytorch utils
import torch
import torch.utils.data
from torch.utils.data import get_worker_info
from torch.utils.data.distributed import DistributedSampler

In [3]:
from typing import Dict
import logging
logger = logging.getLogger(__name__)

In [4]:
from torchvision import transforms as tforms
from credit.data import Sample, drop_var_from_dataset, find_key_for_number

In [5]:
from credit.data import ERA5Dataset, ERA5_and_Forcing_Dataset
from credit.transforms import load_transforms
from torch.utils.data.distributed import DistributedSampler

In [6]:
# old rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_norm/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_old = yaml.safe_load(stream)

# new rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_norm/model_new.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_new = yaml.safe_load(stream)

# new rollout config
config_name = '/glade/work/ksha/repos/global/miles-credit/results/fuxi_norm/model_dyn.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf_dyn = yaml.safe_load(stream)

In [7]:
rank = 0
world_size = 1

## New dataset

In [8]:
# conf = conf_new
# is_train = True

In [9]:
conf = conf_dyn
is_train = True

In [10]:
if 'train_years' in conf['data']:
    train_years_range = conf['data']['train_years']
else:
    train_years_range = [1979, 2014]

if 'valid_years' in conf['data']:
    valid_years_range = conf['data']['valid_years']
else:
    valid_years_range = [2014, 2018]

# convert year info to str for file name search
train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]

In [11]:
# get file names
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

# <------------------------------------------ std_new
if conf['data']['scaler_type'] == 'std_new':

    # check and glob surface files
    if ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0):
        
        print('collecting surface files')
        surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))
        
    else:
        surface_files = None

    # check and glob dyn forcing files
    if ('dynamic_forcing_variables' in conf['data']) and (len(conf['data']['dynamic_forcing_variables']) > 0):

        print('collecting dynamic forcing files')
        dyn_forcing_files = sorted(glob.glob(conf["data"]["save_loc_dynamic_forcing"]))
        
    else:
        dyn_forcing_files = None

    # check and glob diagnostic files
    if ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0):

        print('collecting diagnostic files')
        diagnostic_files = sorted(glob.glob(conf["data"]["save_loc_diagnostic"]))
        
    else:
        diagnostic_files = None

collecting surface files
collecting dynamic forcing files


In [12]:
# Filter the files for training / validation
train_files = [file for file in all_ERA_files if any(year in file for year in train_years)]
valid_files = [file for file in all_ERA_files if any(year in file for year in valid_years)]

# <----------------------------------- std_new
if conf['data']['scaler_type'] == 'std_new':
    
    if surface_files is not None:
        
        train_surface_files = [file for file in surface_files if any(year in file for year in train_years)]
        valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_surface_files) == len(train_files), \
        'Mismatch between the total number of training set [surface files] and [upper-air files]'
        assert len(valid_surface_files) == len(valid_files), \
        'Mismatch between the total number of validation set [surface files] and [upper-air files]'
    
    else:
        train_surface_files = None
        valid_surface_files = None

    if dyn_forcing_files is not None:
        
        train_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in train_years)]
        valid_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_dyn_forcing_files) == len(train_files), \
        'Mismatch between the total number of training set [dynamic forcing files] and [upper-air files]'
        assert len(valid_dyn_forcing_files) == len(valid_files), \
        'Mismatch between the total number of validation set [dynamic forcing files] and [upper-air files]'
    
    else:
        train_dyn_forcing_files = None
        valid_dyn_forcing_files = None
        
    if diagnostic_files is not None:
        
        train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)]
        valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)]

        # ---------------------------- #
        # check total number of files
        assert len(train_diagnostic_files) == len(train_files), \
        'Mismatch between the total number of training set [diagnostic files] and [upper-air files]'
        assert len(valid_diagnostic_files) == len(valid_files), \
        'Mismatch between the total number of validation set [diagnostic files] and [upper-air files]'
    
    else:
        train_diagnostic_files = None
        valid_diagnostic_files = None

In [13]:
# file names
varname_all = []

# upper air
varname_upper_air = conf['data']['variables']

if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
    forcing_files = conf['data']['save_loc_forcing']
    varname_forcing = conf['data']['forcing_variables']
else:
    forcing_files = None
    varname_forcing = None
    
if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
    static_files = conf['data']['save_loc_static']
    varname_static = conf['data']['static_variables']
else:
    static_files = None
    varname_static = None

# get surface variable names
if surface_files is not None:
    varname_surface = conf['data']['surface_variables']
else:
    varname_surface = None

# get dynamic forcing variable names
if dyn_forcing_files is not None:
    varname_dyn_forcing = conf['data']['dynamic_forcing_variables']
else:
    varname_dyn_forcing = None

# get diagnostic variable names
if diagnostic_files is not None:
    varname_diagnostic = conf['data']['diagnostic_variables']
else:
    varname_diagnostic = None
        
# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]

if is_train:
    history_len = history_len
    forecast_len = forecast_len
    # print out training / validation
    name = "training"
else:
    history_len = valid_history_len
    forecast_len = valid_forecast_len
    name = 'validation'
    
# max_forecast_len
if "max_forecast_len" not in conf["data"]:
    max_forecast_len = None
else:
    max_forecast_len = conf["data"]["max_forecast_len"]

# skip_periods
if "skip_periods" not in conf["data"]:
    skip_periods = None
else:
    skip_periods = conf["data"]["skip_periods"]
    
# one_shot
if "one_shot" not in conf["data"]:
    one_shot = None
else:
    one_shot = conf["data"]["one_shot"]

# shufle
shuffle = False

In [14]:
# data preprocessing utils
transforms = load_transforms(conf)

In [15]:
# Z-score
dataset = ERA5_and_Forcing_Dataset(
    varname_upper_air=varname_upper_air,
    varname_surface=varname_surface,
    varname_dyn_forcing=varname_dyn_forcing,
    varname_forcing=varname_forcing,
    varname_static=varname_static,
    varname_diagnostic=varname_diagnostic,
    filenames=all_ERA_files,
    filename_surface=surface_files,
    filename_dyn_forcing=dyn_forcing_files,
    filename_forcing=forcing_files,
    filename_static=static_files,
    filename_diagnostic=diagnostic_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=one_shot,
    max_forecast_len=max_forecast_len,
    transform=transforms
)

In [16]:
samples_dyn = next(iter(dataset))

In [17]:
samples_dyn.keys()

dict_keys(['x_forcing_static', 'x_surf', 'x', 'y_surf', 'y', 'index'])

In [18]:
samples_dyn['x_forcing_static'].shape

torch.Size([2, 4, 640, 1280])

In [21]:
samples_dyn['x_forcing_static'][1, 3, ...]

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.6970, 0.6970, 0.6970,  ..., 0.6969, 0.6969, 0.6969],
        [0.7052, 0.7052, 0.7052,  ..., 0.7052, 0.7052, 0.7052],
        [0.7133, 0.7133, 0.7133,  ..., 0.7133, 0.7133, 0.7133]])

## Old Dataset

In [12]:
conf = conf_old
is_train = True

In [13]:
# number of previous lead time inputs
history_len = conf["data"]["history_len"]
valid_history_len = conf["data"]["valid_history_len"]
history_len = history_len if is_train else valid_history_len

# number of lead times to forecast
forecast_len = conf["data"]["forecast_len"]
valid_forecast_len = conf["data"]["valid_forecast_len"]
forecast_len = forecast_len if is_train else valid_forecast_len

# optional setting: max_forecast_len
max_forecast_len = None if "max_forecast_len" not in conf["data"] else conf["data"]["max_forecast_len"]

# optional setting: skip_periods
skip_periods = None if "skip_periods" not in conf["data"] else conf["data"]["skip_periods"]

# optional setting: one_shot
one_shot = None if "one_shot" not in conf["data"] else conf["data"]["one_shot"]

# shufle dataloader if training
shuffle = False
name = "Train" if is_train else "Valid"

# data preprocessing utils
transforms = load_transforms(conf)

In [14]:
dataset_old = ERA5Dataset(
    filenames=all_ERA_files,
    history_len=history_len,
    forecast_len=forecast_len,
    skip_periods=skip_periods,
    one_shot=one_shot,
    max_forecast_len=max_forecast_len,
    transform=transforms
)

In [15]:
samples_old = next(iter(dataset_old))

### Dataset iter comparison

In [16]:
# elevation diff
np.array(samples_new['x_forcing_static'][:, 0, ...] - samples_old['static'][0, ...] ).sum()

-3.7487491219666616e-05

In [17]:
# land sea mask diff
np.array(samples_new['x_forcing_static'][:, 1, ...] - samples_old['static'][1, ...] ).sum()

0.0

In [18]:
# upper air diff
np.array(samples_new['x'] - samples_old['x']).sum()

0.0

In [19]:
# surface diff
np.array(samples_new['x_surf'] - samples_old['x_surf']).sum()

0.0

In [21]:
# upper air diff
np.array(samples_new['y'] - samples_old['y']).sum()

0.0

In [22]:
# surface diff
np.array(samples_new['y_surf'] - samples_old['y_surf']).sum()

0.0

In [20]:
samples_new.keys()

dict_keys(['x_forcing_static', 'x_surf', 'x', 'y_surf', 'y', 'index'])