In [1]:
import importlib
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from atmacup_18 import constants

import utils

importlib.reload(utils)

<module 'utils' from '/home/tatsuya/projects/atmacup/atmacup_18/experiments/main/v00/v00_11_02/utils.py'>

In [2]:
RANDOM_STATE = 2024
utils.seed_everything(RANDOM_STATE)

## データ読み込み

In [3]:
notebook_dir = Path().resolve()
DATA_DIR = notebook_dir.parents[3].joinpath("data")
DATASET_DIR = DATA_DIR.joinpath("atmaCup#18_dataset")
TR_FEATURES_CSV = DATASET_DIR.joinpath("train_features.csv")
TS_FEATURES_CSV = DATASET_DIR.joinpath("test_features.csv")
IMAGES_DIR = DATASET_DIR.joinpath("images")
TRAFFIC_LIGHTS_CSV = DATASET_DIR.joinpath("traffic_lights.csv")

IMAGE_NAMES = ["image_t.png", "image_t-0.5.png", "image_t-1.0.png"]
TRAFFIC_LIGHTS_BBOX_IMAGE_NAME = constants.TRAFFIC_LIGHT_BBOX_IMAGE_NAME

In [4]:
tr_df = utils.read_feature_csv(TR_FEATURES_CSV)
tr_df.head(2)

ID,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,x_0,y_0,z_0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5,scene_id,scene_dsec,origin_idx
str,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,i32,i64
"""00066be8e20318869c38c66be46663…",5.701526,1.538456,-2.165777,-139.0,0.0,False,0.25,True,"""drive""",False,False,2.82959,0.032226,0.045187,6.231999,0.065895,0.107974,9.785009,0.124972,0.203649,13.485472,0.163448,0.302818,17.574227,0.174289,0.406331,21.951269,0.199503,0.485079,"""00066be8e20318869c38c66be46663…",320,0
"""00066be8e20318869c38c66be46663…",11.176292,0.279881,-11.625697,-44.0,0.0,False,0.0,False,"""drive""",False,True,4.970268,-0.007936,0.005028,10.350489,-0.032374,-0.020701,15.770054,0.084073,0.008645,21.132415,0.391343,0.036335,26.316489,0.843124,0.065,31.383814,1.42507,0.073083,"""00066be8e20318869c38c66be46663…",420,1


In [5]:
ts_df = utils.read_feature_csv(TS_FEATURES_CSV)
ts_df.head(2)

ID,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,scene_id,scene_dsec,origin_idx
str,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,str,i32,i64
"""012baccc145d400c896cb82065a93d…",3.374273,-0.01936,-34.008415,17.0,0.0,False,0.0,False,"""drive""",False,False,"""012baccc145d400c896cb82065a93d…",120,0
"""012baccc145d400c896cb82065a93d…",2.441048,-0.022754,307.860077,295.0,0.0,True,0.0,False,"""drive""",False,False,"""012baccc145d400c896cb82065a93d…",220,1


In [6]:
tr_tl_bbox_images = utils.load_npy_images(
    IMAGES_DIR,
    ids=tr_df.get_column("ID").to_list(),
    image_names=[TRAFFIC_LIGHTS_BBOX_IMAGE_NAME],
)
print(tr_tl_bbox_images.shape)
ts_tl_bbox_images = utils.load_npy_images(
    IMAGES_DIR,
    ids=ts_df.get_column("ID").to_list(),
    image_names=[TRAFFIC_LIGHTS_BBOX_IMAGE_NAME],
)
print(ts_tl_bbox_images.shape)

(43371, 1, 64, 128, 8)
(1727, 1, 64, 128, 8)


In [7]:
tr_images = utils.load_images(
    IMAGES_DIR, ids=tr_df.get_column("ID").to_list(), image_names=IMAGE_NAMES
)
print(tr_images.shape)
ts_images = utils.load_images(
    IMAGES_DIR, ids=ts_df.get_column("ID").to_list(), image_names=IMAGE_NAMES
)
print(ts_images.shape)

(43371, 3, 64, 128, 3)
(1727, 3, 64, 128, 3)


In [8]:
tr_images = utils.preprocess_images([tr_images, tr_tl_bbox_images])
ts_images = utils.preprocess_images([ts_images, ts_tl_bbox_images])

print(tr_images.shape)
print(ts_images.shape)

(43371, 17, 64, 128)
(1727, 17, 64, 128)


## scene_dsec順に並び替える

In [9]:
tr_df = tr_df.sort(["scene_id", "scene_dsec"])
ts_df = ts_df.sort(["scene_id", "scene_dsec"])

tr_images = tr_images[tr_df.get_column("origin_idx").to_numpy()]
ts_images = ts_images[ts_df.get_column("origin_idx").to_numpy()]

## Target

In [10]:
target = utils.CoordinateTarget(prefix="tg_")
target.fit(tr_df)

tg_df = target.transform(tr_df)
print(tg_df.columns)
print(tg_df.describe().glimpse())
tr_df = pl.concat([tr_df, tg_df], how="horizontal")

['tg_cood_x_0', 'tg_cood_y_0', 'tg_cood_z_0', 'tg_cood_x_1', 'tg_cood_y_1', 'tg_cood_z_1', 'tg_cood_x_2', 'tg_cood_y_2', 'tg_cood_z_2', 'tg_cood_x_3', 'tg_cood_y_3', 'tg_cood_z_3', 'tg_cood_x_4', 'tg_cood_y_4', 'tg_cood_z_4', 'tg_cood_x_5', 'tg_cood_y_5', 'tg_cood_z_5']
Rows: 9
Columns: 19
$ statistic   <str> 'count', 'null_count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'
$ tg_cood_x_0 <f64> 43371.0, 0.0, 4.122443942757371, 3.2667167639213908, -1.7321542071537557, 1.116530690565041, 3.843337458989432, 6.4281197924248215, 12.392587231992154
$ tg_cood_y_0 <f64> 43371.0, 0.0, 0.0019486856369589753, 0.11686590022408185, -2.5341378248203235, -0.02646308932096216, 0.0010884804706597444, 0.030664179622664968, 3.4595563267615925
$ tg_cood_z_0 <f64> 43371.0, 0.0, 0.001247332026343412, 0.040745039147660006, -0.9965478318668152, -0.017624552286421614, 0.0011112325970167394, 0.01959056385186172, 1.4479292511292303
$ tg_cood_x_1 <f64> 43371.0, 0.0, 8.694386412319787, 6.894354638218222, -3.1

## 特徴量

In [11]:
feature = utils.Feature(prefix="ft_")
feature.fit(tr_df)

ft_df = feature.transform(tr_df)
print(ft_df.columns)
print(ft_df.describe().glimpse())
tr_df = pl.concat([tr_df, ft_df], how="horizontal")

ft_df = feature.transform(ts_df)
print(ft_df.columns)
print(ft_df.describe().glimpse())
ts_df = pl.concat([ts_df, ft_df], how="horizontal")

['ft_vEgo', 'ft_aEgo', 'ft_steeringAngleDeg', 'ft_steeringTorque', 'ft_brake', 'ft_brakePressed', 'ft_gas', 'ft_gasPressed', 'ft_is_gearShifter_drive', 'ft_is_gearShifter_neutral', 'ft_is_gearShifter_park', 'ft_is_gearShifter_reverse', 'ft_leftBlinker', 'ft_rightBlinker']
Rows: 9
Columns: 15
$ statistic                 <str> 'count', 'null_count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'
$ ft_vEgo                   <f64> 43371.0, 0.0, 9.172175823216334, 7.226919878374694, -0.1619189828634262, 2.5786657333374023, 8.518790245056152, 14.286815643310547, 27.55126190185547
$ ft_aEgo                   <f64> 43371.0, 0.0, -0.015654028629347255, 0.6324016778486632, -4.936206340789795, -0.2363678514957428, -1.8347540436410405e-15, 0.22229795157909396, 3.1400704383850098
$ ft_steeringAngleDeg       <f64> 43371.0, 0.0, -2.065172574071012, 65.54882159006848, -481.394287109375, -3.461754322052002, -0.35647091269493103, 2.6269068717956543, 484.69171142578125
$ ft_steeringTorque         <f64>

## モデリング

In [12]:
N_SPLITS = 2

In [13]:
n_sample_in_scene = 3

model_params = {
    "dnn": {
        "n_sample_in_scene": n_sample_in_scene,
        "n_img_channels": tr_images.shape[1] * n_sample_in_scene,
        "n_features": len(feature.columns) * n_sample_in_scene,
        "n_targets": len(target.columns),
        "segmentation_model_params": {
            "encoder_name": "mit_b1",
            "encoder_weights": "imagenet",
            "decoder_channels": (256, 128, 64, 32, 16),
        },
        "dropout": 0.0,
        "embed_dim": 128,
        "n_layers": 1,
    },
    "dev": "cuda",
}

lr = 1e-4
fit_params = {
    "dnn": {
        "tr_batch_size": 32,
        "vl_batch_size": 128,
        "trainer_params": {
            "criterion_params": {},
            "opt": "adamw",
            "opt_params": {"lr": lr, "weight_decay": 1e-4},
            "backbone_opt_params": {"lr": lr, "weight_decay": 1e-4},
            "sch_params": {
                "max_lr": lr,
                "pct_start": 0.1,
                "div_factor": 25,
                "final_div_factor": 1000,
            },
            "epochs": 40,
            "dev": "cuda",
            "prefix": "",
            "save_best": False,
            "save_epochs": [],
            "maximize_score": False,
            "grad_max_norm": None,
        },
    },
}

In [14]:
models, oof_preds = utils.train(
    model_params=model_params,
    fit_params=fit_params,
    df=tr_df,
    images=tr_images,
    target_cols=target.columns,
    feature_cols=feature.columns,
    group_col="scene_id",
    scene_id_col="scene_id",
    scene_dsec_col="scene_dsec",
    n_splits=N_SPLITS,
)

-----------------
-----------------
Training fold 0...
train samples: 21685, valid samples: 21686


Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /home/tatsuya/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:02<00:00, 78.6MB/s] 


Save model : model.pth

epoch  0
lr  4.000000000000002e-06
lr  4.000000000000002e-06
lr  4.000000000000002e-06
lr  4.000000000000002e-06


100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 19.2190
{'loss': 19.219037520480473, 'loss_mse_0': 0.7827025393307473, 'loss_mse_1': 0.09017284735321911, 'loss_mse_2': 0.0564818971885294, 'loss_mse_3': 1.8210845052579656, 'loss_mse_4': 0.2583827936975079, 'loss_mse_5': 0.09690396750464134, 'loss_mse_6': 1.8921904253994872, 'loss_mse_7': 0.44666905322578493, 'loss_mse_8': 0.14657857079757722, 'loss_mse_9': 2.852236091562353, 'loss_mse_10': 0.6633169852648387, 'loss_mse_11': 0.20122152013343758, 'loss_mse_12': 2.789030848965159, 'loss_mse_13': 0.9823969762821846, 'loss_mse_14': 0.25682503435116927, 'loss_mse_15': 4.16047440088482, 'loss_mse_16': 1.4264837627792992, 'loss_mse_17': 0.29588531041339017}
Valid Loss: 5.9667
{'loss': 5.966720338428722, 'loss_mse_0': 0.1311399317401297, 'loss_mse_1': 0.07168888846302733, 'loss_mse_2': 0.040499751089031207, 'loss_mse_3': 0.23673596373375724, 'loss_mse_4': 0.15495490655303001, 'loss_mse_5': 0.0812061960425447, 'loss_mse_6': 0.3765395766671966, 'loss_mse_7': 0.2779758483171463, 'lo

100%|██████████| 677/677 [01:52<00:00,  6.03it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 5.0274
{'loss': 5.027449174857034, 'loss_mse_0': 0.09497704933035674, 'loss_mse_1': 0.05897562413653868, 'loss_mse_2': 0.042379714311974084, 'loss_mse_3': 0.1853024845934601, 'loss_mse_4': 0.13625019401080762, 'loss_mse_5': 0.08613392712336523, 'loss_mse_6': 0.25579492124884307, 'loss_mse_7': 0.24750512045982845, 'loss_mse_8': 0.1294470084432177, 'loss_mse_9': 0.34210981460538537, 'loss_mse_10': 0.4096266144179523, 'loss_mse_11': 0.17463012470518469, 'loss_mse_12': 0.41029032782006297, 'loss_mse_13': 0.6047731814589599, 'loss_mse_14': 0.22246347922980345, 'loss_mse_15': 0.49396100981460894, 'loss_mse_16': 0.8612493800447821, 'loss_mse_17': 0.27157918434243433}
Valid Loss: 4.1915
{'loss': 4.191537984679727, 'loss_mse_0': 0.06372094564139844, 'loss_mse_1': 0.06117416689281954, 'loss_mse_2': 0.03887483644134858, 'loss_mse_3': 0.11171929790254902, 'loss_mse_4': 0.1230156598283964, 'loss_mse_5': 0.07982480960952885, 'loss_mse_6': 0.18101558010367785, 'loss_mse_7': 0.21585216890

100%|██████████| 677/677 [01:52<00:00,  6.00it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 3.9176
{'loss': 3.917564993066661, 'loss_mse_0': 0.05573227101224344, 'loss_mse_1': 0.051387007205170664, 'loss_mse_2': 0.04140171879542059, 'loss_mse_3': 0.11117101423277284, 'loss_mse_4': 0.1072945073136903, 'loss_mse_5': 0.08458153372754286, 'loss_mse_6': 0.16085215537441533, 'loss_mse_7': 0.19143974844316788, 'loss_mse_8': 0.12785372103081585, 'loss_mse_9': 0.2135187641853684, 'loss_mse_10': 0.318382156745942, 'loss_mse_11': 0.17283813858974278, 'loss_mse_12': 0.27290005815742285, 'loss_mse_13': 0.4822246783428125, 'loss_mse_14': 0.21956765394729386, 'loss_mse_15': 0.33092749136963656, 'loss_mse_16': 0.7071959178790208, 'loss_mse_17': 0.26829644269481895}
Valid Loss: 3.2532
{'loss': 3.253187926376567, 'loss_mse_0': 0.04275953753248734, 'loss_mse_1': 0.05338859202011543, 'loss_mse_2': 0.038667778374955934, 'loss_mse_3': 0.08329458931351409, 'loss_mse_4': 0.0952356785097543, 'loss_mse_5': 0.07830357284230345, 'loss_mse_6': 0.11831613430643782, 'loss_mse_7': 0.16070396674

100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:15<00:00,  2.24it/s]



Train Loss: 3.3498
{'loss': 3.349848783174681, 'loss_mse_0': 0.043163679423226124, 'loss_mse_1': 0.04570175961768764, 'loss_mse_2': 0.04130617406857551, 'loss_mse_3': 0.0870528781547196, 'loss_mse_4': 0.08934191695140727, 'loss_mse_5': 0.08424581039388784, 'loss_mse_6': 0.12766055565222856, 'loss_mse_7': 0.155908620741598, 'loss_mse_8': 0.1274336521875559, 'loss_mse_9': 0.16697863543315336, 'loss_mse_10': 0.25839528988391997, 'loss_mse_11': 0.17272772465556024, 'loss_mse_12': 0.21714350598227325, 'loss_mse_13': 0.3965819483592644, 'loss_mse_14': 0.2193635280135375, 'loss_mse_15': 0.26521789134539214, 'loss_mse_16': 0.583442521214133, 'loss_mse_17': 0.2681826993690988}
Valid Loss: 2.6227
{'loss': 2.622738172727473, 'loss_mse_0': 0.02266184103072566, 'loss_mse_1': 0.048675340320914987, 'loss_mse_2': 0.03828933053814313, 'loss_mse_3': 0.04162030478610712, 'loss_mse_4': 0.08021168227800551, 'loss_mse_5': 0.0778349248582826, 'loss_mse_6': 0.05625933413119877, 'loss_mse_7': 0.13257412765832

100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:16<00:00,  2.22it/s]



Train Loss: 2.8713
{'loss': 2.87130234477439, 'loss_mse_0': 0.03277778977313611, 'loss_mse_1': 0.042171182152863314, 'loss_mse_2': 0.04105177586636832, 'loss_mse_3': 0.06511188267236315, 'loss_mse_4': 0.07683284070749975, 'loss_mse_5': 0.08364193258941878, 'loss_mse_6': 0.09487145176182117, 'loss_mse_7': 0.13015279658909124, 'loss_mse_8': 0.12656235994779377, 'loss_mse_9': 0.12742059126199173, 'loss_mse_10': 0.21165612569433723, 'loss_mse_11': 0.17153686123241174, 'loss_mse_12': 0.16677986727573496, 'loss_mse_13': 0.32678862544467213, 'loss_mse_14': 0.21737818235044465, 'loss_mse_15': 0.20702455051936638, 'loss_mse_16': 0.48407354906883693, 'loss_mse_17': 0.2654699737445467}
Valid Loss: 2.4101
{'loss': 2.4100508395363303, 'loss_mse_0': 0.022214970880133265, 'loss_mse_1': 0.04658511932939291, 'loss_mse_2': 0.03830660669790471, 'loss_mse_3': 0.03909295256085256, 'loss_mse_4': 0.07266487730119159, 'loss_mse_5': 0.07789053782163298, 'loss_mse_6': 0.06575112750425058, 'loss_mse_7': 0.11055

100%|██████████| 677/677 [01:52<00:00,  6.02it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 2.5638
{'loss': 2.5637740369561506, 'loss_mse_0': 0.025944402193075578, 'loss_mse_1': 0.038673035731673945, 'loss_mse_2': 0.0410762168246939, 'loss_mse_3': 0.05138464341218001, 'loss_mse_4': 0.06646850039739509, 'loss_mse_5': 0.08353075366934962, 'loss_mse_6': 0.07523977600225142, 'loss_mse_7': 0.11006482729916646, 'loss_mse_8': 0.12650122166098132, 'loss_mse_9': 0.10152772013950911, 'loss_mse_10': 0.17982271794873556, 'loss_mse_11': 0.1711128707562451, 'loss_mse_12': 0.13561435062520788, 'loss_mse_13': 0.28167388926602716, 'loss_mse_14': 0.21686420274383914, 'loss_mse_15': 0.1711357262211149, 'loss_mse_16': 0.4225230217420192, 'loss_mse_17': 0.2646161589325795}
Valid Loss: 2.1968
{'loss': 2.196847980162677, 'loss_mse_0': 0.022168270262944347, 'loss_mse_1': 0.04511633280883817, 'loss_mse_2': 0.0382020422257483, 'loss_mse_3': 0.03879130935186849, 'loss_mse_4': 0.06612589462276768, 'loss_mse_5': 0.07812543325126171, 'loss_mse_6': 0.061702298614032126, 'loss_mse_7': 0.0960252

100%|██████████| 677/677 [01:52<00:00,  6.00it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 2.3902
{'loss': 2.3901655688947905, 'loss_mse_0': 0.022441105913656986, 'loss_mse_1': 0.03738607481903392, 'loss_mse_2': 0.04087881647474084, 'loss_mse_3': 0.04423919074138299, 'loss_mse_4': 0.06325427401381678, 'loss_mse_5': 0.0832129000126215, 'loss_mse_6': 0.06496681116570882, 'loss_mse_7': 0.10185942878091, 'loss_mse_8': 0.12606121828584235, 'loss_mse_9': 0.08821151286165506, 'loss_mse_10': 0.1636077867190806, 'loss_mse_11': 0.17028484144076111, 'loss_mse_12': 0.1188719635210282, 'loss_mse_13': 0.2542317585752757, 'loss_mse_14': 0.21580156919434998, 'loss_mse_15': 0.15059168127253983, 'loss_mse_16': 0.38092058248806915, 'loss_mse_17': 0.2633440508844229}
Valid Loss: 2.2353
{'loss': 2.2353431996177227, 'loss_mse_0': 0.023463279392351124, 'loss_mse_1': 0.04415124169386485, 'loss_mse_2': 0.03823599792359506, 'loss_mse_3': 0.05087526943534613, 'loss_mse_4': 0.06323227703790454, 'loss_mse_5': 0.07789186659981223, 'loss_mse_6': 0.06912027011241983, 'loss_mse_7': 0.0928544471

100%|██████████| 677/677 [01:53<00:00,  5.99it/s]
100%|██████████| 170/170 [01:16<00:00,  2.22it/s]



Train Loss: 2.3404
{'loss': 2.340373482274516, 'loss_mse_0': 0.02235621822814965, 'loss_mse_1': 0.03653777372009469, 'loss_mse_2': 0.0408397511162271, 'loss_mse_3': 0.044502310079289946, 'loss_mse_4': 0.060387153596654426, 'loss_mse_5': 0.08304892015124793, 'loss_mse_6': 0.06613189159332605, 'loss_mse_7': 0.09629449397985833, 'loss_mse_8': 0.1257696875111039, 'loss_mse_9': 0.08934256406377127, 'loss_mse_10': 0.15434973778277417, 'loss_mse_11': 0.16977942950660405, 'loss_mse_12': 0.12022385548970813, 'loss_mse_13': 0.2395656217863588, 'loss_mse_14': 0.21507090104515478, 'loss_mse_15': 0.1535731206099163, 'loss_mse_16': 0.36051470456174417, 'loss_mse_17': 0.262085355394568}
Valid Loss: 2.0688
{'loss': 2.0688234420383678, 'loss_mse_0': 0.013588156756561469, 'loss_mse_1': 0.044661793364759755, 'loss_mse_2': 0.03822872826062581, 'loss_mse_3': 0.024226192867054658, 'loss_mse_4': 0.06613345866036766, 'loss_mse_5': 0.07775221929392394, 'loss_mse_6': 0.03588031942353529, 'loss_mse_7': 0.097444

100%|██████████| 677/677 [01:53<00:00,  5.99it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 2.1805
{'loss': 2.1805037390358164, 'loss_mse_0': 0.01809084952218758, 'loss_mse_1': 0.03584369911951188, 'loss_mse_2': 0.040740205256331666, 'loss_mse_3': 0.03580894525626473, 'loss_mse_4': 0.057716327268173544, 'loss_mse_5': 0.0828198407731257, 'loss_mse_6': 0.05290790430876802, 'loss_mse_7': 0.09017007120580933, 'loss_mse_8': 0.1253432702944453, 'loss_mse_9': 0.07297001241746207, 'loss_mse_10': 0.14246390720325874, 'loss_mse_11': 0.16913121274536258, 'loss_mse_12': 0.09903645878891561, 'loss_mse_13': 0.2203900121768697, 'loss_mse_14': 0.21432856071954992, 'loss_mse_15': 0.13028625575784303, 'loss_mse_16': 0.3313023703423173, 'loss_mse_17': 0.2611538366506776}
Valid Loss: 2.7220
{'loss': 2.721984101744259, 'loss_mse_0': 0.022413363406325087, 'loss_mse_1': 0.04906297574258026, 'loss_mse_2': 0.038228249423863256, 'loss_mse_3': 0.04769250536885332, 'loss_mse_4': 0.0862273286918507, 'loss_mse_5': 0.07815692802781568, 'loss_mse_6': 0.06000691631480175, 'loss_mse_7': 0.1437068

100%|██████████| 677/677 [01:53<00:00,  5.99it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 2.0545
{'loss': 2.054451432150611, 'loss_mse_0': 0.01692743444859299, 'loss_mse_1': 0.03370185848394274, 'loss_mse_2': 0.04059787204370805, 'loss_mse_3': 0.03325166773020759, 'loss_mse_4': 0.05195899412002926, 'loss_mse_5': 0.08241452727648409, 'loss_mse_6': 0.04900564691890994, 'loss_mse_7': 0.07985292269299399, 'loss_mse_8': 0.12449008296003106, 'loss_mse_9': 0.06814039716472302, 'loss_mse_10': 0.1259187486468425, 'loss_mse_11': 0.16763371552727246, 'loss_mse_12': 0.09275361862545352, 'loss_mse_13': 0.1972129013543423, 'loss_mse_14': 0.2119701446104173, 'loss_mse_15': 0.12139962548651406, 'loss_mse_16': 0.2994632969061504, 'loss_mse_17': 0.257757969805902}
Valid Loss: 2.0234
{'loss': 2.023382799765643, 'loss_mse_0': 0.015546890238628668, 'loss_mse_1': 0.043097536326112115, 'loss_mse_2': 0.03844286012189353, 'loss_mse_3': 0.028222301881760357, 'loss_mse_4': 0.06152070655542261, 'loss_mse_5': 0.07828518815119477, 'loss_mse_6': 0.04521129923489164, 'loss_mse_7': 0.085831206

100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:16<00:00,  2.23it/s]



Train Loss: 1.9347
{'loss': 1.9346580191940466, 'loss_mse_0': 0.015429892709294398, 'loss_mse_1': 0.032478969845442074, 'loss_mse_2': 0.040505387605700306, 'loss_mse_3': 0.030377169185080065, 'loss_mse_4': 0.049606570057802044, 'loss_mse_5': 0.08180657215141299, 'loss_mse_6': 0.04460366247593256, 'loss_mse_7': 0.07473823809487253, 'loss_mse_8': 0.12282041524106325, 'loss_mse_9': 0.06194673535555733, 'loss_mse_10': 0.11577313183978795, 'loss_mse_11': 0.16446110449918439, 'loss_mse_12': 0.08512876575133121, 'loss_mse_13': 0.17779566536889999, 'loss_mse_14': 0.20668914041962017, 'loss_mse_15': 0.11373388320702833, 'loss_mse_16': 0.2672494667220415, 'loss_mse_17': 0.24951324887584725}
Valid Loss: 2.1304
{'loss': 2.1304268465322607, 'loss_mse_0': 0.0182428660145139, 'loss_mse_1': 0.04302267302375506, 'loss_mse_2': 0.03804854266893338, 'loss_mse_3': 0.039271961240207445, 'loss_mse_4': 0.06364867118570734, 'loss_mse_5': 0.07643801840570043, 'loss_mse_6': 0.056306475948761495, 'loss_mse_7': 0

100%|██████████| 677/677 [01:52<00:00,  5.99it/s]
100%|██████████| 170/170 [01:16<00:00,  2.22it/s]



Train Loss: 1.7926
{'loss': 1.7925823065871855, 'loss_mse_0': 0.013714038755571362, 'loss_mse_1': 0.03206606782584876, 'loss_mse_2': 0.04017636656384238, 'loss_mse_3': 0.027051829935176686, 'loss_mse_4': 0.04676097219306398, 'loss_mse_5': 0.08039170227108448, 'loss_mse_6': 0.039926258240450874, 'loss_mse_7': 0.0672048466786346, 'loss_mse_8': 0.11985705213475438, 'loss_mse_9': 0.05540777637349494, 'loss_mse_10': 0.10117318049524436, 'loss_mse_11': 0.15921932141374978, 'loss_mse_12': 0.07799377094024305, 'loss_mse_13': 0.15445299572593352, 'loss_mse_14': 0.19905321944939686, 'loss_mse_15': 0.10481404858961856, 'loss_mse_16': 0.23426980019562177, 'loss_mse_17': 0.23904906032246537}
Valid Loss: 1.9786
{'loss': 1.9786050642237945, 'loss_mse_0': 0.012251850198406506, 'loss_mse_1': 0.04207802059895852, 'loss_mse_2': 0.03798569052973214, 'loss_mse_3': 0.026076713593347983, 'loss_mse_4': 0.06008940828854547, 'loss_mse_5': 0.07680158836438375, 'loss_mse_6': 0.04345186039367143, 'loss_mse_7': 0.

100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:15<00:00,  2.24it/s]



Train Loss: 1.6793
{'loss': 1.679313516687255, 'loss_mse_0': 0.012999552351242396, 'loss_mse_1': 0.030425168142586594, 'loss_mse_2': 0.03850830556959618, 'loss_mse_3': 0.025763712302359312, 'loss_mse_4': 0.0431432085417134, 'loss_mse_5': 0.07695953926261366, 'loss_mse_6': 0.0385512064095515, 'loss_mse_7': 0.0609678344627735, 'loss_mse_8': 0.11477470337144997, 'loss_mse_9': 0.05415464034840531, 'loss_mse_10': 0.09063650473699172, 'loss_mse_11': 0.15199522328517637, 'loss_mse_12': 0.07579225715541417, 'loss_mse_13': 0.13801785157594493, 'loss_mse_14': 0.1892635738325788, 'loss_mse_15': 0.10176896177606445, 'loss_mse_16': 0.20862505255897076, 'loss_mse_17': 0.22696621527281788}
Valid Loss: 1.8566
{'loss': 1.8565642297267915, 'loss_mse_0': 0.009265792846460553, 'loss_mse_1': 0.0419599525137421, 'loss_mse_2': 0.0380818399710252, 'loss_mse_3': 0.018605685973649515, 'loss_mse_4': 0.061628951888312314, 'loss_mse_5': 0.07608458500574616, 'loss_mse_6': 0.026398210254881312, 'loss_mse_7': 0.0884

100%|██████████| 677/677 [01:51<00:00,  6.08it/s]
100%|██████████| 170/170 [01:14<00:00,  2.28it/s]



Train Loss: 1.5725
{'loss': 1.5725210898912112, 'loss_mse_0': 0.011749735135073323, 'loss_mse_1': 0.02938757072776512, 'loss_mse_2': 0.03888345675947395, 'loss_mse_3': 0.02280059583592428, 'loss_mse_4': 0.04087532369868398, 'loss_mse_5': 0.0764828223743348, 'loss_mse_6': 0.03400461982242334, 'loss_mse_7': 0.05658932460645321, 'loss_mse_8': 0.11228994408861749, 'loss_mse_9': 0.04814663169798242, 'loss_mse_10': 0.08237316945988679, 'loss_mse_11': 0.14767130721839175, 'loss_mse_12': 0.06764954919521002, 'loss_mse_13': 0.12404374757910937, 'loss_mse_14': 0.1825291860513004, 'loss_mse_15': 0.09231026693793395, 'loss_mse_16': 0.1868860410853802, 'loss_mse_17': 0.21784779448694985}
Valid Loss: 1.8187
{'loss': 1.8187429841826943, 'loss_mse_0': 0.00936351178542656, 'loss_mse_1': 0.041686683504239604, 'loss_mse_2': 0.03865852280903388, 'loss_mse_3': 0.017236219710834763, 'loss_mse_4': 0.05737106055021286, 'loss_mse_5': 0.07816838519099881, 'loss_mse_6': 0.025725565067328076, 'loss_mse_7': 0.080

100%|██████████| 677/677 [01:50<00:00,  6.13it/s]
100%|██████████| 170/170 [01:14<00:00,  2.28it/s]



Train Loss: 1.4281
{'loss': 1.428080781435861, 'loss_mse_0': 0.010769667901770051, 'loss_mse_1': 0.02921859928175664, 'loss_mse_2': 0.03749559092867784, 'loss_mse_3': 0.02078470596740988, 'loss_mse_4': 0.038668054750978814, 'loss_mse_5': 0.0728283186609605, 'loss_mse_6': 0.030763034022343035, 'loss_mse_7': 0.04990064222888369, 'loss_mse_8': 0.10625727714604782, 'loss_mse_9': 0.04368760123810397, 'loss_mse_10': 0.06929559501760159, 'loss_mse_11': 0.13888241925677133, 'loss_mse_12': 0.06202472082065471, 'loss_mse_13': 0.1025302275352213, 'loss_mse_14': 0.17076107728648574, 'loss_mse_15': 0.08532523665373576, 'loss_mse_16': 0.15539386804921324, 'loss_mse_17': 0.20349415482477212}
Valid Loss: 1.8469
{'loss': 1.8469228667371413, 'loss_mse_0': 0.010601372537477052, 'loss_mse_1': 0.04056756192380014, 'loss_mse_2': 0.03791778605321751, 'loss_mse_3': 0.02177308147006175, 'loss_mse_4': 0.05659494456999442, 'loss_mse_5': 0.07623503951027114, 'loss_mse_6': 0.03292856420225957, 'loss_mse_7': 0.080

100%|██████████| 677/677 [01:50<00:00,  6.12it/s]
100%|██████████| 170/170 [01:14<00:00,  2.28it/s]



Train Loss: 1.3221
{'loss': 1.3221386108035704, 'loss_mse_0': 0.009812250799063367, 'loss_mse_1': 0.027730338089030505, 'loss_mse_2': 0.036106351542904244, 'loss_mse_3': 0.01876326069026701, 'loss_mse_4': 0.036166010586822755, 'loss_mse_5': 0.06934261421206651, 'loss_mse_6': 0.02725269568209128, 'loss_mse_7': 0.04648802506887358, 'loss_mse_8': 0.0997458484293439, 'loss_mse_9': 0.038654039856018924, 'loss_mse_10': 0.06415478908697603, 'loss_mse_11': 0.12993165601644682, 'loss_mse_12': 0.054802729606496355, 'loss_mse_13': 0.0946019730686129, 'loss_mse_14': 0.15926812500575974, 'loss_mse_15': 0.07625446665861493, 'loss_mse_16': 0.14310248302326067, 'loss_mse_17': 0.1899609534259737}
Valid Loss: 1.8732
{'loss': 1.8732471157522763, 'loss_mse_0': 0.01065775168758324, 'loss_mse_1': 0.042195717059075835, 'loss_mse_2': 0.03849738023298628, 'loss_mse_3': 0.01979932174743975, 'loss_mse_4': 0.06157373430974343, 'loss_mse_5': 0.07819470442174112, 'loss_mse_6': 0.029226822225267395, 'loss_mse_7': 0

100%|██████████| 677/677 [01:50<00:00,  6.13it/s]
100%|██████████| 170/170 [01:14<00:00,  2.27it/s]



Train Loss: 1.1947
{'loss': 1.1946583492364982, 'loss_mse_0': 0.009095653148706414, 'loss_mse_1': 0.02719891395115815, 'loss_mse_2': 0.032847614396236936, 'loss_mse_3': 0.0174660722235358, 'loss_mse_4': 0.034589215759545607, 'loss_mse_5': 0.06211979587891165, 'loss_mse_6': 0.02541617401139577, 'loss_mse_7': 0.04321479164696734, 'loss_mse_8': 0.08836331448766735, 'loss_mse_9': 0.03661787991124471, 'loss_mse_10': 0.058418742382411945, 'loss_mse_11': 0.1153766854034745, 'loss_mse_12': 0.052387277409984584, 'loss_mse_13': 0.08419415150030325, 'loss_mse_14': 0.1402150508307988, 'loss_mse_15': 0.0733417945128468, 'loss_mse_16': 0.1267094305046508, 'loss_mse_17': 0.16708579142213395}
Valid Loss: 1.9118
{'loss': 1.9117710877867307, 'loss_mse_0': 0.01231254806249019, 'loss_mse_1': 0.04095195563071791, 'loss_mse_2': 0.03923162802615587, 'loss_mse_3': 0.027490712899495572, 'loss_mse_4': 0.057951804250478745, 'loss_mse_5': 0.08149200835648705, 'loss_mse_6': 0.041351452721830675, 'loss_mse_7': 0.0

100%|██████████| 677/677 [01:52<00:00,  6.01it/s]
100%|██████████| 170/170 [01:15<00:00,  2.26it/s]



Train Loss: 1.0895
{'loss': 1.0894794871307019, 'loss_mse_0': 0.008836279280326898, 'loss_mse_1': 0.026363614403217633, 'loss_mse_2': 0.031159336886408585, 'loss_mse_3': 0.016770206692393326, 'loss_mse_4': 0.032079027892754824, 'loss_mse_5': 0.056849017417391255, 'loss_mse_6': 0.024214412186823048, 'loss_mse_7': 0.03868191804120556, 'loss_mse_8': 0.08079062068237863, 'loss_mse_9': 0.03452800226534003, 'loss_mse_10': 0.051210319683880186, 'loss_mse_11': 0.10384734885896414, 'loss_mse_12': 0.0492975749914566, 'loss_mse_13': 0.07453975302650094, 'loss_mse_14': 0.12654870065091456, 'loss_mse_15': 0.06951108408549819, 'loss_mse_16': 0.11377339917709679, 'loss_mse_17': 0.1504788689997647}
Valid Loss: 1.7635
{'loss': 1.763493873792536, 'loss_mse_0': 0.007536072559271227, 'loss_mse_1': 0.04043494512272232, 'loss_mse_2': 0.039211961276390976, 'loss_mse_3': 0.01499291999703821, 'loss_mse_4': 0.055860644162577744, 'loss_mse_5': 0.08038900051923359, 'loss_mse_6': 0.02143009013551123, 'loss_mse_7'

 16%|█▌        | 107/677 [00:17<01:35,  5.99it/s]


KeyboardInterrupt: 

In [None]:
oof_preds = oof_preds.select(pl.all().name.prefix("pred_"))
pred_cols = oof_preds.columns

tr_df = pl.concat([tr_df, oof_preds], how="horizontal")
tr_df

## 評価

In [None]:
def calc_score(df: pl.DataFrame, pred_cols: list[str]):
    tg_cols = sum([[f"x_{i}", f"y_{i}", f"z_{i}"] for i in range(6)], [])

    tg = df.select(tg_cols).to_numpy()
    pred = df.select(pred_cols).to_numpy()

    scores = np.abs(tg - pred).mean(axis=0)
    scores = {f"score_{col}": float(score) for col, score in zip(pred_cols, scores)}
    scores["avg"] = float(np.abs(tg - pred).mean())
    return scores


scores = calc_score(tr_df, pred_cols)
scores

In [None]:
utils.plot_calibration_curve(tr_df, pred_cols, n_bins=40)

## Submission

In [None]:
preds = utils.predict(
    models,
    ts_images,
    ts_df,
    feature.columns,
    scene_id_col="scene_id",
    scene_dsec_col="scene_dsec",
    pred_cols=pred_cols,
)
preds

In [None]:
# 元の順番に戻す
preds = (
    preds.with_columns(ts_df.get_column("origin_idx"))
    .sort("origin_idx")
    .drop("origin_idx")
)

In [None]:
def create_submission_csv(preds: pl.DataFrame, filename: str = "submission.csv"):
    submission_cols = sum([[f"x_{i}", f"y_{i}", f"z_{i}"] for i in range(6)], [])

    # validate preds columns
    if len(preds.columns) != len(submission_cols):
        raise ValueError(
            f"preds columns must be {len(submission_cols)}, but got {len(preds.columns)}"
        )

    preds.columns = submission_cols
    preds.write_csv(filename)
    print(f"Submission file is created: {filename}")


create_submission_csv(preds)