In [46]:
%cd /kaggle/working
import os

from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../preprocess/make_webdataset"):
    cfg = compose(
        config_name="config.yaml", overrides=["debug=True"], return_hydra_config=True
    )
    print(OmegaConf.to_yaml(cfg.exp))

/kaggle/working
seed: 7



In [43]:
import copy
import glob
import json
import logging
import os
import pickle
import re
import shutil
import string
from glob import glob
from pathlib import Path
from typing import Literal

import datasets
import h5py
import matplotlib.pyplot as plt
import netCDF4
import numpy as np
import pandas as pd
import tensorflow as tf
import webdataset as wds
import xarray as xr
from huggingface_hub import snapshot_download
from tqdm import tqdm
from tqdm.auto import tqdm

runtime_choices = cfg.hydra.runtime.choices
exp_name = f"notebook/{runtime_choices.exp}"
print(f"exp_name: {exp_name}")
output_path = Path(cfg.dir.output_dir) / exp_name
print(f"ouput_path: {output_path}")
os.makedirs(output_path, exist_ok=True)

exp_name: notebook/base
ouput_path: /kaggle/working/output/notebook/base


In [6]:
MLBackendType = Literal["tensorflow", "pytorch"]


class DataUtils:
    def __init__(
        self,
        grid_info,
        input_mean,
        input_max,
        input_min,
        output_scale,
        ml_backend: MLBackendType = "pytorch",
    ):
        self.data_path = None
        self.input_vars = []
        self.target_vars = []
        self.input_feature_len = None
        self.target_feature_len = None
        self.grid_info = grid_info
        self.level_name = "lev"
        self.sample_name = "sample"
        self.num_levels = len(self.grid_info["lev"])
        self.num_latlon = len(
            self.grid_info["ncol"]
        )  # number of unique lat/lon grid points
        # make area-weights
        self.grid_info["area_wgt"] = self.grid_info["area"] / self.grid_info[
            "area"
        ].mean(dim="ncol")
        self.area_wgt = self.grid_info["area_wgt"].values
        # map ncol to nsamples dimension
        # to_xarray = {'area_wgt':(self.sample_name,np.tile(self.grid_info['area_wgt'], int(n_samples/len(self.grid_info['ncol']))))}
        # to_xarray = xr.Dataset(to_xarray)
        self.input_mean = input_mean
        self.input_max = input_max
        self.input_min = input_min
        self.output_scale = output_scale
        self.normalize = True
        self.lats, self.lats_indices = np.unique(
            self.grid_info["lat"].values, return_index=True
        )
        self.lons, self.lons_indices = np.unique(
            self.grid_info["lon"].values, return_index=True
        )
        self.sort_lat_key = np.argsort(
            self.grid_info["lat"].values[np.sort(self.lats_indices)]
        )
        self.sort_lon_key = np.argsort(
            self.grid_info["lon"].values[np.sort(self.lons_indices)]
        )
        self.indextolatlon = {
            i: (
                self.grid_info["lat"].values[i % self.num_latlon],
                self.grid_info["lon"].values[i % self.num_latlon],
            )
            for i in range(self.num_latlon)
        }

        self.ml_backend = ml_backend
        self.tf = None
        self.torch = None

        if self.ml_backend == "tensorflow":
            self.successful_backend_import = False

            try:
                import tensorflow as tf

                self.tf = tf
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("Tensorflow is not installed.")

        elif self.ml_backend == "pytorch":
            self.successful_backend_import = False

            try:
                import torch

                self.torch = torch
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("PyTorch is not installed.")

        def find_keys(dictionary, value):
            keys = []
            for key, val in dictionary.items():
                if val[0] == value:
                    keys.append(key)
            return keys

        indices_list = []
        for lat in self.lats:
            indices = find_keys(self.indextolatlon, lat)
            indices_list.append(indices)
        indices_list.sort(key=lambda x: x[0])
        self.lat_indices_list = indices_list

        self.hyam = self.grid_info["hyam"].values
        self.hybm = self.grid_info["hybm"].values
        self.p0 = 1e5  # code assumes this will always be a scalar
        self.ps_index = None

        self.pressure_grid_train = None
        self.pressure_grid_val = None
        self.pressure_grid_scoring = None
        self.pressure_grid_test = None

        self.dp_train = None
        self.dp_val = None
        self.dp_scoring = None
        self.dp_test = None

        self.train_regexps = None
        self.train_stride_sample = None
        self.train_filelist = None
        self.val_regexps = None
        self.val_stride_sample = None
        self.val_filelist = None
        self.scoring_regexps = None
        self.scoring_stride_sample = None
        self.scoring_filelist = None
        self.test_regexps = None
        self.test_stride_sample = None
        self.test_filelist = None

        self.full_vars = False

        # physical constants from E3SM_ROOT/share/util/shr_const_mod.F90
        self.grav = 9.80616  # acceleration of gravity ~ m/s^2
        self.cp = 1.00464e3  # specific heat of dry air   ~ J/kg/K
        self.lv = 2.501e6  # latent heat of evaporation ~ J/kg
        self.lf = 3.337e5  # latent heat of fusion      ~ J/kg
        self.lsub = self.lv + self.lf  # latent heat of sublimation ~ J/kg
        self.rho_air = (
            101325 / (6.02214e26 * 1.38065e-23 / 28.966) / 273.15
        )  # density of dry air at STP  ~ kg/m^3
        # ~ 1.2923182846924677
        # SHR_CONST_PSTD/(SHR_CONST_RDAIR*SHR_CONST_TKFRZ)
        # SHR_CONST_RDAIR   = SHR_CONST_RGAS/SHR_CONST_MWDAIR
        # SHR_CONST_RGAS    = SHR_CONST_AVOGAD*SHR_CONST_BOLTZ
        self.rho_h20 = 1.0e3  # density of fresh water     ~ kg/m^ 3

        self.v1_inputs = [
            "state_t",
            "state_q0001",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
        ]

        self.v1_outputs = [
            "ptend_t",
            "ptend_q0001",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.v2_inputs = [
            "state_t",
            "state_q0001",
            "state_q0002",
            "state_q0003",
            "state_u",
            "state_v",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
            "pbuf_TAUX",
            "pbuf_TAUY",
            "pbuf_COSZRS",
            "cam_in_ALDIF",
            "cam_in_ALDIR",
            "cam_in_ASDIF",
            "cam_in_ASDIR",
            "cam_in_LWUP",
            "cam_in_ICEFRAC",
            "cam_in_LANDFRAC",
            "cam_in_OCNFRAC",
            "cam_in_SNOWHICE",
            "cam_in_SNOWHLAND",
            "pbuf_ozone",  # outside of the upper troposphere lower stratosphere (UTLS, corresponding to indices 5-21), variance in minimal for these last 3
            "pbuf_CH4",
            "pbuf_N2O",
        ]

        self.v2_outputs = [
            "ptend_t",
            "ptend_q0001",
            "ptend_q0002",
            "ptend_q0003",
            "ptend_u",
            "ptend_v",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.var_lens = {  # inputs
            "state_t": self.num_levels,
            "state_q0001": self.num_levels,
            "state_q0002": self.num_levels,
            "state_q0003": self.num_levels,
            "state_u": self.num_levels,
            "state_v": self.num_levels,
            "state_ps": 1,
            "pbuf_SOLIN": 1,
            "pbuf_LHFLX": 1,
            "pbuf_SHFLX": 1,
            "pbuf_TAUX": 1,
            "pbuf_TAUY": 1,
            "pbuf_COSZRS": 1,
            "cam_in_ALDIF": 1,
            "cam_in_ALDIR": 1,
            "cam_in_ASDIF": 1,
            "cam_in_ASDIR": 1,
            "cam_in_LWUP": 1,
            "cam_in_ICEFRAC": 1,
            "cam_in_LANDFRAC": 1,
            "cam_in_OCNFRAC": 1,
            "cam_in_SNOWHICE": 1,
            "cam_in_SNOWHLAND": 1,
            "pbuf_ozone": self.num_levels,
            "pbuf_CH4": self.num_levels,
            "pbuf_N2O": self.num_levels,
            # outputs
            "ptend_t": self.num_levels,
            "ptend_q0001": self.num_levels,
            "ptend_q0002": self.num_levels,
            "ptend_q0003": self.num_levels,
            "ptend_u": self.num_levels,
            "ptend_v": self.num_levels,
            "cam_out_NETSW": 1,
            "cam_out_FLWDS": 1,
            "cam_out_PRECSC": 1,
            "cam_out_PRECC": 1,
            "cam_out_SOLS": 1,
            "cam_out_SOLL": 1,
            "cam_out_SOLSD": 1,
            "cam_out_SOLLD": 1,
        }

        self.var_short_names = {
            "ptend_t": "$dT/dt$",
            "ptend_q0001": "$dq/dt$",
            "cam_out_NETSW": "NETSW",
            "cam_out_FLWDS": "FLWDS",
            "cam_out_PRECSC": "PRECSC",
            "cam_out_PRECC": "PRECC",
            "cam_out_SOLS": "SOLS",
            "cam_out_SOLL": "SOLL",
            "cam_out_SOLSD": "SOLSD",
            "cam_out_SOLLD": "SOLLD",
        }

        self.target_energy_conv = {
            "ptend_t": self.cp,
            "ptend_q0001": self.lv,
            "ptend_q0002": self.lv,
            "ptend_q0003": self.lv,
            "ptend_wind": None,
            "cam_out_NETSW": 1.0,
            "cam_out_FLWDS": 1.0,
            "cam_out_PRECSC": self.lv * self.rho_h20,
            "cam_out_PRECC": self.lv * self.rho_h20,
            "cam_out_SOLS": 1.0,
            "cam_out_SOLL": 1.0,
            "cam_out_SOLSD": 1.0,
            "cam_out_SOLLD": 1.0,
        }

        # for metrics

        self.input_train = None
        self.target_train = None
        self.preds_train = None
        self.samplepreds_train = None
        self.target_weighted_train = {}
        self.preds_weighted_train = {}
        self.samplepreds_weighted_train = {}
        self.metrics_train = []
        self.metrics_idx_train = {}
        self.metrics_var_train = {}

        self.input_val = None
        self.target_val = None
        self.preds_val = None
        self.samplepreds_val = None
        self.target_weighted_val = {}
        self.preds_weighted_val = {}
        self.samplepreds_weighted_val = {}
        self.metrics_val = []
        self.metrics_idx_val = {}
        self.metrics_var_val = {}

        self.input_scoring = None
        self.target_scoring = None
        self.preds_scoring = None
        self.samplepreds_scoring = None
        self.target_weighted_scoring = {}
        self.preds_weighted_scoring = {}
        self.samplepreds_weighted_scoring = {}
        self.metrics_scoring = []
        self.metrics_idx_scoring = {}
        self.metrics_var_scoring = {}

        self.input_test = None
        self.target_test = None
        self.preds_test = None
        self.samplepreds_test = None
        self.target_weighted_test = {}
        self.preds_weighted_test = {}
        self.samplepreds_weighted_test = {}
        self.metrics_test = []
        self.metrics_idx_test = {}
        self.metrics_var_test = {}

        self.model_names = []
        self.metrics_names = []
        self.num_CRPS = 32
        self.linecolors = ["#0072B2", "#E69F00", "#882255", "#009E73", "#D55E00"]

    def set_to_v2_vars(self):
        """
        This function sets the inputs and outputs to the V2 subset.
        It also indicates the index of the surface pressure variable.
        """
        self.input_vars = self.v2_inputs
        self.target_vars = self.v2_outputs
        self.ps_index = 360
        self.input_feature_len = 557
        self.target_feature_len = 368
        self.full_vars = True

    def get_xrdata(self, file, file_vars=None):
        """
        This function reads in a file and returns an xarray dataset with the variables specified.
        file_vars must be a list of strings.
        """
        ds = xr.open_dataset(file, engine="netcdf4")
        if file_vars is not None:
            ds = ds[file_vars]
        ds = ds.merge(self.grid_info[["lat", "lon"]])
        ds = ds.where((ds["lat"] > -999) * (ds["lat"] < 999), drop=True)
        ds = ds.where((ds["lon"] > -999) * (ds["lon"] < 999), drop=True)
        return ds

    def get_input(self, input_file):
        """
        This function reads in a file and returns an xarray dataset with the input variables for the emulator.
        """
        # read inputs
        return self.get_xrdata(input_file, self.input_vars)

    def get_target(self, input_file):
        """
        This function reads in a file and returns an xarray dataset with the target variables for the emulator.
        """
        # read inputs
        ds_input = self.get_input(input_file)
        ds_target = self.get_xrdata(input_file.replace(".mli.", ".mlo."))
        # each timestep is 20 minutes which corresponds to 1200 seconds
        ds_target["ptend_t"] = (
            ds_target["state_t"] - ds_input["state_t"]
        ) / 1200  # T tendency [K/s]
        ds_target["ptend_q0001"] = (
            ds_target["state_q0001"] - ds_input["state_q0001"]
        ) / 1200  # Q tendency [kg/kg/s]
        if self.full_vars:
            ds_target["ptend_q0002"] = (
                ds_target["state_q0002"] - ds_input["state_q0002"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_q0003"] = (
                ds_target["state_q0003"] - ds_input["state_q0003"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_u"] = (
                ds_target["state_u"] - ds_input["state_u"]
            ) / 1200  # U tendency [m/s/s]
            ds_target["ptend_v"] = (
                ds_target["state_v"] - ds_input["state_v"]
            ) / 1200  # V tendency [m/s/s]
        ds_target = ds_target[self.target_vars]
        return ds_target

    def set_regexps(self, data_split, regexps):
        """
        This function sets the regular expressions used for getting the filelist for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            self.train_regexps = regexps
        elif data_split == "val":
            self.val_regexps = regexps
        elif data_split == "scoring":
            self.scoring_regexps = regexps
        elif data_split == "test":
            self.test_regexps = regexps

    def set_stride_sample(self, data_split, stride_sample):
        """
        This function sets the stride_sample for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            self.train_stride_sample = stride_sample
        elif data_split == "val":
            self.val_stride_sample = stride_sample
        elif data_split == "scoring":
            self.scoring_stride_sample = stride_sample
        elif data_split == "test":
            self.test_stride_sample = stride_sample

    def set_filelist(self, data_split):
        """
        This function sets the filelists corresponding to data splits for train, val, scoring, and test.
        """
        filelist = []
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            assert self.train_regexps is not None, "regexps for train is not set."
            assert (
                self.train_stride_sample is not None
            ), "stride_sample for train is not set."
            for regexp in self.train_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.train_filelist = sorted(filelist)[:: self.train_stride_sample]
        elif data_split == "val":
            assert self.val_regexps is not None, "regexps for val is not set."
            assert (
                self.val_stride_sample is not None
            ), "stride_sample for val is not set."
            for regexp in self.val_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.val_filelist = sorted(filelist)[:: self.val_stride_sample]
        elif data_split == "scoring":
            assert self.scoring_regexps is not None, "regexps for scoring is not set."
            assert (
                self.scoring_stride_sample is not None
            ), "stride_sample for scoring is not set."
            for regexp in self.scoring_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.scoring_filelist = sorted(filelist)[:: self.scoring_stride_sample]
        elif data_split == "test":
            assert self.test_regexps is not None, "regexps for test is not set."
            assert (
                self.test_stride_sample is not None
            ), "stride_sample for test is not set."
            for regexp in self.test_regexps:
                filelist = filelist + glob.glob(self.data_path + "*/" + regexp)
            self.test_filelist = sorted(filelist)[:: self.test_stride_sample]

    def get_filelist(self, data_split):
        """
        This function returns the filelist corresponding to data splits for train, val, scoring, and test.
        """
        assert data_split in [
            "train",
            "val",
            "scoring",
            "test",
        ], "Provided data_split is not valid. Available options are train, val, scoring, and test."
        if data_split == "train":
            assert self.train_filelist is not None, "filelist for train is not set."
            return self.train_filelist
        elif data_split == "val":
            assert self.val_filelist is not None, "filelist for val is not set."
            return self.val_filelist
        elif data_split == "scoring":
            assert self.scoring_filelist is not None, "filelist for scoring is not set."
            return self.scoring_filelist
        elif data_split == "test":
            assert self.test_filelist is not None, "filelist for test is not set."
            return self.test_filelist

    def gen(self, file):
        # read inputs
        ds_input = self.get_input(file)
        # read targets
        ds_target = self.get_target(file)

        # normalization, scaling
        if self.normalize:
            ds_input = (ds_input - self.input_mean) / (self.input_max - self.input_min)
            ds_target = ds_target * self.output_scale
        else:
            ds_input = ds_input.drop_vars(["lat", "lon"])

        # stack
        # ds = ds.stack({'batch':{'sample','ncol'}})
        ds_input = ds_input.stack({"batch": {"ncol"}})
        ds_input = ds_input.to_stacked_array("mlvar", sample_dims=["batch"], name="mli")
        # dso = dso.stack({'batch':{'sample','ncol'}})
        ds_target = ds_target.stack({"batch": {"ncol"}})
        ds_target = ds_target.to_stacked_array(
            "mlvar", sample_dims=["batch"], name="mlo"
        )
        return (ds_input.values, ds_target.values)

In [7]:
grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
norm_path = "/kaggle/working/misc/preprocessing/normalizations/"
grid_info = xr.open_dataset(grid_path)
input_mean = xr.open_dataset(norm_path + "inputs/input_mean.nc")
input_max = xr.open_dataset(norm_path + "inputs/input_max.nc")
input_min = xr.open_dataset(norm_path + "inputs/input_min.nc")
output_scale = xr.open_dataset(norm_path + "outputs/output_scale.nc")

data = DataUtils(
    grid_info=grid_info,
    input_mean=input_mean,
    input_max=input_max,
    input_min=input_min,
    output_scale=output_scale,
)
data.set_to_v2_vars()
data.normalize = False

In [8]:
# ディレクトリ指定
TMP_DIR = Path("/kaggle/working/tmp")

month_dirs = ["train/0001-02", "train/0001-03"]  # 本当はもっとある

month_dir = Path(month_dirs[0])  # 本当は for で取り出す

In [57]:
month_dirs = (
    [f"train/0001-{str(m).zfill(2)}" for m in range(2, 13)]
    + [f"train/000{y}-{str(m).zfill(2)}" for y in range(2, 9) for m in range(1, 13)]
    + ["train/0009-01"]
)
month_dirs

['train/0001-02',
 'train/0001-03',
 'train/0001-04',
 'train/0001-05',
 'train/0001-06',
 'train/0001-07',
 'train/0001-08',
 'train/0001-09',
 'train/0001-10',
 'train/0001-11',
 'train/0001-12',
 'train/0002-01',
 'train/0002-02',
 'train/0002-03',
 'train/0002-04',
 'train/0002-05',
 'train/0002-06',
 'train/0002-07',
 'train/0002-08',
 'train/0002-09',
 'train/0002-10',
 'train/0002-11',
 'train/0002-12',
 'train/0003-01',
 'train/0003-02',
 'train/0003-03',
 'train/0003-04',
 'train/0003-05',
 'train/0003-06',
 'train/0003-07',
 'train/0003-08',
 'train/0003-09',
 'train/0003-10',
 'train/0003-11',
 'train/0003-12',
 'train/0004-01',
 'train/0004-02',
 'train/0004-03',
 'train/0004-04',
 'train/0004-05',
 'train/0004-06',
 'train/0004-07',
 'train/0004-08',
 'train/0004-09',
 'train/0004-10',
 'train/0004-11',
 'train/0004-12',
 'train/0005-01',
 'train/0005-02',
 'train/0005-03',
 'train/0005-04',
 'train/0005-05',
 'train/0005-06',
 'train/0005-07',
 'train/0005-08',
 'train/00

In [26]:
if os.path.exists(TMP_DIR):
    shutil.rmtree(TMP_DIR)
TMP_DIR.mkdir(exist_ok=True)

In [27]:
# Download
download_path = snapshot_download(
    repo_id="LEAP/ClimSim_low-res",
    allow_patterns=str(
        month_dir / "*0001-02-01-*.nc"
    ),  # 本当は全部使うがデバッグ用に一日だけ
    local_dir=TMP_DIR,
    local_dir_use_symlinks=False,  # キャッシュしない
    repo_type="dataset",
)

Fetching 144 files:   0%|          | 0/144 [00:00<?, ?it/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-04800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-02400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-01200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-03600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-07200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-06000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-08400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-00000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-13200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-09600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-10800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-16800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-12000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-15600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-14400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-18000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-20400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-22800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-21600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-24000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-26400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-19200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-25200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-27600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-28800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-36000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-37200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-30000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-31200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-33600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-34800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-32400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-38400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-40800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-46800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-39600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-42000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-43200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-45600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-44400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-48000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-50400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-55200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-51600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-49200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-52800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-56400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-54000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-57600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-64800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-62400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-58800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-66000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-63600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-61200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-60000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-67200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-72000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-70800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-69600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-75600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-74400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-68400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-73200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-78000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-79200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-80400.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-76800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-81600.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-82800.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-84000.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mli.0001-02-01-85200.nc:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-01200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-02400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-00000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-03600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-04800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-06000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-08400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-07200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-09600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-12000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-15600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-13200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-14400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-10800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-16800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-18000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-19200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-22800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-21600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-24000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-25200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-20400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-26400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-27600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-28800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-32400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-31200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-33600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-30000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-34800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-36000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-37200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-38400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-42000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-39600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-40800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-44400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-43200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-45600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-46800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-48000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-49200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-50400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-51600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-54000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-52800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-55200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-56400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-57600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-58800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-60000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-61200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-62400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-63600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-64800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-67200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-69600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-66000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-73200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-68400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-70800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-72000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-74400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-80400.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-78000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-75600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-76800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-81600.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-79200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-82800.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-84000.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

(…)0001-02/E3SM-MMF.mlo.0001-02-01-85200.nc:   0%|          | 0.00/1.13M [00:00<?, ?B/s]

In [28]:
input_paths = glob(str(TMP_DIR / month_dir / "*.mli.*.nc"))

In [29]:
input_path = input_paths[0]
input_data, output_data = data.gen(input_path)

In [30]:
input_data.shape, output_data.shape

((384, 557), (384, 368))

In [32]:
shard_path = TMP_DIR / f"shards_{month_dir.name}"
shard_path

PosixPath('/kaggle/working/tmp/shards_0001-02')

In [37]:
if os.path.exists(shard_path):
    shutil.rmtree(shard_path)
shard_path.mkdir(exist_ok=True)
shard_filename = str(shard_path / "shards-%05d.tar")

shard_size = int(50 * 1000**2)  # 50MB each

dataset_size = 0
with wds.ShardWriter(shard_filename, maxsize=shard_size, maxcount=1e7) as sink, tqdm(
    input_paths
) as pbar:

    for input_path in pbar:
        input_data, output_data = data.gen(input_path)

        for i in range(len(input_data)):
            input_path = Path(input_path)
            month_text = input_path.parent.name
            file_name_text = input_path.name

            write_obj = {
                "__key__": f"{file_name_text}_{i}",
                "file_name_text.txt": file_name_text,
                "id": i,
                "month.txt": month_text,
                "input.npy": input_data[i, :],
                "output.npy": output_data[i, :],
            }

            sink.write(write_obj)
        dataset_size += len(input_data)

# writing /kaggle/working/tmp/shards_0001-02/shards-00000.tar 0 0.0 GB 0


  0%|          | 0/72 [00:00<?, ?it/s]

# writing /kaggle/working/tmp/shards_0001-02/shards-00001.tar 6496 0.1 GB 6496
# writing /kaggle/working/tmp/shards_0001-02/shards-00002.tar 6496 0.1 GB 12992
# writing /kaggle/working/tmp/shards_0001-02/shards-00003.tar 6496 0.1 GB 19488
# writing /kaggle/working/tmp/shards_0001-02/shards-00004.tar 6496 0.1 GB 25984


In [38]:
dataset_size

27648

In [39]:
with open(str(shard_path / "dataset-size.json"), "w") as fp:
    json.dump(
        {
            "dataset size": dataset_size,
        },
        fp,
    )

In [40]:
from google.cloud import storage

In [52]:
from tqdm.auto import tqdm


def upload_directory_to_gcs(local_directory, bucket_name, gcs_directory):
    client = storage.Client()
    bucket = client.get_bucket(bucket_name)

    files_to_upload = []
    for root, dirs, files in os.walk(local_directory):
        for file in files:
            local_file_path = os.path.join(root, file)
            gcs_file_path = os.path.join(
                gcs_directory, os.path.relpath(local_file_path, local_directory)
            )
            files_to_upload.append((local_file_path, gcs_file_path))

    progress_bar = tqdm(total=len(files_to_upload), unit="file")

    for local_file_path, gcs_file_path in files_to_upload:
        blob = bucket.blob(gcs_file_path)
        blob.upload_from_filename(local_file_path, if_generation_match=None)
        progress_bar.update(1)
        progress_bar.set_postfix(file=os.path.basename(local_file_path))

    progress_bar.close()
    print(f"All files uploaded successfully to {gcs_directory}.")

In [53]:
upload_directory_to_gcs(
    shard_path,
    cfg.dir.gcs_bucket,
    str(Path(cfg.dir.gcs_base_dir) / exp_name / shard_path.name),
)

  0%|          | 0/6 [00:00<?, ?file/s]

All files uploaded successfully to kami/notebook/base/shards_0001-02.


In [51]:
str(Path(cfg.dir.gcs_base_dir) / exp_name / shard_path.name)

'kami/notebook/base/shards_0001-02'