In [1]:
import gc
import importlib
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from atmacup_18 import constants

import utils

importlib.reload(utils)

<module 'utils' from '/home/tatsuya/projects/atmacup/atmacup_18/experiments/main/v00/v00_18_02/utils.py'>

In [2]:
RANDOM_STATE = 2024
utils.seed_everything(RANDOM_STATE)

## データ読み込み

In [3]:
notebook_dir = Path().resolve()
DATA_DIR = notebook_dir.parents[3].joinpath("data")
DATASET_DIR = DATA_DIR.joinpath("atmaCup#18_dataset")
TR_FEATURES_CSV = DATASET_DIR.joinpath("train_features.csv")
TS_FEATURES_CSV = DATASET_DIR.joinpath("test_features.csv")
IMAGES_DIR = DATASET_DIR.joinpath("images")
TRAFFIC_LIGHTS_CSV = DATASET_DIR.joinpath("traffic_lights.csv")

IMAGE_NAMES = ["image_t.png", "image_t-0.5.png", "image_t-1.0.png"]
TRAFFIC_LIGHTS_BBOX_IMAGE_NAME = constants.TRAFFIC_LIGHT_BBOX_IMAGE_NAME
OPTICAL_FLOW_IMAGE_NAME = constants.OPTICAL_FLOW_IMAGE_NAME

BASE_PRED_DIR = Path("..", "..", "..", "main2", "v00", "v00_04_00")
BASE_OOF_PRED_CSV = BASE_PRED_DIR.joinpath("oof_preds.csv")
BASE_SUBMISSION_CSV = BASE_PRED_DIR.joinpath("submission.csv")

In [4]:
TARGET_COLS = sum([[f"x_{i}", f"y_{i}", f"z_{i}"] for i in range(6)], [])

In [5]:
tr_df = utils.read_feature_csv(TR_FEATURES_CSV)
tr_df.head(2)

ID,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,x_0,y_0,z_0,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5,scene_id,scene_dsec,origin_idx
str,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,i32,i64
"""00066be8e20318869c38c66be46663…",5.701526,1.538456,-2.165777,-139.0,0.0,False,0.25,True,"""drive""",False,False,2.82959,0.032226,0.045187,6.231999,0.065895,0.107974,9.785009,0.124972,0.203649,13.485472,0.163448,0.302818,17.574227,0.174289,0.406331,21.951269,0.199503,0.485079,"""00066be8e20318869c38c66be46663…",320,0
"""00066be8e20318869c38c66be46663…",11.176292,0.279881,-11.625697,-44.0,0.0,False,0.0,False,"""drive""",False,True,4.970268,-0.007936,0.005028,10.350489,-0.032374,-0.020701,15.770054,0.084073,0.008645,21.132415,0.391343,0.036335,26.316489,0.843124,0.065,31.383814,1.42507,0.073083,"""00066be8e20318869c38c66be46663…",420,1


In [6]:
ts_df = utils.read_feature_csv(TS_FEATURES_CSV)
ts_df.head(2)

ID,vEgo,aEgo,steeringAngleDeg,steeringTorque,brake,brakePressed,gas,gasPressed,gearShifter,leftBlinker,rightBlinker,scene_id,scene_dsec,origin_idx
str,f64,f64,f64,f64,f64,bool,f64,bool,str,bool,bool,str,i32,i64
"""012baccc145d400c896cb82065a93d…",3.374273,-0.01936,-34.008415,17.0,0.0,False,0.0,False,"""drive""",False,False,"""012baccc145d400c896cb82065a93d…",120,0
"""012baccc145d400c896cb82065a93d…",2.441048,-0.022754,307.860077,295.0,0.0,True,0.0,False,"""drive""",False,False,"""012baccc145d400c896cb82065a93d…",220,1


In [7]:
def reduce_base_pred(
    df: pl.DataFrame, base_pred_df: pl.DataFrame, has_target: bool
) -> pl.DataFrame:
    """
    基礎推定値のdfをconcatし、元のdfのtarget列から除去したdfを返す

    Args:
        df (pl.DataFrame): target列を持つDataFrame
        base_pred_df (pl.DataFrame): 基礎推定値のDataFrame
        has_target (bool): target列を持つかどうか
    """
    pf = "base_pred_"
    df = pl.concat(
        [df, base_pred_df.select(pl.all().name.prefix(pf))], how="horizontal"
    )

    if has_target:
        target_cols = TARGET_COLS
        df = df.with_columns(
            [
                (pl.col(tg_col) - pl.col(f"{pf}{tg_col}")).alias(tg_col)
                for tg_col in target_cols
            ]
        )
    return df


def add_base_pred_to_target(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
    """
    target_colsの列に基礎推定量を足したDataFrameを返す

    Args:
        df (pl.DataFrame): target列を持つDataFrame
        target_cols (list[str]): 追加する基礎推定量の列名
    """
    pf = "base_pred_"
    base_pred_cols = [f"{pf}{tg_col}" for tg_col in TARGET_COLS]
    df = df.with_columns(
        [
            (pl.col(tg_col) + pl.col(bp_col)).alias(tg_col)
            for tg_col, bp_col in zip(target_cols, base_pred_cols)
        ]
    )

    return df


if BASE_PRED_DIR is not None:
    # columns: "x_0", "y_0", "z_0", ..., "x_5", "y_5", "z_5"
    base_oof_pred_df = pl.read_csv(BASE_OOF_PRED_CSV)
    base_submission_df = pl.read_csv(BASE_SUBMISSION_CSV)

    # 基礎推定値を元のtarget列から引いた値を新たなtarget列とする
    tr_df = reduce_base_pred(tr_df, base_oof_pred_df, has_target=True)
    ts_df = reduce_base_pred(ts_df, base_submission_df, has_target=False)

    del base_oof_pred_df
    gc.collect()

    del base_submission_df
    gc.collect()

In [8]:
# tr_tl_bbox_images = utils.load_npy_images(
#    IMAGES_DIR,
#    ids=tr_df.get_column("ID").to_list(),
#    image_names=[TRAFFIC_LIGHTS_BBOX_IMAGE_NAME],
# )
# print(tr_tl_bbox_images.shape)
# ts_tl_bbox_images = utils.load_npy_images(
#    IMAGES_DIR,
#    ids=ts_df.get_column("ID").to_list(),
#    image_names=[TRAFFIC_LIGHTS_BBOX_IMAGE_NAME],
# )
# print(ts_tl_bbox_images.shape)

In [9]:
# tr_optical_flow_images = utils.load_npy_images(
#    IMAGES_DIR,
#    ids=tr_df.get_column("ID").to_list(),
#    image_names=[OPTICAL_FLOW_IMAGE_NAME],
# )
# print(tr_optical_flow_images.shape)
# ts_optical_flow_images = utils.load_npy_images(
#    IMAGES_DIR,
#    ids=ts_df.get_column("ID").to_list(),
#    image_names=[OPTICAL_FLOW_IMAGE_NAME],
# )
# print(ts_optical_flow_images.shape)

In [10]:
tr_images = utils.load_images(
    IMAGES_DIR, ids=tr_df.get_column("ID").to_list(), image_names=IMAGE_NAMES
)
print(tr_images.shape)
ts_images = utils.load_images(
    IMAGES_DIR, ids=ts_df.get_column("ID").to_list(), image_names=IMAGE_NAMES
)
print(ts_images.shape)

(43371, 3, 64, 128, 3)
(1727, 3, 64, 128, 3)


In [11]:
tr_images = utils.preprocess_images(
    # [tr_images, tr_tl_bbox_images, tr_optical_flow_images]
    # [tr_images, tr_tl_bbox_images]
    [tr_images]
)
ts_images = utils.preprocess_images(
    # [ts_images, ts_tl_bbox_images, ts_optical_flow_images]
    # [ts_images, ts_tl_bbox_images]
    [ts_images]
)

print(tr_images.shape)
print(ts_images.shape)

(43371, 9, 64, 128)
(1727, 9, 64, 128)


In [12]:
# del tr_tl_bbox_images
# gc.collect()
#
# del ts_tl_bbox_images
# gc.collect()
#
# del tr_optical_flow_images
# gc.collect()
#
# del ts_optical_flow_images
# gc.collect()

## scene_dsec順に並び替える

In [13]:
tr_df = tr_df.sort(["scene_id", "scene_dsec"])
ts_df = ts_df.sort(["scene_id", "scene_dsec"])

tr_images = tr_images[tr_df.get_column("origin_idx").to_numpy()]
ts_images = ts_images[ts_df.get_column("origin_idx").to_numpy()]

## Target

In [14]:
target = utils.CoordinateTarget(prefix="tg_")
target.fit(tr_df)

tg_df = target.transform(tr_df)
print(tg_df.columns)
print(tg_df.describe().glimpse())
tr_df = pl.concat([tr_df, tg_df], how="horizontal")

del tg_df
gc.collect()

['tg_cood_x_0', 'tg_cood_y_0', 'tg_cood_z_0', 'tg_cood_x_1', 'tg_cood_y_1', 'tg_cood_z_1', 'tg_cood_x_2', 'tg_cood_y_2', 'tg_cood_z_2', 'tg_cood_x_3', 'tg_cood_y_3', 'tg_cood_z_3', 'tg_cood_x_4', 'tg_cood_y_4', 'tg_cood_z_4', 'tg_cood_x_5', 'tg_cood_y_5', 'tg_cood_z_5']
Rows: 9
Columns: 19
$ statistic   <str> 'count', 'null_count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'
$ tg_cood_x_0 <f64> 43371.0, 0.0, -6.398172096493292e-05, 0.09748330055151712, -2.049703424044697, -0.03889588129092658, -0.0004598591444684372, 0.04074949288609275, 1.1879026061010576
$ tg_cood_y_0 <f64> 43371.0, 0.0, -1.8151129571594546e-05, 0.06193532851615388, -2.548924091006372, -0.022029638339427435, -0.0006609180929956013, 0.02146723659824248, 3.8037482134210494
$ tg_cood_z_0 <f64> 43371.0, 0.0, 8.790075275766784e-06, 0.04012050718886584, -0.9958970690087603, -0.018423500383461567, 0.0003501801640828847, 0.01789679478671308, 1.444553594762872
$ tg_cood_x_1 <f64> 43371.0, 0.0, 7.390598477211262e-05, 0.19

0

## 特徴量

In [15]:
feature = utils.Feature(prefix="ft_")
feature.fit(tr_df)

ft_df = feature.transform(tr_df)
print(ft_df.columns)
print(ft_df.describe().glimpse())
tr_df = pl.concat([tr_df, ft_df], how="horizontal")

ft_df = feature.transform(ts_df)
print(ft_df.columns)
print(ft_df.describe().glimpse())
ts_df = pl.concat([ts_df, ft_df], how="horizontal")

del ft_df
gc.collect()

['ft_vEgo', 'ft_aEgo', 'ft_steeringAngleDeg', 'ft_steeringTorque', 'ft_brake', 'ft_brakePressed', 'ft_gas', 'ft_gasPressed', 'ft_is_gearShifter_drive', 'ft_is_gearShifter_neutral', 'ft_is_gearShifter_park', 'ft_is_gearShifter_reverse', 'ft_leftBlinker', 'ft_rightBlinker', 'ft_base_pred_x0', 'ft_base_pred_y0', 'ft_base_pred_z0', 'ft_base_pred_x1', 'ft_base_pred_y1', 'ft_base_pred_z1', 'ft_base_pred_x2', 'ft_base_pred_y2', 'ft_base_pred_z2', 'ft_base_pred_x3', 'ft_base_pred_y3', 'ft_base_pred_z3', 'ft_base_pred_x4', 'ft_base_pred_y4', 'ft_base_pred_z4', 'ft_base_pred_x5', 'ft_base_pred_y5', 'ft_base_pred_z5']
Rows: 9
Columns: 33
$ statistic                 <str> 'count', 'null_count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max'
$ ft_vEgo                   <f64> 43371.0, 0.0, 9.172175823216334, 7.226919878374694, -0.1619189828634262, 2.5786657333374023, 8.518790245056152, 14.286815643310547, 27.55126190185547
$ ft_aEgo                   <f64> 43371.0, 0.0, -0.015654028629347255, 0.63

0

## モデリング

In [16]:
N_SPLITS = 2

In [17]:
n_sample_in_scene = 3

model_params = {
    "dnn": {
        "n_sample_in_scene": n_sample_in_scene,
        "n_img_channels": tr_images.shape[1] * n_sample_in_scene,
        "n_features": len(feature.columns) * n_sample_in_scene,
        "n_targets": len(target.columns),
        "segmentation_model_params": {
            "encoder_name": "mit_b1",
            "encoder_weights": "imagenet",
            "decoder_channels": (256, 128, 64, 32, 16),
        },
        "dropout": 0.0,
        "embed_dim": 128,
        "n_layers": 1,
    },
    "dnn_pretrained_model": {
        # list[str]: len(list) == n_splits
        "weight_path": None,
        "load_only_backbone": None,
    },
    "dev": "cuda",
}

lr = 5e-5
fit_params = {
    "dnn": {
        "tr_batch_size": 32,
        "vl_batch_size": 1024,
        "trainer_params": {
            "criterion_params": {},
            "opt": "adamw",
            "opt_params": {"lr": lr, "weight_decay": 1e-4},
            "backbone_opt_params": {"lr": lr, "weight_decay": 1e-4},
            "sch_params": {
                "max_lr": lr,
                "pct_start": 0.1,
                "div_factor": 25,
                "final_div_factor": 1000,
            },
            "epochs": 5,
            "dev": "cuda",
            "val_freq": 1,
            "prefix": "",
            "save_best": False,
            "save_epochs": [],
            "maximize_score": False,
            "grad_max_norm": None,
        },
    },
}

In [18]:
models, oof_preds = utils.train(
    model_params=model_params,
    fit_params=fit_params,
    df=tr_df,
    images=tr_images,
    target_cols=target.columns,
    feature_cols=feature.columns,
    group_col="scene_id",
    scene_id_col="scene_id",
    scene_dsec_col="scene_dsec",
    n_splits=N_SPLITS,
)

-----------------
-----------------
Training fold 0...
train samples: 21685, valid samples: 21686
Save model : fold0_model.pth

epoch  0
lr  2.000000000000001e-06
lr  2.000000000000001e-06
lr  2.000000000000001e-06
lr  2.000000000000001e-06


100%|██████████| 677/677 [00:33<00:00, 20.36it/s]
100%|██████████| 22/22 [00:09<00:00,  2.27it/s]



Train Loss: 6.0416
{'loss': 6.041637798884066, 'loss_mse_0': 0.10027925635480282, 'loss_mse_1': 0.06259180630861261, 'loss_mse_2': 0.04347558127801475, 'loss_mse_3': 0.20440936460540776, 'loss_mse_4': 0.12272585514916363, 'loss_mse_5': 0.08802750970799603, 'loss_mse_6': 0.34315643184426264, 'loss_mse_7': 0.20565721811933038, 'loss_mse_8': 0.13363107364828336, 'loss_mse_9': 0.533099793304897, 'loss_mse_10': 0.33997773916145546, 'loss_mse_11': 0.18118558574792554, 'loss_mse_12': 0.7817452829468726, 'loss_mse_13': 0.5305131286096537, 'loss_mse_14': 0.2297010048508556, 'loss_mse_15': 1.0729827450772684, 'loss_mse_16': 0.787785679674219, 'loss_mse_17': 0.28069272404055123}
Valid Loss: 5.8365
{'loss': 5.836506323380903, 'loss_mse_0': 0.10447026348926804, 'loss_mse_1': 0.07121787448836998, 'loss_mse_2': 0.03852981041100892, 'loss_mse_3': 0.21161795813928952, 'loss_mse_4': 0.13356894491748375, 'loss_mse_5': 0.07847689058292996, 'loss_mse_6': 0.33521666580980475, 'loss_mse_7': 0.21174740588123

100%|██████████| 677/677 [00:33<00:00, 20.21it/s]
100%|██████████| 22/22 [00:07<00:00,  2.86it/s]



Train Loss: 5.7475
{'loss': 5.7475195952321085, 'loss_mse_0': 0.09566783864034596, 'loss_mse_1': 0.058692906917165, 'loss_mse_2': 0.04100613992804306, 'loss_mse_3': 0.1941836594907977, 'loss_mse_4': 0.11465556646093308, 'loss_mse_5': 0.08361798861618408, 'loss_mse_6': 0.32642881655886746, 'loss_mse_7': 0.19488232004479608, 'loss_mse_8': 0.12685994046741608, 'loss_mse_9': 0.5108752582370795, 'loss_mse_10': 0.3219507498064105, 'loss_mse_11': 0.1717581079565704, 'loss_mse_12': 0.7442213817000213, 'loss_mse_13': 0.5054386724639502, 'loss_mse_14': 0.21787688499825567, 'loss_mse_15': 1.0249418494576366, 'loss_mse_16': 0.7480084145967154, 'loss_mse_17': 0.26645309931462413}
Valid Loss: 5.7183
{'loss': 5.718297069722956, 'loss_mse_0': 0.1031824028627439, 'loss_mse_1': 0.07108725090934472, 'loss_mse_2': 0.03835704279216853, 'loss_mse_3': 0.2111990519545295, 'loss_mse_4': 0.13164673305370592, 'loss_mse_5': 0.07780905931510708, 'loss_mse_6': 0.33452619273554196, 'loss_mse_7': 0.20654458755796606

100%|██████████| 677/677 [00:32<00:00, 20.52it/s]
100%|██████████| 22/22 [00:07<00:00,  2.81it/s]



Train Loss: 5.6234
{'loss': 5.623433191279717, 'loss_mse_0': 0.09520218674605231, 'loss_mse_1': 0.05802854034290312, 'loss_mse_2': 0.040825438718731066, 'loss_mse_3': 0.19234750287791652, 'loss_mse_4': 0.1122115966609433, 'loss_mse_5': 0.08314797284695123, 'loss_mse_6': 0.3223149251396406, 'loss_mse_7': 0.1886916761967509, 'loss_mse_8': 0.1260341169199792, 'loss_mse_9': 0.5040203524723187, 'loss_mse_10': 0.30946573787857196, 'loss_mse_11': 0.1698649681092101, 'loss_mse_12': 0.7344623017654461, 'loss_mse_13': 0.48416190459069686, 'loss_mse_14': 0.2148393346863404, 'loss_mse_15': 1.0110509010854465, 'loss_mse_16': 0.7154080631153622, 'loss_mse_17': 0.2613556915495209}
Valid Loss: 5.6315
{'loss': 5.631495497443459, 'loss_mse_0': 0.1045788089660081, 'loss_mse_1': 0.07082702583548697, 'loss_mse_2': 0.03863816310397603, 'loss_mse_3': 0.21341595527800647, 'loss_mse_4': 0.1298867944966663, 'loss_mse_5': 0.07800378921357068, 'loss_mse_6': 0.3373113112016158, 'loss_mse_7': 0.2007917958227071, '

100%|██████████| 677/677 [00:34<00:00, 19.91it/s]
100%|██████████| 22/22 [00:07<00:00,  2.87it/s]



Train Loss: 5.4070
{'loss': 5.406982464487479, 'loss_mse_0': 0.09373199438436873, 'loss_mse_1': 0.057154518303129835, 'loss_mse_2': 0.04053500776854584, 'loss_mse_3': 0.18811154405708855, 'loss_mse_4': 0.10826014181766852, 'loss_mse_5': 0.08209217511436612, 'loss_mse_6': 0.3138556908183295, 'loss_mse_7': 0.17926846718638476, 'loss_mse_8': 0.1242541279396155, 'loss_mse_9': 0.4880739949874723, 'loss_mse_10': 0.2899310442796486, 'loss_mse_11': 0.1664650716668922, 'loss_mse_12': 0.7126070860239568, 'loss_mse_13': 0.4521027695569189, 'loss_mse_14': 0.20986788261522218, 'loss_mse_15': 0.9822743333865128, 'loss_mse_16': 0.6640718039923088, 'loss_mse_17': 0.25432477836418715}
Valid Loss: 5.5873
{'loss': 5.587331663478505, 'loss_mse_0': 0.10365593602711504, 'loss_mse_1': 0.0709340146488764, 'loss_mse_2': 0.038866904310204765, 'loss_mse_3': 0.21144481003284454, 'loss_mse_4': 0.12972533364187588, 'loss_mse_5': 0.07854279621758244, 'loss_mse_6': 0.3355046463283626, 'loss_mse_7': 0.198498483408581

100%|██████████| 677/677 [00:33<00:00, 20.19it/s]
100%|██████████| 22/22 [00:07<00:00,  2.92it/s]
  model.load_state_dict(torch.load(model_path))



Train Loss: 5.1299
{'loss': 5.129887324844541, 'loss_mse_0': 0.09195736179642787, 'loss_mse_1': 0.05612358979143455, 'loss_mse_2': 0.040258394555633056, 'loss_mse_3': 0.18210941495888341, 'loss_mse_4': 0.10458868195851673, 'loss_mse_5': 0.08118045502280379, 'loss_mse_6': 0.3005291968363955, 'loss_mse_7': 0.1689939306804049, 'loss_mse_8': 0.1221785542735365, 'loss_mse_9': 0.4679850478226008, 'loss_mse_10': 0.2673067119044074, 'loss_mse_11': 0.16258292878528113, 'loss_mse_12': 0.6789759365151165, 'loss_mse_13': 0.4152492538116531, 'loss_mse_14': 0.2043558641706911, 'loss_mse_15': 0.9354731081331498, 'loss_mse_16': 0.6035881920702524, 'loss_mse_17': 0.24645070884063114}
Valid Loss: 5.5945
{'loss': 5.59454835544933, 'loss_mse_0': 0.10404779220169241, 'loss_mse_1': 0.07124704579738053, 'loss_mse_2': 0.03909642147746953, 'loss_mse_3': 0.21227225254882465, 'loss_mse_4': 0.13016190312125467, 'loss_mse_5': 0.07895680432292548, 'loss_mse_6': 0.336767968129028, 'loss_mse_7': 0.19915723665194077,

KeyboardInterrupt: 

In [None]:
oof_preds = oof_preds.select(pl.all().name.prefix("pred_"))
pred_cols = oof_preds.columns

tr_df = pl.concat([tr_df, oof_preds], how="horizontal")
tr_df

## 評価

In [None]:
def calc_score(df: pl.DataFrame, pred_cols: list[str]):
    tg_cols = sum([[f"x_{i}", f"y_{i}", f"z_{i}"] for i in range(6)], [])

    tg = df.select(tg_cols).to_numpy()
    pred = df.select(pred_cols).to_numpy()

    scores = np.abs(tg - pred).mean(axis=0)
    scores = {f"score_{col}": float(score) for col, score in zip(pred_cols, scores)}
    scores["avg"] = float(np.abs(tg - pred).mean())
    return scores


scores = calc_score(tr_df, pred_cols)
scores

In [None]:
utils.plot_calibration_curve(tr_df, pred_cols, n_bins=40)

In [None]:
if BASE_PRED_DIR is not None:
    # 差し引いていた基礎推定値を足して元のtarget, pred列に戻す
    tr_df = add_base_pred_to_target(tr_df, TARGET_COLS)
    tr_df = add_base_pred_to_target(tr_df, pred_cols)

In [None]:
if BASE_PRED_DIR is not None:
    scores = calc_score(tr_df, pred_cols)
    display(scores)

In [None]:
if BASE_PRED_DIR is not None:
    utils.plot_calibration_curve(tr_df, pred_cols, n_bins=40)

## oofを保存

In [None]:
def create_submission_csv(preds: pl.DataFrame, filename: str = "submission.csv"):
    submission_cols = TARGET_COLS

    # validate preds columns
    if len(preds.columns) != len(submission_cols):
        raise ValueError(
            f"preds columns must be {len(submission_cols)}, but got {len(preds.columns)}"
        )

    preds.columns = submission_cols
    preds.write_csv(filename)
    print(f"Submission file is created: {filename}")


# 元の順番に戻して保存
create_submission_csv(tr_df.sort("origin_idx").select(pred_cols), "oof_preds.csv")

## Submission

In [None]:
preds = utils.predict(
    models,
    ts_images,
    ts_df,
    feature.columns,
    scene_id_col="scene_id",
    scene_dsec_col="scene_dsec",
    pred_cols=pred_cols,
)
pred_cols = preds.columns
ts_df = pl.concat([ts_df, preds], how="horizontal")

preds

In [None]:
if BASE_PRED_DIR is not None:
    # 差し引いていた基礎推定値を足して元のtarget, pred列に戻す
    ts_df = add_base_pred_to_target(ts_df, pred_cols)
    display(ts_df)

In [None]:
# 元の順番に戻す
ts_df = ts_df.sort("origin_idx")

In [None]:
create_submission_csv(ts_df.select(pred_cols), "submission.csv")