In [2]:
import glob
import json
import os
import shutil
import sys
import time
from glob import glob
from pathlib import Path
from typing import Literal

import hydra
import numpy as np
import webdataset as wds
import xarray as xr
from google.cloud import storage
from huggingface_hub import snapshot_download
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm

MLBackendType = Literal["tensorflow", "pytorch"]
TMP_DIR = Path("/kaggle/working/tmp")


class DataUtils:
    def __init__(
        self,
        grid_info,
        input_mean,
        input_max,
        input_min,
        output_scale,
        ml_backend: MLBackendType = "pytorch",
    ):
        self.data_path = None
        self.input_vars = []
        self.target_vars = []
        self.input_feature_len = None
        self.target_feature_len = None
        self.grid_info = grid_info
        self.level_name = "lev"
        self.sample_name = "sample"
        self.num_levels = len(self.grid_info["lev"])
        self.num_latlon = len(
            self.grid_info["ncol"]
        )  # number of unique lat/lon grid points
        # make area-weights
        self.grid_info["area_wgt"] = self.grid_info["area"] / self.grid_info[
            "area"
        ].mean(dim="ncol")
        self.area_wgt = self.grid_info["area_wgt"].values
        # map ncol to nsamples dimension
        # to_xarray = {'area_wgt':(self.sample_name,np.tile(self.grid_info['area_wgt'], int(n_samples/len(self.grid_info['ncol']))))}
        # to_xarray = xr.Dataset(to_xarray)
        self.input_mean = input_mean
        self.input_max = input_max
        self.input_min = input_min
        self.output_scale = output_scale
        self.normalize = True
        self.lats, self.lats_indices = np.unique(
            self.grid_info["lat"].values, return_index=True
        )
        self.lons, self.lons_indices = np.unique(
            self.grid_info["lon"].values, return_index=True
        )
        self.sort_lat_key = np.argsort(
            self.grid_info["lat"].values[np.sort(self.lats_indices)]
        )
        self.sort_lon_key = np.argsort(
            self.grid_info["lon"].values[np.sort(self.lons_indices)]
        )
        self.indextolatlon = {
            i: (
                self.grid_info["lat"].values[i % self.num_latlon],
                self.grid_info["lon"].values[i % self.num_latlon],
            )
            for i in range(self.num_latlon)
        }

        self.ml_backend = ml_backend
        self.tf = None
        self.torch = None

        if self.ml_backend == "tensorflow":
            self.successful_backend_import = False

            try:
                import tensorflow as tf

                self.tf = tf
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("Tensorflow is not installed.")

        elif self.ml_backend == "pytorch":
            self.successful_backend_import = False

            try:
                import torch

                self.torch = torch
                self.successful_backend_import = True
            except ImportError:
                raise ImportError("PyTorch is not installed.")

        def find_keys(dictionary, value):
            keys = []
            for key, val in dictionary.items():
                if val[0] == value:
                    keys.append(key)
            return keys

        indices_list = []
        for lat in self.lats:
            indices = find_keys(self.indextolatlon, lat)
            indices_list.append(indices)
        indices_list.sort(key=lambda x: x[0])
        self.lat_indices_list = indices_list

        self.hyam = self.grid_info["hyam"].values
        self.hybm = self.grid_info["hybm"].values
        self.p0 = 1e5  # code assumes this will always be a scalar
        self.ps_index = None

        self.pressure_grid_train = None
        self.pressure_grid_val = None
        self.pressure_grid_scoring = None
        self.pressure_grid_test = None

        self.dp_train = None
        self.dp_val = None
        self.dp_scoring = None
        self.dp_test = None

        self.train_regexps = None
        self.train_stride_sample = None
        self.train_filelist = None
        self.val_regexps = None
        self.val_stride_sample = None
        self.val_filelist = None
        self.scoring_regexps = None
        self.scoring_stride_sample = None
        self.scoring_filelist = None
        self.test_regexps = None
        self.test_stride_sample = None
        self.test_filelist = None

        self.full_vars = False

        # physical constants from E3SM_ROOT/share/util/shr_const_mod.F90
        self.grav = 9.80616  # acceleration of gravity ~ m/s^2
        self.cp = 1.00464e3  # specific heat of dry air   ~ J/kg/K
        self.lv = 2.501e6  # latent heat of evaporation ~ J/kg
        self.lf = 3.337e5  # latent heat of fusion      ~ J/kg
        self.lsub = self.lv + self.lf  # latent heat of sublimation ~ J/kg
        self.rho_air = (
            101325 / (6.02214e26 * 1.38065e-23 / 28.966) / 273.15
        )  # density of dry air at STP  ~ kg/m^3
        # ~ 1.2923182846924677
        # SHR_CONST_PSTD/(SHR_CONST_RDAIR*SHR_CONST_TKFRZ)
        # SHR_CONST_RDAIR   = SHR_CONST_RGAS/SHR_CONST_MWDAIR
        # SHR_CONST_RGAS    = SHR_CONST_AVOGAD*SHR_CONST_BOLTZ
        self.rho_h20 = 1.0e3  # density of fresh water     ~ kg/m^ 3

        self.v1_inputs = [
            "state_t",
            "state_q0001",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
        ]

        self.v1_outputs = [
            "ptend_t",
            "ptend_q0001",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.v2_inputs = [
            "state_t",
            "state_q0001",
            "state_q0002",
            "state_q0003",
            "state_u",
            "state_v",
            "state_ps",
            "pbuf_SOLIN",
            "pbuf_LHFLX",
            "pbuf_SHFLX",
            "pbuf_TAUX",
            "pbuf_TAUY",
            "pbuf_COSZRS",
            "cam_in_ALDIF",
            "cam_in_ALDIR",
            "cam_in_ASDIF",
            "cam_in_ASDIR",
            "cam_in_LWUP",
            "cam_in_ICEFRAC",
            "cam_in_LANDFRAC",
            "cam_in_OCNFRAC",
            "cam_in_SNOWHICE",
            "cam_in_SNOWHLAND",
            "pbuf_ozone",  # outside of the upper troposphere lower stratosphere (UTLS, corresponding to indices 5-21), variance in minimal for these last 3
            "pbuf_CH4",
            "pbuf_N2O",
        ]

        self.v2_outputs = [
            "ptend_t",
            "ptend_q0001",
            "ptend_q0002",
            "ptend_q0003",
            "ptend_u",
            "ptend_v",
            "cam_out_NETSW",
            "cam_out_FLWDS",
            "cam_out_PRECSC",
            "cam_out_PRECC",
            "cam_out_SOLS",
            "cam_out_SOLL",
            "cam_out_SOLSD",
            "cam_out_SOLLD",
        ]

        self.var_lens = {  # inputs
            "state_t": self.num_levels,
            "state_q0001": self.num_levels,
            "state_q0002": self.num_levels,
            "state_q0003": self.num_levels,
            "state_u": self.num_levels,
            "state_v": self.num_levels,
            "state_ps": 1,
            "pbuf_SOLIN": 1,
            "pbuf_LHFLX": 1,
            "pbuf_SHFLX": 1,
            "pbuf_TAUX": 1,
            "pbuf_TAUY": 1,
            "pbuf_COSZRS": 1,
            "cam_in_ALDIF": 1,
            "cam_in_ALDIR": 1,
            "cam_in_ASDIF": 1,
            "cam_in_ASDIR": 1,
            "cam_in_LWUP": 1,
            "cam_in_ICEFRAC": 1,
            "cam_in_LANDFRAC": 1,
            "cam_in_OCNFRAC": 1,
            "cam_in_SNOWHICE": 1,
            "cam_in_SNOWHLAND": 1,
            "pbuf_ozone": self.num_levels,
            "pbuf_CH4": self.num_levels,
            "pbuf_N2O": self.num_levels,
            # outputs
            "ptend_t": self.num_levels,
            "ptend_q0001": self.num_levels,
            "ptend_q0002": self.num_levels,
            "ptend_q0003": self.num_levels,
            "ptend_u": self.num_levels,
            "ptend_v": self.num_levels,
            "cam_out_NETSW": 1,
            "cam_out_FLWDS": 1,
            "cam_out_PRECSC": 1,
            "cam_out_PRECC": 1,
            "cam_out_SOLS": 1,
            "cam_out_SOLL": 1,
            "cam_out_SOLSD": 1,
            "cam_out_SOLLD": 1,
        }

        self.var_short_names = {
            "ptend_t": "$dT/dt$",
            "ptend_q0001": "$dq/dt$",
            "cam_out_NETSW": "NETSW",
            "cam_out_FLWDS": "FLWDS",
            "cam_out_PRECSC": "PRECSC",
            "cam_out_PRECC": "PRECC",
            "cam_out_SOLS": "SOLS",
            "cam_out_SOLL": "SOLL",
            "cam_out_SOLSD": "SOLSD",
            "cam_out_SOLLD": "SOLLD",
        }

        self.target_energy_conv = {
            "ptend_t": self.cp,
            "ptend_q0001": self.lv,
            "ptend_q0002": self.lv,
            "ptend_q0003": self.lv,
            "ptend_wind": None,
            "cam_out_NETSW": 1.0,
            "cam_out_FLWDS": 1.0,
            "cam_out_PRECSC": self.lv * self.rho_h20,
            "cam_out_PRECC": self.lv * self.rho_h20,
            "cam_out_SOLS": 1.0,
            "cam_out_SOLL": 1.0,
            "cam_out_SOLSD": 1.0,
            "cam_out_SOLLD": 1.0,
        }

        # for metrics

        self.input_train = None
        self.target_train = None
        self.preds_train = None
        self.samplepreds_train = None
        self.target_weighted_train = {}
        self.preds_weighted_train = {}
        self.samplepreds_weighted_train = {}
        self.metrics_train = []
        self.metrics_idx_train = {}
        self.metrics_var_train = {}

        self.input_val = None
        self.target_val = None
        self.preds_val = None
        self.samplepreds_val = None
        self.target_weighted_val = {}
        self.preds_weighted_val = {}
        self.samplepreds_weighted_val = {}
        self.metrics_val = []
        self.metrics_idx_val = {}
        self.metrics_var_val = {}

        self.input_scoring = None
        self.target_scoring = None
        self.preds_scoring = None
        self.samplepreds_scoring = None
        self.target_weighted_scoring = {}
        self.preds_weighted_scoring = {}
        self.samplepreds_weighted_scoring = {}
        self.metrics_scoring = []
        self.metrics_idx_scoring = {}
        self.metrics_var_scoring = {}

        self.input_test = None
        self.target_test = None
        self.preds_test = None
        self.samplepreds_test = None
        self.target_weighted_test = {}
        self.preds_weighted_test = {}
        self.samplepreds_weighted_test = {}
        self.metrics_test = []
        self.metrics_idx_test = {}
        self.metrics_var_test = {}

        self.model_names = []
        self.metrics_names = []
        self.num_CRPS = 32
        self.linecolors = ["#0072B2", "#E69F00", "#882255", "#009E73", "#D55E00"]

    def set_to_v2_vars(self):
        """
        This function sets the inputs and outputs to the V2 subset.
        It also indicates the index of the surface pressure variable.
        """
        self.input_vars = self.v2_inputs
        self.target_vars = self.v2_outputs
        self.ps_index = 360
        self.input_feature_len = 557
        self.target_feature_len = 368
        self.full_vars = True

    def get_xrdata(self, file, file_vars=None):
        """
        This function reads in a file and returns an xarray dataset with the variables specified.
        file_vars must be a list of strings.
        """
        ds = xr.open_dataset(file, engine="netcdf4")
        if file_vars is not None:
            ds = ds[file_vars]
        ds = ds.merge(self.grid_info[["lat", "lon"]])
        ds = ds.where((ds["lat"] > -999) * (ds["lat"] < 999), drop=True)
        ds = ds.where((ds["lon"] > -999) * (ds["lon"] < 999), drop=True)
        return ds

    def get_input(self, input_file):
        """
        This function reads in a file and returns an xarray dataset with the input variables for the emulator.
        """
        # read inputs
        return self.get_xrdata(input_file, self.input_vars)

    def get_target(self, ds_input, input_file):
        """
        This function reads in a file and returns an xarray dataset with the target variables for the emulator.
        """
        # read inputs
        ds_target = self.get_xrdata(input_file.replace(".mli.", ".mlo."))
        # each timestep is 20 minutes which corresponds to 1200 seconds
        ds_target["ptend_t"] = (
            ds_target["state_t"] - ds_input["state_t"]
        ) / 1200  # T tendency [K/s]
        ds_target["ptend_q0001"] = (
            ds_target["state_q0001"] - ds_input["state_q0001"]
        ) / 1200  # Q tendency [kg/kg/s]
        if self.full_vars:
            ds_target["ptend_q0002"] = (
                ds_target["state_q0002"] - ds_input["state_q0002"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_q0003"] = (
                ds_target["state_q0003"] - ds_input["state_q0003"]
            ) / 1200  # Q tendency [kg/kg/s]
            ds_target["ptend_u"] = (
                ds_target["state_u"] - ds_input["state_u"]
            ) / 1200  # U tendency [m/s/s]
            ds_target["ptend_v"] = (
                ds_target["state_v"] - ds_input["state_v"]
            ) / 1200  # V tendency [m/s/s]
        ds_target = ds_target[self.target_vars]
        return ds_target

    def gen(self, file):
        # read inputs
        ds_input = self.get_input(file)
        # read targets
        ds_target = self.get_target(ds_input, file)

        ds_input = ds_input.drop_vars(["lat", "lon"])

        # stack
        # ds = ds.stack({'batch':{'sample','ncol'}})
        ds_input = ds_input.stack({"batch": {"ncol"}})
        ds_input = ds_input.to_stacked_array("mlvar", sample_dims=["batch"], name="mli")
        # dso = dso.stack({'batch':{'sample','ncol'}})
        ds_target = ds_target.stack({"batch": {"ncol"}})
        ds_target = ds_target.to_stacked_array(
            "mlvar", sample_dims=["batch"], name="mlo"
        )
        return (ds_input.values, ds_target.values)

In [3]:
grid_path = "/kaggle/working/misc/grid_info/ClimSim_low-res_grid-info.nc"
norm_path = "/kaggle/working/misc/preprocessing/normalizations/"
grid_info = xr.open_dataset(grid_path)
input_mean = xr.open_dataset(norm_path + "inputs/input_mean.nc")
input_max = xr.open_dataset(norm_path + "inputs/input_max.nc")
input_min = xr.open_dataset(norm_path + "inputs/input_min.nc")
output_scale = xr.open_dataset(norm_path + "outputs/output_scale.nc")

data_utils = DataUtils(
    grid_info=grid_info,
    input_mean=input_mean,
    input_max=input_max,
    input_min=input_min,
    output_scale=output_scale,
)
data_utils.set_to_v2_vars()
data_utils.normalize = False

In [8]:
!ls /kaggle/working/input

E3SM-MMF.mli.0001-02-01-00000.nc     make_webdataset_batch	test.parquet
leap-atmospheric-physics-ai-climsim  sample_submission.parquet	train.parquet


In [9]:
input_data, output_data = data_utils.gen(
    "/kaggle/working/input/E3SM-MMF.mli.0001-02-01-00000.nc"
)

input_data.shape, output_data.shape

((384, 557), (384, 368))

In [12]:
# 提供データを確認
import polars as pl

df = pl.read_parquet("/kaggle/working/input/train.parquet", n_rows=384)

In [13]:
df

sample_id,state_t_0,state_t_1,state_t_2,state_t_3,state_t_4,state_t_5,state_t_6,state_t_7,state_t_8,state_t_9,state_t_10,state_t_11,state_t_12,state_t_13,state_t_14,state_t_15,state_t_16,state_t_17,state_t_18,state_t_19,state_t_20,state_t_21,state_t_22,state_t_23,state_t_24,state_t_25,state_t_26,state_t_27,state_t_28,state_t_29,state_t_30,state_t_31,state_t_32,state_t_33,state_t_34,state_t_35,…,ptend_v_31,ptend_v_32,ptend_v_33,ptend_v_34,ptend_v_35,ptend_v_36,ptend_v_37,ptend_v_38,ptend_v_39,ptend_v_40,ptend_v_41,ptend_v_42,ptend_v_43,ptend_v_44,ptend_v_45,ptend_v_46,ptend_v_47,ptend_v_48,ptend_v_49,ptend_v_50,ptend_v_51,ptend_v_52,ptend_v_53,ptend_v_54,ptend_v_55,ptend_v_56,ptend_v_57,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""train_0""",213.806117,222.244454,229.259879,245.743959,258.337482,265.772467,263.978456,254.879872,243.946147,236.70699,230.934476,226.459626,222.014096,216.91297,210.688514,205.181518,201.649949,199.354256,196.582987,198.08732,200.30267,204.5107,209.294262,214.299322,219.180883,223.644692,227.899269,231.883467,235.859155,239.75449,243.663599,247.494148,251.145756,254.734832,258.118795,261.322883,…,6.3101e-7,3.8035e-7,2.6436e-8,-7.3957e-8,-1.0226e-7,-1.5508e-7,1.8131e-8,1.8428e-7,-1.9102e-8,1.8966e-7,-3.4925e-8,-2.5049e-7,-3.5119e-7,3.1814e-8,-0.000001,-0.000002,-7.6918e-7,-0.000007,-0.000007,-0.000021,-0.000008,0.000024,0.00002,0.000008,0.000003,2.1486e-7,-0.000001,-6.1321e-7,0.000001,0.0,349.564325,0.0,2.9477e-12,0.0,0.0,0.0,0.0
"""train_1""",213.17743,225.851064,229.663618,246.828333,261.026415,269.078431,267.736565,256.896227,244.169421,236.844423,231.586369,227.873491,224.125186,219.864133,214.768666,209.773682,206.593758,205.144601,202.21968,203.165579,203.691769,206.691885,210.018967,213.651746,217.22671,220.732834,224.271241,227.879259,231.523113,235.206556,238.911192,242.569836,246.095623,249.661886,253.170195,256.545214,…,2.3532e-7,-0.000002,4.7710e-7,-4.6249e-7,-4.7910e-7,3.4284e-7,1.5265e-7,-3.6109e-7,4.7664e-7,2.6672e-7,-1.8301e-7,-0.000001,-3.2145e-8,0.000002,-0.000002,-0.000005,-0.000006,-0.000002,-0.000009,-0.000015,-0.000008,0.000017,0.000023,0.000013,0.000004,7.6476e-7,-8.6454e-7,3.5609e-7,9.9849e-7,0.0,335.204086,0.0,3.1384e-9,0.0,0.0,0.0,0.0
"""train_2""",217.105685,220.448106,234.653398,244.422951,254.023818,259.651472,257.709514,251.064513,241.8796,234.487318,228.294373,223.660529,219.108751,214.820508,208.639566,201.055995,193.462408,190.267832,188.690119,189.957238,196.482059,203.461783,210.517854,217.264497,223.298206,228.871994,234.369621,239.208271,243.596733,247.688112,251.639426,255.233656,258.531051,261.569665,264.470169,267.068563,…,-0.000002,-6.3580e-7,-0.000001,-5.8939e-7,-6.8356e-7,1.3912e-7,-1.8613e-7,-3.5363e-7,5.9231e-7,-5.4433e-7,-4.1161e-7,-7.6074e-7,-3.6474e-7,5.2729e-7,0.000002,0.000006,0.000013,0.000012,0.000016,0.000005,-0.000005,-0.00002,-0.00004,-0.000035,-0.00002,-0.000013,-0.000004,0.000002,0.000062,0.0,401.70934,0.0,7.4242e-9,0.0,0.0,0.0,0.0
"""train_3""",217.773994,225.611775,234.104091,247.745365,257.411402,263.470947,261.131775,253.30325,242.316814,234.396266,227.95502,223.999858,219.658845,215.24492,210.214695,204.137721,196.509274,191.893671,189.929401,190.806367,196.69688,203.68075,210.684974,217.256992,223.168849,228.660408,233.845497,238.535216,242.74732,246.688901,250.429055,253.907015,257.260424,260.436627,263.363255,266.154815,…,-0.000002,-0.000002,-8.0713e-7,-7.1824e-8,-1.6555e-8,-1.6542e-7,3.0559e-7,1.8800e-8,-1.3640e-7,2.0172e-7,-2.4224e-7,2.5207e-7,0.000001,-7.9637e-8,0.000002,0.000005,0.000003,-0.000002,-0.000003,-0.000001,0.000004,0.000013,-0.000001,-0.000009,-0.000008,-0.000006,-0.000005,-0.000009,0.00001,0.0,400.230177,0.0,2.5341e-8,0.0,0.0,0.0,0.0
"""train_4""",216.349337,230.526083,233.650252,248.196013,262.50073,270.055663,268.863606,258.161645,244.44262,236.779096,231.508378,227.968412,224.863747,221.415977,217.325766,213.113691,209.607872,207.829591,204.866553,205.55398,205.036638,207.604275,210.501407,213.802681,217.053629,220.416544,223.803482,227.157857,230.464429,233.723924,236.938842,240.118405,243.26515,246.338046,249.61741,253.012313,…,-5.5779e-7,-1.6262e-7,9.1616e-8,1.4862e-7,-3.6454e-7,-0.000001,0.000001,-6.2991e-8,-4.6701e-8,1.0520e-7,-2.6200e-7,2.5540e-7,-4.7391e-8,2.5992e-7,-7.2915e-8,5.6063e-7,8.0119e-7,0.000001,8.1273e-7,-0.000002,-0.000005,-0.000006,0.00001,0.000009,-0.000001,-0.000002,-0.000001,0.000002,-0.000006,0.0,321.96047,0.0,3.3774e-10,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""train_379""",218.918341,238.736381,235.083042,239.398846,247.69602,250.190803,245.385246,239.039703,233.492732,229.572704,227.514843,226.821063,226.672404,226.232272,225.178405,223.715826,222.177653,220.823551,219.132981,217.754679,215.809467,214.382627,213.345526,212.987632,213.047825,213.494792,214.271168,215.550844,216.935316,218.407531,220.035177,221.746289,223.670077,225.736659,227.944913,230.176195,…,-5.6138e-7,-0.000001,-0.000002,-4.4419e-7,-0.000005,-0.000002,-0.000004,-0.000006,-0.000011,-0.00002,-0.000022,-0.000023,-0.000051,-0.00006,-0.000026,-0.000004,-0.000006,0.000003,0.000003,0.00002,0.000121,0.000062,0.000109,0.00007,0.00004,0.000019,0.000007,-0.000003,-0.000039,270.770143,253.784134,2.6165e-8,2.9391e-8,60.69613,143.212514,74.885183,26.505918
"""train_380""",216.149716,227.89558,228.541675,237.983346,244.243784,248.517983,249.283739,244.65311,238.148522,233.131319,230.515876,229.611207,228.918458,227.713537,225.665493,223.14387,220.615311,218.62191,217.00241,215.17604,212.940687,211.506115,210.36118,209.781008,209.451575,209.665216,210.302009,211.280728,212.40955,213.733524,215.205422,216.871771,218.699373,220.631188,222.726628,224.966437,…,6.2843e-9,-1.4247e-9,-2.1706e-13,5.0873e-11,2.5653e-11,1.3400e-10,2.9150e-11,2.2532e-11,1.3242e-11,-4.6344e-11,-1.4960e-10,1.2900e-10,-5.9406e-11,-1.8364e-10,4.4491e-11,-2.0666e-11,-1.0761e-11,-3.8318e-11,-3.3614e-11,-3.3864e-13,7.8466e-12,1.8776e-9,-2.6124e-9,5.1719e-8,-5.0648e-8,8.1278e-11,-1.1372e-10,3.9984e-8,-3.9542e-8,119.347959,146.71273,1.8775e-9,1.8775e-9,59.341941,111.489698,36.864229,9.006725
"""train_381""",223.078504,229.800374,229.883811,240.159283,247.198049,252.207608,252.542542,245.663732,235.971018,229.462776,226.749173,225.807962,224.761639,223.482462,221.784814,219.884257,217.753344,215.83833,213.907749,212.162737,210.277699,209.063224,208.902412,209.523831,210.306954,211.579001,213.152697,215.024797,216.929116,219.021818,221.267676,223.64581,226.142734,228.602392,231.042,233.456907,…,-6.1129e-8,6.2284e-8,-1.1450e-8,3.8386e-8,-1.5003e-8,1.0541e-8,2.4069e-8,1.6983e-9,-1.2242e-8,1.0734e-7,-7.6397e-8,2.3397e-8,4.7098e-9,2.6074e-8,2.3473e-8,-1.8220e-8,-2.6194e-8,1.3016e-8,-1.2908e-9,8.9186e-9,3.8723e-8,4.4155e-8,2.6523e-8,8.1703e-8,0.000001,-0.000003,9.3715e-7,-0.000001,0.000001,139.179531,174.161188,3.7386e-12,3.7386e-12,67.272294,117.615398,33.032472,4.188111
"""train_382""",222.500171,237.110982,236.301461,241.17534,249.168572,252.31233,248.012214,240.605567,233.376437,228.347834,225.891684,224.881954,224.2397,223.697266,222.977631,221.818979,220.440269,219.329974,218.156109,217.121354,215.815931,214.974466,214.538968,214.606346,214.955326,215.674263,216.598305,217.821772,219.21321,220.700202,222.237867,223.824736,225.502936,227.197501,229.019952,231.008786,…,3.2900e-8,4.3772e-8,1.1370e-7,8.3488e-8,-2.9141e-8,3.0469e-8,1.4317e-7,-3.2929e-7,-8.8086e-7,3.4627e-7,-2.7832e-7,-2.3542e-7,-1.5264e-7,-4.9499e-7,-4.3423e-7,-0.000001,-0.000006,-0.000011,-0.000015,0.000006,0.000009,0.000008,0.00001,0.000016,0.000013,0.000012,0.000009,0.000006,-0.000042,291.343969,211.344276,4.8667e-9,4.8667e-9,56.127465,155.375447,88.187456,33.458496


In [25]:
my_df = pl.concat(
    [
        df.select(["sample_id"]),
        pl.from_numpy(np.delete(input_data, 375, 1), schema=df.columns[1:557]),
        pl.from_numpy(output_data, schema=df.columns[557:]),
    ],
    how="horizontal",
)

In [26]:
df.equals(my_df)

True

In [27]:
my_df[1,1]=0

In [28]:
df.equals(my_df)

False