In [56]:
# parameters
config_dir = "../experiments"
exp_name = "601_attention/small_no_hack"

In [57]:
%cd /kaggle/working

from pathlib import Path

from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(
    version_base=None, config_path=f"{config_dir}/{exp_name.split('/')[0]}"
):
    cfg = compose(
        config_name="config.yaml",
        overrides=[f"exp={exp_name.split('/')[-1]}"],
        return_hydra_config=True,
    )
import pickle

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns

# 定数定義
output_dir = Path(f"output/experiments/{exp_name}")
gcs_path = f"gs://{cfg.dir.gcs_bucket}/{cfg.dir.gcs_base_dir}/experiments/{exp_name}/"

# カラムの定義
single_targets = [
    "cam_out_NETSW",
    "cam_out_FLWDS",
    "cam_out_PRECSC",
    "cam_out_PRECC",
    "cam_out_SOLS",
    "cam_out_SOLL",
    "cam_out_SOLSD",
    "cam_out_SOLLD",
]
seq_targets = [
    "ptend_t",
    "ptend_q0001",
    "ptend_q0002",
    "ptend_q0003",
    "ptend_u",
    "ptend_v",
]
target_columns = []
for col in seq_targets:
    for i in range(60):
        target_columns.append(f"{col}_{i}")
target_columns.extend(single_targets)

/kaggle/working


In [58]:
kami_sub = pl.read_parquet(gcs_path + "submission.parquet", retries=5)

In [6]:
takoi_sub = pl.read_parquet(
    "gs://kaggle-leap/kami/ex123_124_130_131_133_134_135_ensemble.parquet"
)

In [59]:
import pandas as pd

from utils.metric import score

preds = kami_sub[:, 1:].to_numpy()
labels = takoi_sub[:, 1:].to_numpy()

_predict_df = pd.DataFrame(
    preds, columns=[i for i in range(preds.shape[1])]
).reset_index()
_label_df = pd.DataFrame(
    labels, columns=[i for i in range(labels.shape[1])]
).reset_index()
r2_scores = score(_label_df, _predict_df, "index", multioutput="raw_values")

r2_score_dict = {
    col: r2 for col, r2 in dict(zip(cfg.cols.col_names, r2_scores)).items()
}

for key, val in r2_score_dict.items():
    print(key, val)

ptend_t_0 0.9185702155384129
ptend_t_1 0.9528992002074299
ptend_t_2 0.9828407864211753
ptend_t_3 0.9942554347994033
ptend_t_4 0.9964783449868281
ptend_t_5 0.9970443050201641
ptend_t_6 0.9971965931322151
ptend_t_7 0.9961321977435666
ptend_t_8 0.9943229028359083
ptend_t_9 0.9930326345700681
ptend_t_10 0.9920720801070758
ptend_t_11 0.9913630090487604
ptend_t_12 0.9902634073922278
ptend_t_13 0.9899976045853742
ptend_t_14 0.9888748566905682
ptend_t_15 0.987185297548569
ptend_t_16 0.9706050034734586
ptend_t_17 0.8465957471441389
ptend_t_18 0.8115275813480717
ptend_t_19 0.7856576237851144
ptend_t_20 0.8223131834272461
ptend_t_21 0.8385552129741242
ptend_t_22 0.866004573468316
ptend_t_23 0.8851055301014309
ptend_t_24 0.9062205762227692
ptend_t_25 0.9305878262490458
ptend_t_26 0.949307397782295
ptend_t_27 0.9599531332408758
ptend_t_28 0.9670261475076736
ptend_t_29 0.969754480474439
ptend_t_30 0.9705633549822104
ptend_t_31 0.9684300358508484
ptend_t_32 0.9642946738284872
ptend_t_33 0.95964836681

In [54]:
kami_sub[383:, 133:142].head()

ptend_q0002_12,ptend_q0002_13,ptend_q0002_14,ptend_q0002_15,ptend_q0002_16,ptend_q0002_17,ptend_q0002_18,ptend_q0002_19,ptend_q0002_20
f64,f64,f64,f64,f64,f64,f64,f64,f64
-4.3269999999999994e-46,-2.489e-51,-1.7290000000000002e-56,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-1.3929999999999998e-44,-3.5045e-50,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.2229e-46,-2.0199999999999998e-53,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.1935e-47,-1.766e-52,-5.7912999999999995e-58,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.1047e-49,-3.0836e-55,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0


In [55]:
takoi_sub[383:, 133:142].head()

ptend_q0002_12,ptend_q0002_13,ptend_q0002_14,ptend_q0002_15,ptend_q0002_16,ptend_q0002_17,ptend_q0002_18,ptend_q0002_19,ptend_q0002_20
f64,f64,f64,f64,f64,f64,f64,f64,f64
-4.3269999999999994e-46,-2.489e-51,-1.7290000000000002e-56,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-1.3929999999999998e-44,-3.5045e-50,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.2229e-46,-2.0199999999999998e-53,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.1935e-47,-1.766e-52,-5.7912999999999995e-58,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
-2.1047e-49,-3.0836e-55,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0


In [21]:
data = np.load("input/sim_data/test/id384.npz")
data

NpzFile 'input/sim_data/test/id384.npz' with keys: x, y, sim_x

In [23]:
test_df = pl.read_parquet("input/test.parquet", n_rows=500)

In [30]:
data["x"]

array([[2.08036108e+02, 2.27158595e+02, 2.30732909e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07],
       [2.13779131e+02, 2.22641431e+02, 2.34139642e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07],
       [2.20849426e+02, 2.22368835e+02, 2.36797136e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07],
       ...,
       [2.11430493e+02, 2.36875820e+02, 2.53411973e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07],
       [2.37398034e+02, 2.25109263e+02, 2.36295219e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07],
       [2.14821313e+02, 2.31346606e+02, 2.38470580e+02, ...,
        4.90858386e-07, 4.90858386e-07, 4.90858386e-07]])

In [28]:
test_df[384:, :]

sample_id,state_t_0,state_t_1,state_t_2,state_t_3,state_t_4,state_t_5,state_t_6,state_t_7,state_t_8,state_t_9,state_t_10,state_t_11,state_t_12,state_t_13,state_t_14,state_t_15,state_t_16,state_t_17,state_t_18,state_t_19,state_t_20,state_t_21,state_t_22,state_t_23,state_t_24,state_t_25,state_t_26,state_t_27,state_t_28,state_t_29,state_t_30,state_t_31,state_t_32,state_t_33,state_t_34,state_t_35,…,pbuf_N2O_23,pbuf_N2O_24,pbuf_N2O_25,pbuf_N2O_26,pbuf_N2O_27,pbuf_N2O_28,pbuf_N2O_29,pbuf_N2O_30,pbuf_N2O_31,pbuf_N2O_32,pbuf_N2O_33,pbuf_N2O_34,pbuf_N2O_35,pbuf_N2O_36,pbuf_N2O_37,pbuf_N2O_38,pbuf_N2O_39,pbuf_N2O_40,pbuf_N2O_41,pbuf_N2O_42,pbuf_N2O_43,pbuf_N2O_44,pbuf_N2O_45,pbuf_N2O_46,pbuf_N2O_47,pbuf_N2O_48,pbuf_N2O_49,pbuf_N2O_50,pbuf_N2O_51,pbuf_N2O_52,pbuf_N2O_53,pbuf_N2O_54,pbuf_N2O_55,pbuf_N2O_56,pbuf_N2O_57,pbuf_N2O_58,pbuf_N2O_59
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""test_100455""",208.036108,227.158595,230.732909,232.859734,238.974486,242.371338,238.3371,232.301199,225.518653,222.664938,220.931735,219.55473,218.522116,217.81895,216.910092,216.251177,215.317696,215.025365,212.180458,212.413629,212.314927,212.411587,213.060662,214.098834,215.240623,216.654261,218.504662,221.013468,224.031618,227.345943,230.829695,234.36038,237.889765,241.301154,244.48481,247.469463,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100456""",213.779131,222.641431,234.139642,248.09112,259.115334,265.345904,263.341407,255.022751,244.659451,237.155154,231.599541,227.705521,224.966578,223.26269,222.053267,221.173859,220.142456,219.313927,215.369519,214.711592,210.829823,208.767038,208.560237,209.458643,211.015524,212.965759,215.249396,217.730036,220.361509,223.246607,226.559564,230.073359,233.629443,237.156588,240.694411,244.252419,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100458""",220.849426,222.368835,236.797136,247.46863,257.170938,265.812413,261.597356,250.263043,239.940483,233.877353,228.405961,222.707502,216.399564,212.72141,207.88912,200.191655,189.471222,186.1124,189.169765,191.528516,198.259119,204.388877,209.790096,215.67299,221.763447,227.443738,232.92412,237.888033,242.479158,246.826067,251.019566,255.024617,258.622324,261.717114,264.686033,267.593346,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100459""",210.303844,217.963133,230.294393,240.131893,255.06261,265.784184,264.537377,250.079453,237.615815,232.274324,227.024494,222.443966,218.326228,213.79039,206.145042,198.183557,192.019738,186.633447,189.394467,192.748124,197.786222,203.243175,209.533219,216.201298,222.496667,227.791011,232.613324,237.422549,242.126486,246.770446,251.273411,255.392277,259.238923,262.794763,266.115944,269.019767,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_10046""",209.322507,231.10362,241.325142,247.55673,253.734333,259.701721,257.904731,252.741085,243.519753,234.41549,226.515931,221.011199,215.993577,213.310021,210.555016,203.85873,190.23026,183.499425,188.165277,194.902409,202.116503,209.452504,216.132954,220.993722,226.072093,231.817729,236.856361,241.249057,246.381163,251.236683,255.780915,259.843325,263.341175,266.320779,269.192216,272.120027,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""test_100583""",210.896299,226.949524,246.328341,256.774321,255.693123,249.320052,240.589772,231.478061,221.625913,213.835722,209.015114,206.09005,204.614019,204.247701,204.475533,204.950203,205.595503,205.715511,205.036733,206.857188,207.152173,208.374523,209.03579,209.805986,210.707882,212.129655,213.904188,215.813468,218.053704,220.485115,222.971111,225.518342,228.08996,230.781366,233.680939,236.411408,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100585""",219.242308,213.683807,231.43643,241.752144,253.337549,259.810693,259.377023,252.394073,242.842122,236.111489,230.129933,225.202313,220.335361,214.914411,206.708639,196.175395,189.303298,189.538013,194.371118,200.519663,207.496444,213.907788,220.152382,227.363771,233.58624,238.934276,243.109843,246.77621,250.457702,254.143704,257.843862,261.447136,264.800699,267.952459,270.556472,273.150542,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100587""",188.641033,218.682586,251.38984,257.156122,247.991925,239.392321,231.504708,224.324894,215.861296,209.149528,205.285833,203.089489,201.798647,201.328213,201.43404,201.989683,202.856485,203.602909,202.674878,203.443471,203.642181,203.867029,203.985044,204.413502,205.363022,206.859562,208.660719,210.658955,212.837752,215.06403,217.227709,219.481502,221.829504,224.227651,226.681746,229.133459,…,4.7192e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7
"""test_100588""",209.368365,216.826475,229.871042,243.125089,254.889736,261.936689,262.032798,250.627573,237.334386,231.390554,225.344048,219.947537,215.931143,214.46562,210.108701,202.036557,193.034664,187.446768,187.75894,192.271572,199.270371,204.088142,210.325125,216.104642,222.276893,227.444603,232.472156,237.190846,241.610457,245.953051,250.45319,254.673098,258.234254,261.776959,264.855822,267.896128,…,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7,4.9086e-7


In [48]:
import re
from glob import glob


def sort_key(s):
    # 数値部分を抽出
    match = re.search(r"\d+", s)
    return int(match.group()) if match else 0


paths = sorted(glob("input/sim_data/test/*"), key=sort_key)

In [49]:
paths[0:3]

['input/sim_data/test/id0.npz',
 'input/sim_data/test/id384.npz',
 'input/sim_data/test/id768.npz']