In [1]:
!rm -r /kaggle/working/*
%cd /kaggle/working

/kaggle/working


In [2]:
import os
import sys

PACKAGE_DIR = "/kaggle/src"
sys.path.append(PACKAGE_DIR)
sys.path.append(os.path.join(PACKAGE_DIR, "Penguin-ML-Library"))

In [3]:
import yaml
from penguinml.utils.logger import get_logger, init_logger
from penguinml.utils.set_seed import seed_base

MODEL_NAME = "xgboost"
CFG = yaml.safe_load(open(os.path.join(PACKAGE_DIR, "config.yaml"), "r"))
print(CFG[MODEL_NAME]["execution"]["exp_id"])
CFG["output_dir"] = f"/kaggle/output/{CFG[MODEL_NAME]['execution']['exp_id']}"
!rm -r {CFG["output_dir"]}
os.makedirs(CFG["output_dir"], exist_ok=True)

init_logger(f"{ CFG[MODEL_NAME]['execution']['exp_id']}.log")
logger = get_logger("main")
seed_base(CFG[MODEL_NAME]["execution"]["seed"])

2024-11-16 08:25:00.471629: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-16 08:25:00.498894: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


exp_004


  pid, fd = os.forkpty()
set seed: 46


In [4]:
import warnings

import numpy as np
import polars as pl

warnings.filterwarnings("ignore")

In [5]:
train = pl.read_csv(os.path.join(CFG["dataset"]["competition_dir"], "train_features.csv"))
train = (
    train.with_columns(
        pl.col("ID").str.split_exact("_", n=1).struct.rename_fields(["sceneID", "offset"]).alias("fields")
    )
    .unnest("fields")
    .with_columns(pl.col("offset").cast(pl.Float32))
)
print(train.shape)
train.head(1)

(43371, 32)


ID,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,x_0,y_0,z_0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5,sceneID,offset
str,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f32
"""00066be8e20318869c38c66be46663…",5.701526,1.538456,-2.165777,-139.0,0.0,False,0.25,True,"""drive""",False,False,2.82959,0.032226,0.045187,6.231999,0.065895,0.107974,9.785009,0.124972,0.203649,13.485472,0.163448,0.302818,17.574227,0.174289,0.406331,21.951269,0.199503,0.485079,"""00066be8e20318869c38c66be46663…",320.0


## 特徴量生成


In [6]:
from penguinml.utils.contena import FeatureContena

features = FeatureContena()
features.add_num_features(
    [
        "vEgo",
        "aEgo",
        "steeringAngleDeg",
        "steeringTorque",
        "brake",
        "brakePressed",
        "gas",
        "gasPressed",
        "leftBlinker",
        "rightBlinker",
        "offset",
    ]
)
features.add_cat_features(["gearShifter"])

In [7]:
# train = train.with_columns(
#     (pl.col("vEgo") / pl.col("aEgo")).alias("vEgo/aEgo"),
# )
# features.add_num_features(["vEgo/aEgo"])

## ターゲット列を分解


In [8]:
from const import TARGET_COLS

train = (
    train.unpivot(index="ID", on=TARGET_COLS, variable_name="target_name", value_name="target")
    .join(
        train.drop(TARGET_COLS),
        on="ID",
        how="left",
    )
    .with_columns(
        (
            pl.col("target_name").map_elements(lambda x: float(x.split("_")[1]), return_dtype=pl.Float32) * 0.5 + 0.5
        ).alias("dt"),
        pl.col("target_name").map_elements(lambda x: x.split("_")[0], return_dtype=str).alias("xyz"),
    )
    .with_columns(
        pl.col("target_name").alias("target_name_original"),
    )
)
features.add_cat_features(["target_name", "xyz"])
train.head(1)

ID,target_name,target,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,sceneID,offset,dt,xyz,target_name_original
str,str,f64,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,str,f32,f32,str,str
"""00066be8e20318869c38c66be46663…","""x_0""",2.82959,5.701526,1.538456,-2.165777,-139.0,0.0,False,0.25,True,"""drive""",False,False,"""00066be8e20318869c38c66be46663…",320.0,0.5,"""x""","""x_0"""


In [9]:
# dt秒後の特徴
train = train.with_columns(
    # vt
    (pl.col("vEgo") * pl.col("dt").cast(pl.Float32)).alias("linear_movement@dt"),
    # vt + 0.5at^2
    ((pl.col("vEgo") + 0.5 * pl.col("aEgo") * pl.col("dt").cast(pl.Float32) ** 2).alias("movement@dt")),
    # v + at
    (pl.col("vEgo") + pl.col("aEgo") * pl.col("dt").cast(pl.Float32)).alias("velocity@dt"),
    # # cos
    # (pl.col("steeringAngleDeg").map_elements(lambda x: np.cos(np.deg2rad(x)), return_dtype=pl.Float32)).alias("cos"),
    # # sin
    # (pl.col("steeringAngleDeg").map_elements(lambda x: np.sin(np.deg2rad(x)), return_dtype=pl.Float32)).alias("sin"),
).with_columns(
    # (pl.col("movement@dt") * pl.col("cos")).alias("movement@dt*cos"),
    # (pl.col("movement@dt") * pl.col("sin")).alias("movement@dt*sin"),
    # (pl.col("linear_movement@dt") * pl.col("cos")).alias("linear_movement@dt*cos"),
    # (pl.col("linear_movement@dt") * pl.col("sin")).alias("linear_movement@dt*sin"),
    # (pl.col("velocity@dt") * pl.col("cos")).alias("velocity@dt*cos"),
    # (pl.col("velocity@dt") * pl.col("sin")).alias("velocity@dt*sin"),
)

features.add_num_features(
    [
        "linear_movement@dt",
        "movement@dt",
        "velocity@dt",
        # "cos",
        # "sin",
        # "movement@dt*cos",
        # "movement@dt*sin",
        # "linear_movement@dt*cos",
        # "linear_movement@dt*sin",
        # "velocity@dt*cos",
        # "velocity@dt*sin",
    ]
)

In [10]:
# シーン内の集約特徴量
for c in features.num_features():
    train = train.with_columns(
        pl.col(c).mean().over("sceneID").alias(f"{c}_mean"),
        pl.col(c).std().over("sceneID").alias(f"{c}_std"),
        pl.col(c).max().over("sceneID").alias(f"{c}_max"),
        pl.col(c).min().over("sceneID").alias(f"{c}_min"),
    )
    features.add_num_features([f"{c}_mean", f"{c}_std", f"{c}_max", f"{c}_min"])

In [11]:
from const import CATEGORY_MAPPING

for c in features.num_features():
    train = train.with_columns(pl.col(c).cast(pl.Float32))

for c in features.cat_features():
    mapping = CATEGORY_MAPPING[c]
    train = train.with_columns(pl.col(c).replace_strict(mapping).cast(pl.Int32))

## CV Split


In [12]:
train_folds = pl.read_csv(CFG["dataset"]["train_fold_path"])
train = train.join(train_folds, on="sceneID", how="left")
assert train["fold"].null_count() == 0

## Training


In [13]:
from penguinml.gbdt.xgboost import fit_xgb, inference_xgb
from tqdm import tqdm

dfs = []
for c in tqdm(TARGET_COLS):
    this_df = train.filter(pl.col("target_name_original") == c)
    oof, models = fit_xgb(
        data=this_df,
        features=features,
        params=CFG[MODEL_NAME]["params"],
        target_col="target",
        fold_col="fold",
        target_type="regression",
        verbose=500,
    )
    this_df = this_df.with_columns(pl.Series("oof", oof))
    dfs.append(this_df)
train = pl.concat(dfs)

  0%|          | 0/18 [00:00<?, ?it/s]

== fold 0 ==
[0]	validation_0-mae:2.69955
[500]	validation_0-mae:0.06571
[1000]	validation_0-mae:0.06534
[1500]	validation_0-mae:0.06530
[1629]	validation_0-mae:0.06531
== fold 1 ==
[0]	validation_0-mae:2.70526
[500]	validation_0-mae:0.06490
[1000]	validation_0-mae:0.06452
[1500]	validation_0-mae:0.06436
[1607]	validation_0-mae:0.06434
== fold 2 ==
[0]	validation_0-mae:2.73485
[500]	validation_0-mae:0.06744
[1000]	validation_0-mae:0.06713
[1500]	validation_0-mae:0.06701
[1900]	validation_0-mae:0.06698
== fold 3 ==
[0]	validation_0-mae:2.71804
[500]	validation_0-mae:0.06608
[1000]	validation_0-mae:0.06564
[1500]	validation_0-mae:0.06543
[1899]	validation_0-mae:0.06543
== fold 4 ==
[0]	validation_0-mae:2.65464
[500]	validation_0-mae:0.06628
[1000]	validation_0-mae:0.06595
[1432]	validation_0-mae:0.06587


  6%|▌         | 1/18 [00:38<11:02, 38.98s/it]

== fold 0 ==
[0]	validation_0-mae:0.05748
[500]	validation_0-mae:0.03227
[1000]	validation_0-mae:0.03221
[1089]	validation_0-mae:0.03220
== fold 1 ==
[0]	validation_0-mae:0.05904
[500]	validation_0-mae:0.03154
[843]	validation_0-mae:0.03153
== fold 2 ==
[0]	validation_0-mae:0.06048
[500]	validation_0-mae:0.03267
[1000]	validation_0-mae:0.03265
[1012]	validation_0-mae:0.03266
== fold 3 ==
[0]	validation_0-mae:0.05639
[500]	validation_0-mae:0.03193
[1000]	validation_0-mae:0.03183
[1113]	validation_0-mae:0.03184
== fold 4 ==
[0]	validation_0-mae:0.06161
[500]	validation_0-mae:0.03337
[905]	validation_0-mae:0.03332


 11%|█         | 2/18 [01:04<08:14, 30.90s/it]

== fold 0 ==
[0]	validation_0-mae:0.02615
[368]	validation_0-mae:0.02575
== fold 1 ==
[0]	validation_0-mae:0.02584
[365]	validation_0-mae:0.02544
== fold 2 ==
[0]	validation_0-mae:0.02683
[404]	validation_0-mae:0.02637
== fold 3 ==
[0]	validation_0-mae:0.02685
[445]	validation_0-mae:0.02646
== fold 4 ==
[0]	validation_0-mae:0.02610
[333]	validation_0-mae:0.02565


 17%|█▋        | 3/18 [01:15<05:27, 21.83s/it]

== fold 0 ==
[0]	validation_0-mae:5.69810
[500]	validation_0-mae:0.14304
[1000]	validation_0-mae:0.14147
[1500]	validation_0-mae:0.14100
[1636]	validation_0-mae:0.14098
== fold 1 ==
[0]	validation_0-mae:5.70668
[500]	validation_0-mae:0.14323
[1000]	validation_0-mae:0.14185
[1500]	validation_0-mae:0.14131
[2000]	validation_0-mae:0.14103
[2500]	validation_0-mae:0.14090
[2563]	validation_0-mae:0.14095
== fold 2 ==
[0]	validation_0-mae:5.77049
[500]	validation_0-mae:0.14468
[1000]	validation_0-mae:0.14346
[1500]	validation_0-mae:0.14297
[2000]	validation_0-mae:0.14255
[2500]	validation_0-mae:0.14241
[3000]	validation_0-mae:0.14225
[3500]	validation_0-mae:0.14215
[4000]	validation_0-mae:0.14207
[4500]	validation_0-mae:0.14197
[5000]	validation_0-mae:0.14188
[5439]	validation_0-mae:0.14185
== fold 3 ==
[0]	validation_0-mae:5.73449
[500]	validation_0-mae:0.14256
[1000]	validation_0-mae:0.14088
[1500]	validation_0-mae:0.14023
[2000]	validation_0-mae:0.13987
[2500]	validation_0-mae:0.13967
[300

 22%|██▏       | 4/18 [02:29<09:55, 42.55s/it]

== fold 0 ==
[0]	validation_0-mae:0.16069
[500]	validation_0-mae:0.07347
[1000]	validation_0-mae:0.07305
[1182]	validation_0-mae:0.07308
== fold 1 ==
[0]	validation_0-mae:0.16734
[500]	validation_0-mae:0.07146
[1000]	validation_0-mae:0.07125
[1500]	validation_0-mae:0.07115
[1926]	validation_0-mae:0.07110
== fold 2 ==
[0]	validation_0-mae:0.17003
[500]	validation_0-mae:0.07486
[1000]	validation_0-mae:0.07460
[1177]	validation_0-mae:0.07459
== fold 3 ==
[0]	validation_0-mae:0.15821
[500]	validation_0-mae:0.07340
[1000]	validation_0-mae:0.07302
[1468]	validation_0-mae:0.07302
== fold 4 ==
[0]	validation_0-mae:0.17257
[500]	validation_0-mae:0.07549
[1000]	validation_0-mae:0.07488
[1500]	validation_0-mae:0.07470
[1939]	validation_0-mae:0.07461


 28%|██▊       | 5/18 [03:07<08:50, 40.78s/it]

== fold 0 ==
[0]	validation_0-mae:0.05456
[317]	validation_0-mae:0.05368
== fold 1 ==
[0]	validation_0-mae:0.05379
[431]	validation_0-mae:0.05269
== fold 2 ==
[0]	validation_0-mae:0.05571
[399]	validation_0-mae:0.05446
== fold 3 ==
[0]	validation_0-mae:0.05645
[500]	validation_0-mae:0.05524
[1000]	validation_0-mae:0.05520
[1063]	validation_0-mae:0.05522
== fold 4 ==
[0]	validation_0-mae:0.05473
[396]	validation_0-mae:0.05365


 33%|███▎      | 6/18 [03:20<06:18, 31.55s/it]

== fold 0 ==
[0]	validation_0-mae:8.69555
[500]	validation_0-mae:0.24530
[1000]	validation_0-mae:0.24098
[1500]	validation_0-mae:0.23964
[1769]	validation_0-mae:0.23963
== fold 1 ==
[0]	validation_0-mae:8.70621
[500]	validation_0-mae:0.24888
[1000]	validation_0-mae:0.24543
[1500]	validation_0-mae:0.24483
[2000]	validation_0-mae:0.24444
[2500]	validation_0-mae:0.24411
[3000]	validation_0-mae:0.24391
[3500]	validation_0-mae:0.24359
[4000]	validation_0-mae:0.24332
[4500]	validation_0-mae:0.24305
[4899]	validation_0-mae:0.24302
== fold 2 ==
[0]	validation_0-mae:8.80491
[500]	validation_0-mae:0.24573
[1000]	validation_0-mae:0.24308
[1500]	validation_0-mae:0.24247
[2000]	validation_0-mae:0.24216
[2500]	validation_0-mae:0.24201
[3000]	validation_0-mae:0.24186
[3500]	validation_0-mae:0.24165
[4000]	validation_0-mae:0.24151
[4199]	validation_0-mae:0.24154
== fold 3 ==
[0]	validation_0-mae:8.74942
[500]	validation_0-mae:0.24774
[1000]	validation_0-mae:0.24364
[1500]	validation_0-mae:0.24246
[200

 39%|███▉      | 7/18 [04:46<09:01, 49.26s/it]

== fold 0 ==
[0]	validation_0-mae:0.30753
[500]	validation_0-mae:0.13082
[1000]	validation_0-mae:0.12968
[1388]	validation_0-mae:0.12966
== fold 1 ==
[0]	validation_0-mae:0.32486
[500]	validation_0-mae:0.13081
[1000]	validation_0-mae:0.12968
[1500]	validation_0-mae:0.12936
[1918]	validation_0-mae:0.12937
== fold 2 ==
[0]	validation_0-mae:0.32624
[500]	validation_0-mae:0.13499
[1000]	validation_0-mae:0.13342
[1500]	validation_0-mae:0.13324
[1886]	validation_0-mae:0.13362
== fold 3 ==
[0]	validation_0-mae:0.30421
[500]	validation_0-mae:0.13240
[1000]	validation_0-mae:0.13131
[1361]	validation_0-mae:0.13125
== fold 4 ==
[0]	validation_0-mae:0.33226
[500]	validation_0-mae:0.13671
[1000]	validation_0-mae:0.13477
[1500]	validation_0-mae:0.13417
[2000]	validation_0-mae:0.13400
[2332]	validation_0-mae:0.13399


 44%|████▍     | 8/18 [05:29<07:52, 47.26s/it]

== fold 0 ==
[0]	validation_0-mae:0.08426
[417]	validation_0-mae:0.08258
== fold 1 ==
[0]	validation_0-mae:0.08308
[424]	validation_0-mae:0.08088
== fold 2 ==
[0]	validation_0-mae:0.08595
[382]	validation_0-mae:0.08387
== fold 3 ==
[0]	validation_0-mae:0.08784
[500]	validation_0-mae:0.08563
[660]	validation_0-mae:0.08569
== fold 4 ==
[0]	validation_0-mae:0.08493
[500]	validation_0-mae:0.08266
[758]	validation_0-mae:0.08277


 50%|█████     | 9/18 [05:43<05:32, 36.89s/it]

== fold 0 ==
[0]	validation_0-mae:11.69217
[500]	validation_0-mae:0.39172
[1000]	validation_0-mae:0.38525
[1500]	validation_0-mae:0.38393
[1768]	validation_0-mae:0.38390
== fold 1 ==
[0]	validation_0-mae:11.70499
[500]	validation_0-mae:0.40406
[1000]	validation_0-mae:0.39811
[1500]	validation_0-mae:0.39672
[2000]	validation_0-mae:0.39595
[2500]	validation_0-mae:0.39517
[3000]	validation_0-mae:0.39472
[3265]	validation_0-mae:0.39474
== fold 2 ==
[0]	validation_0-mae:11.84051
[500]	validation_0-mae:0.39708
[1000]	validation_0-mae:0.39211
[1500]	validation_0-mae:0.39082
[2000]	validation_0-mae:0.39014
[2500]	validation_0-mae:0.38971
[3000]	validation_0-mae:0.38959
[3307]	validation_0-mae:0.38953
== fold 3 ==
[0]	validation_0-mae:11.76079
[500]	validation_0-mae:0.39801
[1000]	validation_0-mae:0.39259
[1500]	validation_0-mae:0.39122
[2000]	validation_0-mae:0.39015
[2500]	validation_0-mae:0.38931
[3000]	validation_0-mae:0.38877
[3500]	validation_0-mae:0.38836
[4000]	validation_0-mae:0.38808


 56%|█████▌    | 10/18 [06:59<06:31, 48.88s/it]

== fold 0 ==
[0]	validation_0-mae:0.49677
[500]	validation_0-mae:0.21494
[1000]	validation_0-mae:0.21305
[1500]	validation_0-mae:0.21270
[1686]	validation_0-mae:0.21280
== fold 1 ==
[0]	validation_0-mae:0.52904
[500]	validation_0-mae:0.21897
[1000]	validation_0-mae:0.21691
[1500]	validation_0-mae:0.21650
[1700]	validation_0-mae:0.21655
== fold 2 ==
[0]	validation_0-mae:0.52863
[500]	validation_0-mae:0.22272
[1000]	validation_0-mae:0.22044
[1500]	validation_0-mae:0.21992
[2000]	validation_0-mae:0.21941
[2221]	validation_0-mae:0.21949
== fold 3 ==
[0]	validation_0-mae:0.49479
[500]	validation_0-mae:0.21963
[1000]	validation_0-mae:0.21795
[1388]	validation_0-mae:0.21776
== fold 4 ==
[0]	validation_0-mae:0.54055
[500]	validation_0-mae:0.22573
[1000]	validation_0-mae:0.22346
[1500]	validation_0-mae:0.22260
[2000]	validation_0-mae:0.22227
[2357]	validation_0-mae:0.22226


 61%|██████    | 11/18 [07:44<05:33, 47.65s/it]

== fold 0 ==
[0]	validation_0-mae:0.11534
[366]	validation_0-mae:0.11256
== fold 1 ==
[0]	validation_0-mae:0.11421
[500]	validation_0-mae:0.11061
[582]	validation_0-mae:0.11063
== fold 2 ==
[0]	validation_0-mae:0.11769
[399]	validation_0-mae:0.11458
== fold 3 ==
[0]	validation_0-mae:0.12032
[500]	validation_0-mae:0.11689
[667]	validation_0-mae:0.11696
== fold 4 ==
[0]	validation_0-mae:0.11716
[500]	validation_0-mae:0.11377
[816]	validation_0-mae:0.11385


 67%|██████▋   | 12/18 [07:58<03:45, 37.64s/it]

== fold 0 ==
[0]	validation_0-mae:14.68608
[500]	validation_0-mae:0.59145
[1000]	validation_0-mae:0.58154
[1500]	validation_0-mae:0.57963
[2000]	validation_0-mae:0.57862
[2500]	validation_0-mae:0.57793
[3000]	validation_0-mae:0.57762
[3500]	validation_0-mae:0.57711
[4000]	validation_0-mae:0.57671
[4488]	validation_0-mae:0.57647
== fold 1 ==
[0]	validation_0-mae:14.70070
[500]	validation_0-mae:0.60543
[1000]	validation_0-mae:0.59702
[1500]	validation_0-mae:0.59468
[2000]	validation_0-mae:0.59316
[2500]	validation_0-mae:0.59241
[3000]	validation_0-mae:0.59204
[3500]	validation_0-mae:0.59164
[3983]	validation_0-mae:0.59145
== fold 2 ==
[0]	validation_0-mae:14.87083
[500]	validation_0-mae:0.60132
[1000]	validation_0-mae:0.59362
[1500]	validation_0-mae:0.59159
[2000]	validation_0-mae:0.59105
[2500]	validation_0-mae:0.59053
[3000]	validation_0-mae:0.59037
[3168]	validation_0-mae:0.59043
== fold 3 ==
[0]	validation_0-mae:14.76740
[500]	validation_0-mae:0.59452
[1000]	validation_0-mae:0.58655


 72%|███████▏  | 13/18 [09:27<04:24, 52.98s/it]

== fold 0 ==
[0]	validation_0-mae:0.72713
[500]	validation_0-mae:0.33063
[1000]	validation_0-mae:0.32799
[1500]	validation_0-mae:0.32714
[1827]	validation_0-mae:0.32713
== fold 1 ==
[0]	validation_0-mae:0.77953
[500]	validation_0-mae:0.34125
[1000]	validation_0-mae:0.33862
[1446]	validation_0-mae:0.33787
== fold 2 ==
[0]	validation_0-mae:0.77445
[500]	validation_0-mae:0.34251
[1000]	validation_0-mae:0.33860
[1281]	validation_0-mae:0.33937
== fold 3 ==
[0]	validation_0-mae:0.72922
[500]	validation_0-mae:0.33790
[1000]	validation_0-mae:0.33574
[1500]	validation_0-mae:0.33516
[1941]	validation_0-mae:0.33529
== fold 4 ==
[0]	validation_0-mae:0.79549
[500]	validation_0-mae:0.35447
[1000]	validation_0-mae:0.34803
[1500]	validation_0-mae:0.34642
[1831]	validation_0-mae:0.34636


 78%|███████▊  | 14/18 [10:07<03:15, 48.99s/it]

== fold 0 ==
[0]	validation_0-mae:0.14814
[453]	validation_0-mae:0.14429
== fold 1 ==
[0]	validation_0-mae:0.14715
[500]	validation_0-mae:0.14199
[541]	validation_0-mae:0.14203
== fold 2 ==
[0]	validation_0-mae:0.15076
[468]	validation_0-mae:0.14589
== fold 3 ==
[0]	validation_0-mae:0.15401
[500]	validation_0-mae:0.14945
[839]	validation_0-mae:0.14947
== fold 4 ==
[0]	validation_0-mae:0.15034
[411]	validation_0-mae:0.14609


 83%|████████▎ | 15/18 [10:21<01:55, 38.54s/it]

== fold 0 ==
[0]	validation_0-mae:17.67537
[500]	validation_0-mae:0.83620
[1000]	validation_0-mae:0.82434
[1500]	validation_0-mae:0.82233
[2000]	validation_0-mae:0.82141
[2500]	validation_0-mae:0.82077
[3000]	validation_0-mae:0.82027
[3500]	validation_0-mae:0.81991
[3692]	validation_0-mae:0.81995
== fold 1 ==
[0]	validation_0-mae:17.69783
[500]	validation_0-mae:0.86434
[1000]	validation_0-mae:0.85402
[1500]	validation_0-mae:0.85206
[2000]	validation_0-mae:0.85031
[2500]	validation_0-mae:0.84950
[2774]	validation_0-mae:0.84933
== fold 2 ==
[0]	validation_0-mae:17.91073
[500]	validation_0-mae:0.85734
[1000]	validation_0-mae:0.84613
[1500]	validation_0-mae:0.84415
[2000]	validation_0-mae:0.84307
[2500]	validation_0-mae:0.84230
[3000]	validation_0-mae:0.84107
[3500]	validation_0-mae:0.84069
[3519]	validation_0-mae:0.84069
== fold 3 ==
[0]	validation_0-mae:17.77506
[500]	validation_0-mae:0.84382
[1000]	validation_0-mae:0.83398
[1500]	validation_0-mae:0.83066
[1653]	validation_0-mae:0.83085


 89%|████████▉ | 16/18 [11:33<01:37, 48.64s/it]

== fold 0 ==
[0]	validation_0-mae:0.99782
[500]	validation_0-mae:0.48073
[1000]	validation_0-mae:0.47642
[1500]	validation_0-mae:0.47552
[1682]	validation_0-mae:0.47558
== fold 1 ==
[0]	validation_0-mae:1.07299
[500]	validation_0-mae:0.49871
[1000]	validation_0-mae:0.49562
[1500]	validation_0-mae:0.49432
[1601]	validation_0-mae:0.49444
== fold 2 ==
[0]	validation_0-mae:1.06145
[500]	validation_0-mae:0.50198
[1000]	validation_0-mae:0.49429
[1500]	validation_0-mae:0.49330
[1586]	validation_0-mae:0.49344
== fold 3 ==
[0]	validation_0-mae:1.00633
[500]	validation_0-mae:0.49400
[1000]	validation_0-mae:0.48981
[1500]	validation_0-mae:0.48816
[2000]	validation_0-mae:0.48802
[2049]	validation_0-mae:0.48800
== fold 4 ==
[0]	validation_0-mae:1.09491
[500]	validation_0-mae:0.51652
[1000]	validation_0-mae:0.50598
[1500]	validation_0-mae:0.50400
[2000]	validation_0-mae:0.50345
[2209]	validation_0-mae:0.50345


 94%|█████████▍| 17/18 [12:17<00:47, 47.27s/it]

== fold 0 ==
[0]	validation_0-mae:0.18290
[383]	validation_0-mae:0.17759
== fold 1 ==
[0]	validation_0-mae:0.18177
[455]	validation_0-mae:0.17507
== fold 2 ==
[0]	validation_0-mae:0.18544
[447]	validation_0-mae:0.17912
== fold 3 ==
[0]	validation_0-mae:0.18894
[432]	validation_0-mae:0.18332
== fold 4 ==
[0]	validation_0-mae:0.18484
[380]	validation_0-mae:0.17905


100%|██████████| 18/18 [12:28<00:00, 41.61s/it]


In [14]:
mae = np.abs(train["oof"] - train["target"]).mean()
print(f"MAE: {mae}")

MAE: 0.23028796363219695


In [15]:
rev_dict = {v: "oof_" + k for k, v in CATEGORY_MAPPING["target_name"].items()}
oof_df = (
    train.select(["ID", "target_name", "oof"])
    .with_columns(pl.col("target_name").replace_strict(rev_dict))
    .pivot(index="ID", columns="target_name", values="oof")
)
oof_df.head()

ID,oof_x_0,oof_y_0,oof_z_0,oof_x_1,oof_y_1,oof_z_1,oof_x_2,oof_y_2,oof_z_2,oof_x_3,oof_y_3,oof_z_3,oof_x_4,oof_y_4,oof_z_4,oof_x_5,oof_y_5,oof_z_5
str,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""00066be8e20318869c38c66be46663…",2.66032,-0.009087,0.0112,5.862541,-0.042287,0.019358,9.444977,-0.097,0.039624,13.653478,-0.18745,0.047936,17.308624,-0.370154,0.085177,21.455189,-0.437528,0.127063
"""00066be8e20318869c38c66be46663…",5.035405,-0.064468,0.01415,10.514678,-0.233034,0.035776,15.887778,-0.414633,0.055777,21.37081,-0.532647,0.08411,26.310242,-0.679434,0.088512,31.392342,-0.869482,0.064034
"""00066be8e20318869c38c66be46663…",4.717708,-0.016422,0.014091,10.118437,-0.061507,0.035542,15.277604,-0.138827,0.053336,20.977812,-0.236764,0.09642,26.509876,-0.386165,0.093717,31.893902,-0.479848,0.117592
"""000fb056f97572d384bae4f5fc1e0f…",2.725299,0.020627,-0.000482,5.767129,0.12537,-0.001532,8.653085,0.271812,-0.007948,11.57066,0.511329,-0.003318,14.297729,0.774844,-0.007069,16.974367,1.198038,-0.007766
"""000fb056f97572d384bae4f5fc1e0f…",1.598209,-0.039609,0.00284,3.841709,-0.104585,0.002361,6.19996,-0.210795,0.003604,8.824244,-0.426306,0.029741,11.542909,-0.437678,0.041353,15.56218,-0.60901,0.05184


In [16]:
oof_df.write_csv(os.path.join(CFG["output_dir"], "oof.csv"))