In [1]:
%cd /kaggle/working

import copy
import glob
import os
import pickle
import re
import string
from typing import Literal

import h5py
import matplotlib.pyplot as plt
import netCDF4
import numpy as np
import pandas as pd
import tensorflow as tf
import xarray as xr
from tqdm import tqdm

/kaggle/working


2024-04-29 00:00:22.732966: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-29 00:00:22.733107: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-29 00:00:22.868582: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-29 00:00:23.142329: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# set variable names

v2_inputs = [
    "state_t",
    "state_q0001",
    "state_q0002",
    "state_q0003",
    "state_u",
    "state_v",
    "state_ps",
    "pbuf_SOLIN",
    "pbuf_LHFLX",
    "pbuf_SHFLX",
    "pbuf_TAUX",
    "pbuf_TAUY",
    "pbuf_COSZRS",
    "cam_in_ALDIF",
    "cam_in_ALDIR",
    "cam_in_ASDIF",
    "cam_in_ASDIR",
    "cam_in_LWUP",
    "cam_in_ICEFRAC",
    "cam_in_LANDFRAC",
    "cam_in_OCNFRAC",
    "cam_in_SNOWHICE",
    "cam_in_SNOWHLAND",
    "pbuf_ozone",  # outside of the upper troposphere lower stratosphere (UTLS, corresponding to indices 5-21), variance in minimal for these last 3
    "pbuf_CH4",
    "pbuf_N2O",
]


v2_outputs = [
    "ptend_t",
    "ptend_q0001",
    "ptend_q0002",
    "ptend_q0003",
    "ptend_u",
    "ptend_v",
    "cam_out_NETSW",
    "cam_out_FLWDS",
    "cam_out_PRECSC",
    "cam_out_PRECC",
    "cam_out_SOLS",
    "cam_out_SOLL",
    "cam_out_SOLSD",
    "cam_out_SOLLD",
]
vertically_resolved = [
    "state_t",
    "state_q0001",
    "state_q0002",
    "state_q0003",
    "state_u",
    "state_v",
    "pbuf_ozone",
    "pbuf_CH4",
    "pbuf_N2O",
    "ptend_t",
    "ptend_q0001",
    "ptend_q0002",
    "ptend_q0003",
    "ptend_u",
    "ptend_v",
]

ablated_vars = ["ptend_q0001", "ptend_q0002", "ptend_q0003", "ptend_u", "ptend_v"]

v2_vars = v2_inputs + v2_outputs

In [5]:
train_col_names = []
ablated_col_names = []
for var in v2_vars:
    if var in vertically_resolved:
        for i in range(60):
            train_col_names.append(var + "_" + str(i))
            if i < 12 and var in ablated_vars:
                ablated_col_names.append(var + "_" + str(i))
    else:
        train_col_names.append(var)

len(train_col_names), len(ablated_col_names)

(925, 60)

In [6]:
input_col_names = []
for var in v2_inputs:
    if var in vertically_resolved:
        for i in range(60):
            input_col_names.append(var + "_" + str(i))
    else:
        input_col_names.append(var)
len(input_col_names)

557

In [7]:
output_col_names = []
for var in v2_outputs:
    if var in vertically_resolved:
        for i in range(60):
            output_col_names.append(var + "_" + str(i))
    else:
        output_col_names.append(var)
len(output_col_names)

368

In [8]:
assert len(train_col_names) == 17 + 60 * 9 + 60 * 6 + 8
assert len(input_col_names) == 17 + 60 * 9
assert len(output_col_names) == 60 * 6 + 8
assert len(set(output_col_names).intersection(set(ablated_col_names))) == len(
    ablated_col_names
)

In [10]:
grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
norm_path = "/kaggle/working/misc/preprocessing/normalizations/"
grid_info = xr.open_dataset(grid_path)
input_mean = xr.open_dataset(norm_path + "inputs/input_mean.nc")
input_max = xr.open_dataset(norm_path + "inputs/input_max.nc")
input_min = xr.open_dataset(norm_path + "inputs/input_min.nc")
output_scale = xr.open_dataset(norm_path + "outputs/output_scale.nc")

In [24]:
input_mean

In [14]:
MLBackendType = Literal["tensorflow", "pytorch"]


class data_utils:
    def __init__(
        self,
        grid_info,
        input_mean,
        input_max,
        input_min,
        output_scale,
        ml_backend: MLBackendType = "pytorch",
    ):
        self.data_path = None
        self.input_vars = []
        self.target_vars = []
        self.input_feature_len = None
        self.target_feature_len = None
        self.grid_info = grid_info
        self.level_name = "lev"
        self.sample_name = "sample"
        self.num_levels = len(self.grid_info["lev"])
        self.num_latlon = len(
            self.grid_info["ncol"]
        )  # number of unique lat/lon grid points
        # make area-weights
        self.grid_info["area_wgt"] = self.grid_info["area"] / self.grid_info[
            "area"
        ].mean(dim="ncol")
        self.area_wgt = self.grid_info["area_wgt"].values
        # map ncol to nsamples dimension
        # to_xarray = {'area_wgt':(self.sample_name,np.tile(self.grid_info['area_wgt'], int(n_samples/len(self.grid_info['ncol']))))}
        # to_xarray = xr.Dataset(to_xarray)
        self.input_mean = input_mean
        self.input_max = input_max
        self.input_min = input_min
        self.output_scale = output_scale
        self.normalize = True
        self.lats, self.lats_indices = np.unique(
            self.grid_info["lat"].values, return_index=True
        )
        self.lons, self.lons_indices = np.unique(
            self.grid_info["lon"].values, return_index=True
        )
        self.sort_lat_key = np.argsort(
            self.grid_info["lat"].values[np.sort(self.lats_indices)]
        )
        self.sort_lon_key = np.argsort(
            self.grid_info["lon"].values[np.sort(self.lons_indices)]
        )
        self.indextolatlon = {
            i: (
                self.grid_info["lat"].values[i % self.num_latlon],
                self.grid_info["lon"].values[i % self.num_latlon],
            )
            for i in range(self.num_latlon)
        }

        self.ml_backend = ml_backend
        self.tf = None
        self.torch = None

        if self.ml_backend == "tensorflow":
            self.successful_backend_import = False

            try:
                import tensorflow as tf

                self.tf = tf
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("Tensorflow is not installed.")

        elif self.ml_backend == "pytorch":
            self.successful_backend_import = False

            try:
                import torch

                self.torch = torch
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("PyTorch is not installed.")

        def find_keys(dictionary, value):
            keys = []
            for key, val in dictionary.items():
                if val[0] == value:
                    keys.append(key)
            return keys

        indices_list = []
        for lat in self.lats:
            indices = find_keys(self.indextolatlon, lat)
            indices_list.append(indices)
        indices_list.sort(key=lambda x: x[0])
        self.lat_indices_list = indices_list

        self.hyam = self.grid_info["hyam"].values
        self.hybm = self.grid_info["hybm"].values
        self.p0 = 1e5  # code assumes this will always be a scalar
        self.ps_index = None

        self.pressure_grid_train = None
        self.pressure_grid_val = None
        self.pressure_grid_scoring = None
        self.pressure_grid_test = None

        self.dp_train = None
        self.dp_val = None
        self.dp_scoring = None
        self.dp_test = None

        self.train_regexps = None
        self.train_stride_sample = None
        self.train_filelist = None
        self.val_regexps = None
        self.val_stride_sample = None
        self.val_filelist = None
        self.scoring_regexps = None
        self.scoring_stride_sample = None
        self.scoring_filelist = None
        self.test_regexps = None
        self.test_stride_sample = None
        self.test_filelist = None

        self.full_vars = False

        # physical constants from E3SM_ROOT/share/util/shr_const_mod.F90
        self.grav = 9.80616  # acceleration of gravity ~ m/s^2
        self.cp = 1.00464e3  # specific heat of dry air   ~ J/kg/K
        self.lv = 2.501e6  # latent heat of evaporation ~ J/kg
        self.lf = 3.337e5  # latent heat of fusion      ~ J/kg
        self.lsub = self.lv + self.lf  # latent heat of sublimation ~ J/kg
        self.rho_air = (
            101325 / (6.02214e26 * 1.38065e-23 / 28.966) / 273.15
        )  # density of dry air at STP  ~ kg/m^3
        # ~ 1.2923182846924677
        # SHR_CONST_PSTD/(SHR_CONST_RDAIR*SHR_CONST_TKFRZ)
        # SHR_CONST_RDAIR   = SHR_CONST_RGAS/SHR_CONST_MWDAIR
        # SHR_CONST_RGAS    = SHR_CONST_AVOGAD*SHR_CONST_BOLTZ
        self.rho_h20 = 1.0e3  # density of fresh water     ~ kg/m^ 3

        self.v1_inputs = [
            "state_t",
            "state_q0001",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
        ]

        self.v1_outputs = [
            "ptend_t",
            "ptend_q0001",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.v2_inputs = [
            "state_t",
            "state_q0001",
            "state_q0002",
            "state_q0003",
            "state_u",
            "state_v",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
            "pbuf_TAUX",
            "pbuf_TAUY",
            "pbuf_COSZRS",
            "cam_in_ALDIF",
            "cam_in_ALDIR",
            "cam_in_ASDIF",
            "cam_in_ASDIR",
            "cam_in_LWUP",
            "cam_in_ICEFRAC",
            "cam_in_LANDFRAC",
            "cam_in_OCNFRAC",
            "cam_in_SNOWHICE",
            "cam_in_SNOWHLAND",
            "pbuf_ozone",  # outside of the upper troposphere lower stratosphere (UTLS, corresponding to indices 5-21), variance in minimal for these last 3
            "pbuf_CH4",
            "pbuf_N2O",
        ]

        self.v2_outputs = [
            "ptend_t",
            "ptend_q0001",
            "ptend_q0002",
            "ptend_q0003",
            "ptend_u",
            "ptend_v",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.var_lens = {  # inputs
            "state_t": self.num_levels,
            "state_q0001": self.num_levels,
            "state_q0002": self.num_levels,
            "state_q0003": self.num_levels,
            "state_u": self.num_levels,
            "state_v": self.num_levels,
            "state_ps": 1,
            "pbuf_SOLIN": 1,
            "pbuf_LHFLX": 1,
            "pbuf_SHFLX": 1,
            "pbuf_TAUX": 1,
            "pbuf_TAUY": 1,
            "pbuf_COSZRS": 1,
            "cam_in_ALDIF": 1,
            "cam_in_ALDIR": 1,
            "cam_in_ASDIF": 1,
            "cam_in_ASDIR": 1,
            "cam_in_LWUP": 1,
            "cam_in_ICEFRAC": 1,
            "cam_in_LANDFRAC": 1,
            "cam_in_OCNFRAC": 1,
            "cam_in_SNOWHICE": 1,
            "cam_in_SNOWHLAND": 1,
            "pbuf_ozone": self.num_levels,
            "pbuf_CH4": self.num_levels,
            "pbuf_N2O": self.num_levels,
            # outputs
            "ptend_t": self.num_levels,
            "ptend_q0001": self.num_levels,
            "ptend_q0002": self.num_levels,
            "ptend_q0003": self.num_levels,
            "ptend_u": self.num_levels,
            "ptend_v": self.num_levels,
            "cam_out_NETSW": 1,
            "cam_out_FLWDS": 1,
            "cam_out_PRECSC": 1,
            "cam_out_PRECC": 1,
            "cam_out_SOLS": 1,
            "cam_out_SOLL": 1,
            "cam_out_SOLSD": 1,
            "cam_out_SOLLD": 1,
        }

        self.var_short_names = {
            "ptend_t": "$dT/dt$",
            "ptend_q0001": "$dq/dt$",
            "cam_out_NETSW": "NETSW",
            "cam_out_FLWDS": "FLWDS",
            "cam_out_PRECSC": "PRECSC",
            "cam_out_PRECC": "PRECC",
            "cam_out_SOLS": "SOLS",
            "cam_out_SOLL": "SOLL",
            "cam_out_SOLSD": "SOLSD",
            "cam_out_SOLLD": "SOLLD",
        }

        self.target_energy_conv = {
            "ptend_t": self.cp,
            "ptend_q0001": self.lv,
            "ptend_q0002": self.lv,
            "ptend_q0003": self.lv,
            "ptend_wind": None,
            "cam_out_NETSW": 1.0,
            "cam_out_FLWDS": 1.0,
            "cam_out_PRECSC": self.lv * self.rho_h20,
            "cam_out_PRECC": self.lv * self.rho_h20,
            "cam_out_SOLS": 1.0,
            "cam_out_SOLL": 1.0,
            "cam_out_SOLSD": 1.0,
            "cam_out_SOLLD": 1.0,
        }

        # for metrics

        self.input_train = None
        self.target_train = None
        self.preds_train = None
        self.samplepreds_train = None
        self.target_weighted_train = {}
        self.preds_weighted_train = {}
        self.samplepreds_weighted_train = {}
        self.metrics_train = []
        self.metrics_idx_train = {}
        self.metrics_var_train = {}

        self.input_val = None
        self.target_val = None
        self.preds_val = None
        self.samplepreds_val = None
        self.target_weighted_val = {}
        self.preds_weighted_val = {}
        self.samplepreds_weighted_val = {}
        self.metrics_val = []
        self.metrics_idx_val = {}
        self.metrics_var_val = {}

        self.input_scoring = None
        self.target_scoring = None
        self.preds_scoring = None
        self.samplepreds_scoring = None
        self.target_weighted_scoring = {}
        self.preds_weighted_scoring = {}
        self.samplepreds_weighted_scoring = {}
        self.metrics_scoring = []
        self.metrics_idx_scoring = {}
        self.metrics_var_scoring = {}

        self.input_test = None
        self.target_test = None
        self.preds_test = None
        self.samplepreds_test = None
        self.target_weighted_test = {}
        self.preds_weighted_test = {}
        self.samplepreds_weighted_test = {}
        self.metrics_test = []
        self.metrics_idx_test = {}
        self.metrics_var_test = {}

        self.model_names = []
        self.metrics_names = []
        self.metrics_dict = {
            "MAE": self.calc_MAE,
            "RMSE": self.calc_RMSE,
            "R2": self.calc_R2,
            "CRPS": self.calc_CRPS,
            "bias": self.calc_bias,
        }
        self.num_CRPS = 32
        self.linecolors = ["#0072B2", "#E69F00", "#882255", "#009E73", "#D55E00"]

    def set_to_v2_vars(self):
        """
        This function sets the inputs and outputs to the V2 subset.
        It also indicates the index of the surface pressure variable.
        """
        self.input_vars = self.v2_inputs
        self.target_vars = self.v2_outputs
        self.ps_index = 360
        self.input_feature_len = 557
        self.target_feature_len = 368
        self.full_vars = True

    def get_xrdata(self, file, file_vars=None):
        """
        This function reads in a file and returns an xarray dataset with the variables specified.
        file_vars must be a list of strings.
        """
        ds = xr.open_dataset(file, engine="netcdf4")
        if file_vars is not None:
            ds = ds[file_vars]
        ds = ds.merge(self.grid_info[["lat", "lon"]])
        ds = ds.where((ds["lat"] > -999) * (ds["lat"] < 999), drop=True)
        ds = ds.where((ds["lon"] > -999) * (ds["lon"] < 999), drop=True)
        return ds

    def get_input(self, input_file):
        """
        This function reads in a file and returns an xarray dataset with the input variables for the emulator.
        """
        # read inputs
        return self.get_xrdata(input_file, self.input_vars)

    def get_target(self, input_file):
        """
        This function reads in a file and returns an xarray dataset with the target variables for the emulator.
        """
        # read inputs
        ds_input = self.get_input(input_file)
        ds_target = self.get_xrdata(input_file.replace(".mli.", ".mlo."))
        # each timestep is 20 minutes which corresponds to 1200 seconds
        ds_target["ptend_t"] = (
            ds_target["state_t"] - ds_input["state_t"]
        ) / 1200  # T tendency [K/s]
        ds_target["ptend_q0001"] = (
            ds_target["state_q0001"] - ds_input["state_q0001"]
        ) / 1200  # Q tendency [kg/kg/s]
        if self.full_vars:
            ds_target["ptend_q0002"] = (
                ds_target["state_q0002"] - ds_input["state_q0002"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_q0003"] = (
                ds_target["state_q0003"] - ds_input["state_q0003"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_u"] = (
                ds_target["state_u"] - ds_input["state_u"]
            ) / 1200  # U tendency [m/s/s]
            ds_target["ptend_v"] = (
                ds_target["state_v"] - ds_input["state_v"]
            ) / 1200  # V tendency [m/s/s]
        ds_target = ds_target[self.target_vars]
        return ds_target

    def set_regexps(self, data_split, regexps):
        """
        This function sets the regular expressions used for getting the filelist for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            self.train_regexps = regexps
        elif data_split == "val":
            self.val_regexps = regexps
        elif data_split == "scoring":
            self.scoring_regexps = regexps
        elif data_split == "test":
            self.test_regexps = regexps

    def set_stride_sample(self, data_split, stride_sample):
        """
        This function sets the stride_sample for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            self.train_stride_sample = stride_sample
        elif data_split == "val":
            self.val_stride_sample = stride_sample
        elif data_split == "scoring":
            self.scoring_stride_sample = stride_sample
        elif data_split == "test":
            self.test_stride_sample = stride_sample

    def set_filelist(self, data_split):
        """
        This function sets the filelists corresponding to data splits for train, val, scoring, and test.
        """
        filelist = []
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            assert self.train_regexps is not None, "regexps for train is not set."
            assert (
                self.train_stride_sample is not None
            ), "stride_sample for train is not set."
            for regexp in self.train_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.train_filelist = sorted(filelist)[:: self.train_stride_sample]
        elif data_split == "val":
            assert self.val_regexps is not None, "regexps for val is not set."
            assert (
                self.val_stride_sample is not None
            ), "stride_sample for val is not set."
            for regexp in self.val_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.val_filelist = sorted(filelist)[:: self.val_stride_sample]
        elif data_split == "scoring":
            assert self.scoring_regexps is not None, "regexps for scoring is not set."
            assert (
                self.scoring_stride_sample is not None
            ), "stride_sample for scoring is not set."
            for regexp in self.scoring_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.scoring_filelist = sorted(filelist)[:: self.scoring_stride_sample]
        elif data_split == "test":
            assert self.test_regexps is not None, "regexps for test is not set."
            assert (
                self.test_stride_sample is not None
            ), "stride_sample for test is not set."
            for regexp in self.test_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.test_filelist = sorted(filelist)[:: self.test_stride_sample]

    def get_filelist(self, data_split):
        """
        This function returns the filelist corresponding to data splits for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            assert self.train_filelist is not None, "filelist for train is not set."
            return self.train_filelist
        elif data_split == "val":
            assert self.val_filelist is not None, "filelist for val is not set."
            return self.val_filelist
        elif data_split == "scoring":
            assert self.scoring_filelist is not None, "filelist for scoring is not set."
            return self.scoring_filelist
        elif data_split == "test":
            assert self.test_filelist is not None, "filelist for test is not set."
            return self.test_filelist

    def load_ncdata_with_generator(self, data_split):
        """
        This function works as a dataloader when training the emulator with raw netCDF files.
        This can be used as a dataloader during training or it can be used to create entire datasets.
        When used as a dataloader for training, I/O can slow down training considerably.
        This function also normalizes the data.
        mli corresponds to input
        mlo corresponds to target
        """
        filelist = self.get_filelist(data_split)

        def gen():
            for file in filelist:
                # read inputs
                ds_input = self.get_input(file)
                # read targets
                ds_target = self.get_target(file)

                # normalization, scaling
                if self.normalize:
                    ds_input = (ds_input - self.input_mean) / (
                        self.input_max - self.input_min
                    )
                    ds_target = ds_target * self.output_scale
                else:
                    ds_input = ds_input.drop(["lat", "lon"])

                # stack
                # ds = ds.stack({'batch':{'sample','ncol'}})
                ds_input = ds_input.stack({"batch": {"ncol"}})
                ds_input = ds_input.to_stacked_array(
                    "mlvar", sample_dims=["batch"], name="mli"
                )
                # dso = dso.stack({'batch':{'sample','ncol'}})
                ds_target = ds_target.stack({"batch": {"ncol"}})
                ds_target = ds_target.to_stacked_array(
                    "mlvar", sample_dims=["batch"], name="mlo"
                )
                yield (ds_input.values, ds_target.values)

        if self.ml_backend == "tensorflow":

            # Removed output_shapes and output_types, converting to output_signature as is
            # recommended in the latest version of TensorFlow.
            return self.tf.data.Dataset.from_generator(
                gen,
                output_signature=(
                    self.tf.TensorSpec(
                        shape=(None, self.input_feature_len), dtype=self.tf.float64
                    ),
                    self.tf.TensorSpec(
                        shape=(None, self.target_feature_len), dtype=self.tf.float64
                    ),
                ),
            )

        elif self.ml_backend == "pytorch":
            if self.successful_backend_import:

                class IterableTorchDataset(self.torch.utils.data.IterableDataset):
                    def __init__(
                        this_self, data_generator, output_types, output_shapes
                    ):
                        this_self.data_generator = data_generator
                        this_self.output_types = output_types
                        this_self.output_shapes = output_shapes

                    def __iter__(this_self):
                        for item in this_self.data_generator:

                            input_array = self.torch.tensor(
                                item[0], dtype=this_self.output_types[0]
                            )
                            target_array = self.torch.tensor(
                                item[1], dtype=this_self.output_types[1]
                            )

                            # Assert final dimensions are correct.
                            assert (
                                input_array.shape[-1] == this_self.output_shapes[0][-1]
                            )
                            assert (
                                target_array.shape[-1] == this_self.output_shapes[1][-1]
                            )

                            yield (input_array, target_array)

                    def as_numpy_iterator(this_self):
                        for item in this_self.data_generator:

                            # Convert item to numpy array
                            input_array = np.array(item[0])
                            target_array = np.array(item[1])

                            # Assert final dimensions are correct.
                            assert (
                                input_array.shape[-1] == this_self.output_shapes[0][-1]
                            )
                            assert (
                                target_array.shape[-1] == this_self.output_shapes[1][-1]
                            )

                            yield (input_array, target_array)

                dataset = IterableTorchDataset(
                    gen(),
                    (self.torch.float64, self.torch.float64),
                    ((None, self.input_feature_len), (None, self.target_feature_len)),
                )

                return dataset

    @staticmethod
    def ls(dir_path=""):
        """
        You can treat this as a Python wrapper for the bash command "ls".
        """
        return os.popen(" ".join(["ls", dir_path])).read().splitlines()

    @staticmethod
    def set_plot_params():
        """
        This function sets the plot parameters for matplotlib.
        """
        plt.close("all")
        plt.rcParams.update(plt.rcParamsDefault)
        plt.rc("font", family="sans")
        plt.rcParams.update(
            {
                "font.size": 32,
                "lines.linewidth": 2,
                "axes.labelsize": 32,
                "axes.titlesize": 32,
                "xtick.labelsize": 32,
                "ytick.labelsize": 32,
                "legend.fontsize": 32,
                "axes.linewidth": 2,
                "pgf.texsystem": "pdflatex",
            }
        )
        # %config InlineBackend.figure_format = 'retina'
        # use the above line when working in a jupyter notebook

    @staticmethod
    def load_npy_file(load_path=""):
        """
        This function loads the prediction .npy file.
        """
        with open(load_path, "rb") as f:
            pred = np.load(f)
        return pred

    @staticmethod
    def load_h5_file(load_path=""):
        """
        This function loads the prediction .h5 file.
        """
        hf = h5py.File(load_path, "r")
        pred = np.array(hf.get("pred"))
        return pred

    def output_weighting(self, output, data_split, just_weights=False):
        """
        This function does four transformations, and assumes we are using V1 variables:
        [0] Undos the output scaling
        [1] Weight vertical levels by dp/g
        [2] Weight horizontal area of each grid cell by a[x]/mean(a[x])
        [3] Unit conversion to a common energy unit
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        num_samples = output.shape[0]
        if just_weights:
            weightings = np.ones(output.shape)

        if not self.full_vars:
            ptend_t = output[:, :60].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_q0001 = output[:, 60:120].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            netsw = output[:, 120].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            flwds = output[:, 121].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            precsc = output[:, 122].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            precc = output[:, 123].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            sols = output[:, 124].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            soll = output[:, 125].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            solsd = output[:, 126].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            solld = output[:, 127].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            if just_weights:
                ptend_t_weight = weightings[:, :60].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_q0001_weight = weightings[:, 60:120].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                netsw_weight = weightings[:, 120].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                flwds_weight = weightings[:, 121].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                precsc_weight = weightings[:, 122].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                precc_weight = weightings[:, 123].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                sols_weight = weightings[:, 124].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                soll_weight = weightings[:, 125].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                solsd_weight = weightings[:, 126].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                solld_weight = weightings[:, 127].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
        else:
            ptend_t = output[:, :60].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_q0001 = output[:, 60:120].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_q0002 = output[:, 120:180].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_q0003 = output[:, 180:240].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_u = output[:, 240:300].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            ptend_v = output[:, 300:360].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon, 60)
            )
            netsw = output[:, 360].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            flwds = output[:, 361].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            precsc = output[:, 362].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            precc = output[:, 363].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            sols = output[:, 364].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            soll = output[:, 365].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            solsd = output[:, 366].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            solld = output[:, 367].reshape(
                (int(num_samples / self.num_latlon), self.num_latlon)
            )
            state_wind = ((ptend_u**2) + (ptend_v**2)) ** 0.5
            self.target_energy_conv["ptend_wind"] = state_wind
            if just_weights:
                ptend_t_weight = weightings[:, :60].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_q0001_weight = weightings[:, 60:120].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_q0002_weight = weightings[:, 120:180].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_q0003_weight = weightings[:, 180:240].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_u_weight = weightings[:, 240:300].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                ptend_v_weight = weightings[:, 300:360].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon, 60)
                )
                netsw_weight = weightings[:, 360].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                flwds_weight = weightings[:, 361].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                precsc_weight = weightings[:, 362].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                precc_weight = weightings[:, 363].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                sols_weight = weightings[:, 364].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                soll_weight = weightings[:, 365].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                solsd_weight = weightings[:, 366].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )
                solld_weight = weightings[:, 367].reshape(
                    (int(num_samples / self.num_latlon), self.num_latlon)
                )

        # ptend_t = ptend_t.transpose((2,0,1))
        # ptend_q0001 = ptend_q0001.transpose((2,0,1))
        # scalar_outputs = scalar_outputs.transpose((2,0,1))

        # [0] Undo output scaling
        if self.normalize:
            ptend_t = (
                ptend_t / self.output_scale["ptend_t"].values[np.newaxis, np.newaxis, :]
            )
            ptend_q0001 = (
                ptend_q0001
                / self.output_scale["ptend_q0001"].values[np.newaxis, np.newaxis, :]
            )
            netsw = netsw / self.output_scale["cam_out_NETSW"].values
            flwds = flwds / self.output_scale["cam_out_FLWDS"].values
            precsc = precsc / self.output_scale["cam_out_PRECSC"].values
            precc = precc / self.output_scale["cam_out_PRECC"].values
            sols = sols / self.output_scale["cam_out_SOLS"].values
            soll = soll / self.output_scale["cam_out_SOLL"].values
            solsd = solsd / self.output_scale["cam_out_SOLSD"].values
            solld = solld / self.output_scale["cam_out_SOLLD"].values
            if just_weights:
                ptend_t_weight = (
                    ptend_t_weight
                    / self.output_scale["ptend_t"].values[np.newaxis, np.newaxis, :]
                )
                ptend_q0001_weight = (
                    ptend_q0001_weight
                    / self.output_scale["ptend_q0001"].values[np.newaxis, np.newaxis, :]
                )
                netsw_weight = netsw_weight / self.output_scale["cam_out_NETSW"].values
                flwds_weight = flwds_weight / self.output_scale["cam_out_FLWDS"].values
                precsc_weight = (
                    precsc_weight / self.output_scale["cam_out_PRECSC"].values
                )
                precc_weight = precc_weight / self.output_scale["cam_out_PRECC"].values
                sols_weight = sols_weight / self.output_scale["cam_out_SOLS"].values
                soll_weight = soll_weight / self.output_scale["cam_out_SOLL"].values
                solsd_weight = solsd_weight / self.output_scale["cam_out_SOLSD"].values
                solld_weight = solld_weight / self.output_scale["cam_out_SOLLD"].values
            if self.full_vars:
                ptend_q0002 = (
                    ptend_q0002
                    / self.output_scale["ptend_q0002"].values[np.newaxis, np.newaxis, :]
                )
                ptend_q0003 = (
                    ptend_q0003
                    / self.output_scale["ptend_q0003"].values[np.newaxis, np.newaxis, :]
                )
                ptend_u = (
                    ptend_u
                    / self.output_scale["ptend_u"].values[np.newaxis, np.newaxis, :]
                )
                ptend_v = (
                    ptend_v
                    / self.output_scale["ptend_v"].values[np.newaxis, np.newaxis, :]
                )
                if just_weights:
                    ptend_q0002_weight = (
                        ptend_q0002_weight
                        / self.output_scale["ptend_q0002"].values[
                            np.newaxis, np.newaxis, :
                        ]
                    )
                    ptend_q0003_weight = (
                        ptend_q0003_weight
                        / self.output_scale["ptend_q0003"].values[
                            np.newaxis, np.newaxis, :
                        ]
                    )
                    ptend_u_weight = (
                        ptend_u_weight
                        / self.output_scale["ptend_u"].values[np.newaxis, np.newaxis, :]
                    )
                    ptend_v_weight = (
                        ptend_v_weight
                        / self.output_scale["ptend_v"].values[np.newaxis, np.newaxis, :]
                    )

        # [1] Weight vertical levels by dp/g
        # only for vertically-resolved variables, e.g. ptend_{t,q0001}
        # dp/g = -\rho * dz

        dp = None
        if data_split == "train":
            dp = self.dp_train
        elif data_split == "val":
            dp = self.dp_val
        elif data_split == "scoring":
            dp = self.dp_scoring
        elif data_split == "test":
            dp = self.dp_test
        assert dp is not None
        ptend_t = ptend_t * dp / self.grav
        ptend_q0001 = ptend_q0001 * dp / self.grav
        if just_weights:
            ptend_t_weight = ptend_t_weight * dp / self.grav
            ptend_q0001_weight = ptend_q0001_weight * dp / self.grav
        if self.full_vars:
            ptend_q0002 = ptend_q0002 * dp / self.grav
            ptend_q0003 = ptend_q0003 * dp / self.grav
            ptend_u = ptend_u * dp / self.grav
            ptend_v = ptend_v * dp / self.grav
            if just_weights:
                ptend_q0002_weight = ptend_q0002_weight * dp / self.grav
                ptend_q0003_weight = ptend_q0003_weight * dp / self.grav
                ptend_u_weight = ptend_u_weight * dp / self.grav
                ptend_v_weight = ptend_v_weight * dp / self.grav

        # [2] weight by area

        ptend_t = ptend_t * self.area_wgt[np.newaxis, :, np.newaxis]
        ptend_q0001 = ptend_q0001 * self.area_wgt[np.newaxis, :, np.newaxis]
        netsw = netsw * self.area_wgt[np.newaxis, :]
        flwds = flwds * self.area_wgt[np.newaxis, :]
        precsc = precsc * self.area_wgt[np.newaxis, :]
        precc = precc * self.area_wgt[np.newaxis, :]
        sols = sols * self.area_wgt[np.newaxis, :]
        soll = soll * self.area_wgt[np.newaxis, :]
        solsd = solsd * self.area_wgt[np.newaxis, :]
        solld = solld * self.area_wgt[np.newaxis, :]
        if just_weights:
            ptend_t_weight = ptend_t_weight * self.area_wgt[np.newaxis, :, np.newaxis]
            ptend_q0001_weight = (
                ptend_q0001_weight * self.area_wgt[np.newaxis, :, np.newaxis]
            )
            netsw_weight = netsw_weight * self.area_wgt[np.newaxis, :]
            flwds_weight = flwds_weight * self.area_wgt[np.newaxis, :]
            precsc_weight = precsc_weight * self.area_wgt[np.newaxis, :]
            precc_weight = precc_weight * self.area_wgt[np.newaxis, :]
            sols_weight = sols_weight * self.area_wgt[np.newaxis, :]
            soll_weight = soll_weight * self.area_wgt[np.newaxis, :]
            solsd_weight = solsd_weight * self.area_wgt[np.newaxis, :]
            solld_weight = solld_weight * self.area_wgt[np.newaxis, :]
        if self.full_vars:
            ptend_q0002 = ptend_q0002 * self.area_wgt[np.newaxis, :, np.newaxis]
            ptend_q0003 = ptend_q0003 * self.area_wgt[np.newaxis, :, np.newaxis]
            ptend_u = ptend_u * self.area_wgt[np.newaxis, :, np.newaxis]
            ptend_v = ptend_v * self.area_wgt[np.newaxis, :, np.newaxis]
            if just_weights:
                ptend_q0002_weight = (
                    ptend_q0002_weight * self.area_wgt[np.newaxis, :, np.newaxis]
                )
                ptend_q0003_weight = (
                    ptend_q0003_weight * self.area_wgt[np.newaxis, :, np.newaxis]
                )
                ptend_u_weight = (
                    ptend_u_weight * self.area_wgt[np.newaxis, :, np.newaxis]
                )
                ptend_v_weight = (
                    ptend_v_weight * self.area_wgt[np.newaxis, :, np.newaxis]
                )

        # [3] unit conversion

        ptend_t = ptend_t * self.target_energy_conv["ptend_t"]
        ptend_q0001 = ptend_q0001 * self.target_energy_conv["ptend_q0001"]
        netsw = netsw * self.target_energy_conv["cam_out_NETSW"]
        flwds = flwds * self.target_energy_conv["cam_out_FLWDS"]
        precsc = precsc * self.target_energy_conv["cam_out_PRECSC"]
        precc = precc * self.target_energy_conv["cam_out_PRECC"]
        sols = sols * self.target_energy_conv["cam_out_SOLS"]
        soll = soll * self.target_energy_conv["cam_out_SOLL"]
        solsd = solsd * self.target_energy_conv["cam_out_SOLSD"]
        solld = solld * self.target_energy_conv["cam_out_SOLLD"]
        if just_weights:
            ptend_t_weight = ptend_t_weight * self.target_energy_conv["ptend_t"]
            ptend_q0001_weight = (
                ptend_q0001_weight * self.target_energy_conv["ptend_q0001"]
            )
            netsw_weight = netsw_weight * self.target_energy_conv["cam_out_NETSW"]
            flwds_weight = flwds_weight * self.target_energy_conv["cam_out_FLWDS"]
            precsc_weight = precsc_weight * self.target_energy_conv["cam_out_PRECSC"]
            precc_weight = precc_weight * self.target_energy_conv["cam_out_PRECC"]
            sols_weight = sols_weight * self.target_energy_conv["cam_out_SOLS"]
            soll_weight = soll_weight * self.target_energy_conv["cam_out_SOLL"]
            solsd_weight = solsd_weight * self.target_energy_conv["cam_out_SOLSD"]
            solld_weight = solld_weight * self.target_energy_conv["cam_out_SOLLD"]
        if self.full_vars:
            ptend_q0002 = ptend_q0002 * self.target_energy_conv["ptend_q0002"]
            ptend_q0003 = ptend_q0003 * self.target_energy_conv["ptend_q0003"]
            ptend_u = ptend_u * self.target_energy_conv["ptend_wind"]
            ptend_v = ptend_v * self.target_energy_conv["ptend_wind"]
            if just_weights:
                ptend_q0002_weight = (
                    ptend_q0002_weight * self.target_energy_conv["ptend_q0002"]
                )
                ptend_q0003_weight = (
                    ptend_q0003_weight * self.target_energy_conv["ptend_q0003"]
                )
                ptend_u_weight = ptend_u_weight * self.target_energy_conv["ptend_wind"]
                ptend_v_weight = ptend_v_weight * self.target_energy_conv["ptend_wind"]

        if just_weights:
            if self.full_vars:
                weightings = np.concatenate(
                    [
                        ptend_t_weight.reshape((num_samples, 60)),
                        ptend_q0001_weight.reshape((num_samples, 60)),
                        ptend_q0002_weight.reshape((num_samples, 60)),
                        ptend_q0003_weight.reshape((num_samples, 60)),
                        ptend_u_weight.reshape((num_samples, 60)),
                        ptend_v_weight.reshape((num_samples, 60)),
                        netsw_weight.reshape((num_samples))[:, np.newaxis],
                        flwds_weight.reshape((num_samples))[:, np.newaxis],
                        precsc_weight.reshape((num_samples))[:, np.newaxis],
                        precc_weight.reshape((num_samples))[:, np.newaxis],
                        sols_weight.reshape((num_samples))[:, np.newaxis],
                        soll_weight.reshape((num_samples))[:, np.newaxis],
                        solsd_weight.reshape((num_samples))[:, np.newaxis],
                        solld_weight.reshape((num_samples))[:, np.newaxis],
                    ],
                    axis=1,
                )
            else:
                weightings = np.concatenate(
                    [
                        ptend_t_weight.reshape((num_samples, 60)),
                        ptend_q0001_weight.reshape((num_samples, 60)),
                        netsw_weight.reshape((num_samples))[:, np.newaxis],
                        flwds_weight.reshape((num_samples))[:, np.newaxis],
                        precsc_weight.reshape((num_samples))[:, np.newaxis],
                        precc_weight.reshape((num_samples))[:, np.newaxis],
                        sols_weight.reshape((num_samples))[:, np.newaxis],
                        soll_weight.reshape((num_samples))[:, np.newaxis],
                        solsd_weight.reshape((num_samples))[:, np.newaxis],
                        solld_weight.reshape((num_samples))[:, np.newaxis],
                    ],
                    axis=1,
                )
            return weightings
        else:
            var_dict = {
                "ptend_t": ptend_t,
                "ptend_q0001": ptend_q0001,
                "cam_out_NETSW": netsw,
                "cam_out_FLWDS": flwds,
                "cam_out_PRECSC": precsc,
                "cam_out_PRECC": precc,
                "cam_out_SOLS": sols,
                "cam_out_SOLL": soll,
                "cam_out_SOLSD": solsd,
                "cam_out_SOLLD": solld,
            }
            if self.full_vars:
                var_dict["ptend_q0002"] = ptend_q0002
                var_dict["ptend_q0003"] = ptend_q0003
                var_dict["ptend_u"] = ptend_u
                var_dict["ptend_v"] = ptend_v

            return var_dict

    def calc_MAE(self, pred, target, avg_grid=True):
        """
        calculate 'globally averaged' mean absolute error
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        """
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        mae = np.abs(pred - target).mean(axis=0)
        if avg_grid:
            return mae.mean(axis=0)  # we decided to average globally at end
        else:
            return mae

    def calc_RMSE(self, pred, target, avg_grid=True):
        """
        calculate 'globally averaged' root mean squared error
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        """
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target) ** 2
        rmse = np.sqrt(sq_diff.mean(axis=0))  # mean over time
        if avg_grid:
            return rmse.mean(axis=0)  # we decided to separately average globally at end
        else:
            return rmse

    def calc_R2(self, pred, target, avg_grid=True):
        """
        calculate 'globally averaged' R-squared
        for vertically-resolved variables, input shape should be time x grid x level
        for scalars, input shape should be time x grid

        returns vector of length level or 1
        """
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target) ** 2
        tss_time = (
            target - target.mean(axis=0)[np.newaxis, ...]
        ) ** 2  # mean over time
        r_squared = 1 - sq_diff.sum(axis=0) / tss_time.sum(axis=0)  # sum over time
        if avg_grid:
            return r_squared.mean(
                axis=0
            )  # we decided to separately average globally at end
        else:
            return r_squared

    def calc_bias(self, pred, target, avg_grid=True):
        """
        calculate bias
        for vertically-resolved variables, input shape should be time x grid x level
        for scalars, input shape should be time x grid

        returns vector of length level or 1
        """
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        bias = pred.mean(axis=0) - target.mean(axis=0)
        if avg_grid:
            return bias.mean(axis=0)  # we decided to separately average globally at end
        else:
            return bias

    def calc_CRPS(self, samplepreds, target, avg_grid=True):
        """
        calculate 'globally averaged' continuous ranked probability score
        for vertically-resolved variables, input shape should be time x grid x level x num_crps_samples
        for scalars, input shape should be time x grid x num_crps_samples

        returns vector of length level or 1
        """
        assert samplepreds.shape[1] == self.num_latlon
        assert len(samplepreds.shape) == len(target.shape) + 1
        assert len(samplepreds.shape) == 3 or len(samplepreds.shape) == 4
        num_crps = samplepreds.shape[-1]
        mae = np.mean(
            np.abs(samplepreds - target[..., np.newaxis]), axis=(0, -1)
        )  # mean over time and crps samples
        samplepreds = np.sort(samplepreds, axis=-1)
        diff = samplepreds[..., 1:] - samplepreds[..., :-1]
        count = np.arange(1, num_crps) * np.arange(num_crps - 1, 0, -1)
        if len(samplepreds.shape) == 4:
            spread = (
                (diff * count[np.newaxis, np.newaxis, np.newaxis, :])
                .sum(axis=-1)
                .mean(axis=0)
            )  # sum over crps samples and mean over time
        elif len(samplepreds.shape) == 3:
            spread = (
                (diff * count[np.newaxis, np.newaxis, :]).sum(axis=-1).mean(axis=0)
            )  # sum over crps samples and mean over time
        crps = mae - spread / (num_crps * (num_crps - 1))
        # count was not multiplied by two so no need to divide by two
        if avg_grid:
            return crps.mean(axis=0)  # we decided to separately average globally at end
        else:
            return crps

    def reshape_daily(self, output):
        """
        This function returns two numpy arrays, one for each vertically resolved variable (ptend_t and ptend_q0001).
        Dimensions of expected input are num_samples by 128 (number of target features).
        Output argument is espected to be have dimensions of num_samples by features.
        ptend_t is expected to be the first feature, and ptend_q0001 is expected to be the second feature.
        Data is expected to use a stride_sample of 6. (12 samples per day, 20 min timestep).
        """
        num_samples = output.shape[0]
        ptend_t = output[:, :60].reshape(
            (int(num_samples / self.num_latlon), self.num_latlon, 60)
        )
        ptend_q0001 = output[:, 60:120].reshape(
            (int(num_samples / self.num_latlon), self.num_latlon, 60)
        )
        ptend_t_daily = np.mean(
            ptend_t.reshape((ptend_t.shape[0] // 12, 12, self.num_latlon, 60)), axis=1
        )  # Nday x lotlonnum x 60
        ptend_q0001_daily = np.mean(
            ptend_q0001.reshape((ptend_q0001.shape[0] // 12, 12, self.num_latlon, 60)),
            axis=1,
        )  # Nday x lotlonnum x 60
        ptend_t_daily_long = []
        ptend_q0001_daily_long = []
        for i in range(len(self.lats)):
            ptend_t_daily_long.append(
                np.mean(ptend_t_daily[:, self.lat_indices_list[i], :], axis=1)
            )
            ptend_q0001_daily_long.append(
                np.mean(ptend_q0001_daily[:, self.lat_indices_list[i], :], axis=1)
            )
        ptend_t_daily_long = np.array(ptend_t_daily_long)  # lat x Nday x 60
        ptend_q0001_daily_long = np.array(ptend_q0001_daily_long)  # lat x Nday x 60
        return ptend_t_daily_long, ptend_q0001_daily_long

    @staticmethod
    def reshape_input_for_cnn(npy_input, save_path=""):
        """
        This function reshapes a numpy input array to be compatible with CNN training.
        Each variable becomes its own channel.
        For the input there are 6 channels, each with 60 vertical levels.
        The last 4 channels correspond to scalars repeated across all 60 levels.
        This is for V1 data only! (V2 data has more variables)
        """
        npy_input_cnn = np.stack(
            [
                npy_input[:, 0:60],
                npy_input[:, 60:120],
                np.repeat(npy_input[:, 120][:, np.newaxis], 60, axis=1),
                np.repeat(npy_input[:, 121][:, np.newaxis], 60, axis=1),
                np.repeat(npy_input[:, 122][:, np.newaxis], 60, axis=1),
                np.repeat(npy_input[:, 123][:, np.newaxis], 60, axis=1),
            ],
            axis=2,
        )

        if save_path != "":
            with open(save_path + "train_input_cnn.npy", "wb") as f:
                np.save(f, np.float32(npy_input_cnn))
        return npy_input_cnn

    @staticmethod
    def reshape_target_for_cnn(npy_target, save_path=""):
        """
        This function reshapes a numpy target array to be compatible with CNN training.
        Each variable becomes its own channel.
        For the input there are 6 channels, each with 60 vertical levels.
        The last 4 channels correspond to scalars repeated across all 60 levels.
        This is for V1 data only! (V2 data has more variables)
        """
        npy_target_cnn = np.stack(
            [
                npy_target[:, 0:60],
                npy_target[:, 60:120],
                np.repeat(npy_target[:, 120][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 121][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 122][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 123][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 124][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 125][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 126][:, np.newaxis], 60, axis=1),
                np.repeat(npy_target[:, 127][:, np.newaxis], 60, axis=1),
            ],
            axis=2,
        )

        if save_path != "":
            with open(save_path + "train_target_cnn.npy", "wb") as f:
                np.save(f, np.float32(npy_target_cnn))
        return npy_target_cnn

    @staticmethod
    def reshape_target_from_cnn(npy_predict_cnn, save_path=""):
        """
        This function reshapes CNN target to (num_samples, 128) for standardized metrics.
        This is for V1 data only! (V2 data has more variables)
        """
        npy_predict_cnn_reshaped = np.concatenate(
            [
                npy_predict_cnn[:, :, 0],
                npy_predict_cnn[:, :, 1],
                np.mean(npy_predict_cnn[:, :, 2], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 3], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 4], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 5], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 6], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 7], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 8], axis=1)[:, np.newaxis],
                np.mean(npy_predict_cnn[:, :, 9], axis=1)[:, np.newaxis],
            ],
            axis=1,
        )

        if save_path != "":
            with open(save_path + "cnn_predict_reshaped.npy", "wb") as f:
                np.save(f, np.float32(npy_predict_cnn_reshaped))
        return npy_predict_cnn_reshaped

In [15]:
data = data_utils(
    grid_info=grid_info,
    input_mean=input_mean,
    input_max=input_max,
    input_min=input_min,
    output_scale=output_scale,
)

In [16]:
data.set_to_v2_vars()
# do not normalize
data.normalize = False

In [18]:
"""
    regexps=[
        "E3SM-MMF.mli.000[12345678]-*-*-*.nc",  # years 1 through 7
        "E3SM-MMF.mli.0009-01-*-*.nc",
    ],
"""

data.set_regexps(
    data_split="train",
    regexps=[
        "E3SM-MMF.mli.0009-01-*-*.nc",
    ],
)
# set temporal subsampling
data.set_stride_sample(data_split="train", stride_sample=1000)  # もとは7

In [19]:
# create list of files to extract data from
data.set_filelist(data_split="train")