# Data loader workspace

In [1]:
import os
import sys
import glob
import yaml
import wandb
import optuna
import shutil
import logging
import warnings

from pathlib import Path
from argparse import ArgumentParser
from echo.src.base_objective import BaseObjective

import torch
import torch.distributed as dist
from torch.cuda.amp import GradScaler
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from credit.distributed import distributed_model_wrapper

from credit.seed import seed_everything
from credit.loss import VariableTotalLoss2D
from credit.data import ERA5Dataset, ERA5_and_Forcing_Dataset, Dataset_BridgeScaler
from credit.transforms import load_transforms
from credit.scheduler import load_scheduler, annealed_probability
from credit.trainer import Trainer
from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
from credit.models import load_model
from credit.models.checkpoint import (
    FSDPOptimizerWrapper,
    TorchFSDPCheckpointIO
)

In [2]:
import numpy as np

In [3]:
warnings.filterwarnings("ignore")

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

os.environ['NCCL_SHM_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'


# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
#config_name = '/glade/u/home/ksha/miles-credit/config/fuxi_baseline_ksha_cpu.yml' #
#config_name = '/glade/u/home/ksha/miles-credit/config/example_for_data_checks.yml'
config_name = '/glade/work/ksha/CREDIT_runs/diag_o_tcw/model.yml'
# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [32]:
def setup(rank, world_size, mode):
    logging.info(f"Running {mode.upper()} on rank {rank} with world_size {world_size}.")
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

In [33]:
def load_dataset_and_sampler(conf, files, world_size, rank, is_train, seed=42):

    # convert $USER to the actual user name
    conf['save_loc'] = os.path.expandvars(conf['save_loc'])

    # number of previous lead time inputs
    history_len = conf["data"]["history_len"]
    valid_history_len = conf["data"]["valid_history_len"]
    history_len = history_len if is_train else valid_history_len

    # number of lead times to forecast
    forecast_len = conf["data"]["forecast_len"]
    valid_forecast_len = conf["data"]["valid_forecast_len"]
    forecast_len = forecast_len if is_train else valid_forecast_len

    # optional setting: max_forecast_len
    max_forecast_len = None if "max_forecast_len" not in conf["data"] else conf["data"]["max_forecast_len"]

    # optional setting: skip_periods
    skip_periods = None if "skip_periods" not in conf["data"] else conf["data"]["skip_periods"]

    # optional setting: one_shot
    one_shot = None if "one_shot" not in conf["data"] else conf["data"]["one_shot"]

    # shufle dataloader if training
    shuffle = is_train
    name = "Train" if is_train else "Valid"

    # data preprocessing utils
    transforms = load_transforms(conf)

    # quantile transform using BridgeScaler
    if conf["data"]["scaler_type"] == "quantile-cached":
        dataset = Dataset_BridgeScaler(
            conf,
            conf_dataset='bs_years_train' if is_train else 'bs_years_val',
            transform=transforms
        )

    else:
        # Z-score
        dataset = ERA5Dataset(
            filenames=files,
            history_len=history_len,
            forecast_len=forecast_len,
            skip_periods=skip_periods,
            one_shot=one_shot,
            max_forecast_len=max_forecast_len,
            transform=transforms
        )

    # Pytorch sampler
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        seed=seed,
        shuffle=shuffle,
        drop_last=True
    )
    logging.info(f" Loaded a {name} ERA dataset, and a distributed sampler (forecast length = {forecast_len + 1})")

    return dataset, sampler

In [34]:
def load_dataset_and_sampler_zscore_only(conf, all_ERA_files, surface_files, diagnostic_files, world_size, rank, is_train, seed=42):

    # convert $USER to the actual user name
    conf['save_loc'] = os.path.expandvars(conf['save_loc'])

    # ======================================================== #
    # parse intputs
    
    # file names
    varname_upper_air = conf['data']['variables']
    
    if ('forcing_variables' in conf['data']) and (len(conf['data']['forcing_variables']) > 0):
        forcing_files = conf['data']['save_loc_forcing']
        varname_forcing = conf['data']['forcing_variables']
    else:
        forcing_files = None
        varname_forcing = None
    
    if ('static_variables' in conf['data']) and (len(conf['data']['static_variables']) > 0):
        static_files = conf['data']['save_loc_static']
        varname_static = conf['data']['static_variables']
    else:
        static_files = None
        varname_static = None
    
    if surface_files is not None:
        varname_surface = conf['data']['surface_variables']
    else:
        varname_surface = None
        
    if diagnostic_files is not None:
        varname_diagnostic = conf['data']['diagnostic_variables']
    else:
        varname_diagnostic = None
        
    # number of previous lead time inputs
    history_len = conf["data"]["history_len"]
    valid_history_len = conf["data"]["valid_history_len"]

    # number of lead times to forecast
    forecast_len = conf["data"]["forecast_len"]
    valid_forecast_len = conf["data"]["valid_forecast_len"]
    
    if is_train:
        history_len = history_len
        forecast_len = forecast_len
        # print out training / validation
        name = "training"
    else:
        history_len = valid_history_len
        forecast_len = valid_forecast_len
        name = 'validation'
        
    # max_forecast_len
    if "max_forecast_len" not in conf["data"]:
        max_forecast_len = None
    else:
        max_forecast_len = conf["data"]["max_forecast_len"]

    # skip_periods
    if "skip_periods" not in conf["data"]:
        skip_periods = None
    else:
        skip_periods = conf["data"]["skip_periods"]
        
    # one_shot
    if "one_shot" not in conf["data"]:
        one_shot = None
    else:
        one_shot = conf["data"]["one_shot"]

    # shufle
    shuffle = is_train
    
    # data preprocessing utils
    transforms = load_transforms(conf)

    # Z-score
    dataset = ERA5_and_Forcing_Dataset(
        varname_upper_air=varname_upper_air,
        varname_surface=varname_surface,
        varname_forcing=varname_forcing,
        varname_static=varname_static,
        varname_diagnostic=varname_diagnostic,
        filenames=all_ERA_files,
        filename_surface=surface_files,
        filename_forcing=forcing_files,
        filename_static=static_files,
        filename_diagnostic=diagnostic_files,
        history_len=history_len,
        forecast_len=forecast_len,
        skip_periods=skip_periods,
        one_shot=one_shot,
        max_forecast_len=max_forecast_len,
        transform=transforms
    )
    
    # Pytorch sampler
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        seed=seed,
        shuffle=shuffle,
        drop_last=True
    )
    
    logging.info(f" Loaded a {name} ERA dataset, and a distributed sampler (forecast length = {forecast_len + 1})")

    return dataset, sampler

In [35]:
rank = 0
world_size = 1

In [36]:
# convert $USER to the actual user name
conf['save_loc'] = os.path.expandvars(conf['save_loc'])

if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
    setup(rank, world_size, conf["trainer"]["mode"])

# infer device id from rank
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu")
torch.cuda.set_device(rank % torch.cuda.device_count())

# Config settings
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

train_batch_size = conf['trainer']['train_batch_size']
valid_batch_size = conf['trainer']['valid_batch_size']
thread_workers = conf['trainer']['thread_workers']
valid_thread_workers = conf['trainer']['valid_thread_workers'] if 'valid_thread_workers' in conf['trainer'] else thread_workers

# get file names
all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"]))

if conf['data']['scaler_type'] == 'std_new':

    if "save_loc_surface" in conf["data"]:
        surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"]))
    else:
        surface_files = None

    if "save_loc_diagnostic" in conf["data"]:
        diagnostic_files = sorted(glob.glob(conf["data"]["save_loc_diagnostic"]))
    else:
        diagnostic_files = None
        
# ============================================================================== #
# Space for customized train/test split
# filenames = list(map(os.path.basename, all_ERA_files))
# all_years = sorted([re.findall(r'(?:_)(\d{4})', fn)[0] for fn in filenames])

# Specify the years for each set
# if conf["data"][train_test_split]:
#    normalized_split = conf["data"][train_test_split] / sum(conf["data"][train_test_split])
#    n_years = len(all_years)
#    train_years, sklearn.model_selection.train_test_split

# ============================================================================== #
# hard coded split
train_years = [str(year) for year in range(1979, 2014)]
valid_years = [str(year) for year in range(2014, 2018)]

# Filter the files for each set
train_files = [file for file in all_ERA_files if any(year in file for year in train_years)]
valid_files = [file for file in all_ERA_files if any(year in file for year in valid_years)]

if conf['data']['scaler_type'] == 'std_new':
    train_surface_files = [file for file in surface_files if any(year in file for year in train_years)]
    valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)]
    
    train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)]
    valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)]

# load dataset and sampler
if conf['data']['scaler_type'] == 'std_new':
    
    train_dataset, train_sampler = load_dataset_and_sampler_zscore_only(conf, 
                                                                        train_files, 
                                                                        train_surface_files, 
                                                                        train_diagnostic_files, 
                                                                        world_size, rank, is_train=True)
    
    valid_dataset, valid_sampler = load_dataset_and_sampler_zscore_only(conf, 
                                                                        valid_files, 
                                                                        valid_surface_files, 
                                                                        valid_diagnostic_files,
                                                                        world_size, rank, is_train=False)
else:
    train_dataset, train_sampler = load_dataset_and_sampler(conf, train_files, world_size, rank, is_train=True)
    valid_dataset, valid_sampler = load_dataset_and_sampler(conf, valid_files, world_size, rank, is_train=False)

# setup the dataloder
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=False,
        sampler=train_sampler,
        pin_memory=True,
        persistent_workers=True if thread_workers > 0 else False,
        num_workers=thread_workers,
        drop_last=True
)

valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=valid_batch_size,
        shuffle=False,
        sampler=valid_sampler,
        pin_memory=False,
        num_workers=valid_thread_workers,
        drop_last=True
)

In [37]:
def concat_and_reshape(x1, x2):
    x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5])
    x_concat = torch.cat((x1, x2), dim=2)
    return x_concat.permute(0, 2, 1, 3, 4)

In [48]:
def reshape_only(x1):
    x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5])
    return x1.permute(0, 2, 1, 3, 4)

In [38]:
conf['data']['scaler_type']

'std_new'

In [39]:
test = next(iter(valid_loader))

In [40]:
test.keys()

dict_keys(['x_surf', 'x', 'x_forcing_static', 'y_diag', 'y_surf', 'y', 'index'])

In [49]:
reshape_only(test['x']).shape

torch.Size([1, 60, 2, 640, 1280])

In [42]:
test['x_surf'][0, 1, -1, 300, 600]

tensor(-0.6509)

In [43]:
concat_and_reshape(test['y'], test['y_surf']).shape

torch.Size([1, 67, 1, 640, 1280])

In [47]:
test['y_diag'].permute(0, 2, 1, 3, 4).shape

torch.Size([1, 1, 1, 640, 1280])

In [46]:
test['x_forcing_static'].shape

torch.Size([1, 2, 3, 640, 1280])

In [27]:
test = next(iter(valid_loader))

In [28]:
test.keys()

dict_keys(['x_surf', 'x', 'y_surf', 'y', 'TOA', 'datetime', 'static', 'index'])

In [29]:
test['x_surf'].shape

torch.Size([1, 2, 7, 640, 1280])

In [30]:
test['x_surf'][0, 1, -1, 300, 600]

tensor(-0.6509)

In [29]:
batch_merge = concat_and_reshape(test["x"], test["x_surf"])

In [30]:
batch_merge.shape

torch.Size([1, 67, 2, 640, 1280])

In [45]:
test['forcing_static'].shape

torch.Size([1, 2, 3, 640, 1280])

In [42]:
test['forcing_static'].permute(0, 2, 1, 3, 4).shape

torch.Size([1, 3, 2, 640, 1280])

torch.Size([1, 70, 2, 640, 1280])

In [36]:
batch_merge[0, 59, 0, 10, 10]

tensor(-1.1295)

In [37]:
test["x"][0, 0, -1, -1, 10, 10]

tensor(-1.1295)

In [56]:
x_surf_new = test['x_surf']

In [14]:
TOA_new = test['forcing_static'][0, :, 0, ...]

In [21]:
TOA_new.shape

torch.Size([2, 640, 1280])

In [22]:
test = next(iter(valid_loader))

In [23]:
x_surf_old = test['x_surf']

In [26]:
TOA_old = test['TOA'][0, ...]

In [27]:
TOA_old.shape

torch.Size([2, 640, 1280])

In [32]:
(TOA_new - TOA_old).sum()

tensor(0., dtype=torch.float64)