In [2]:
import sys, os
sys.path.append(os.pardir)

In [3]:
from pathlib import Path
import numpy as np
import polars as pl
import os
from hydra import initialize, compose

with initialize(config_path="../run/conf", version_base=None):
    cfg = compose("cv_train")

In [4]:
from src.utils.metrics import event_detection_ap
from src.utils.periodicity import get_periodicity_dict
from src.utils.common import trace
periodicity_dict = get_periodicity_dict(cfg)



In [5]:
train_df = pl.read_parquet(Path(cfg.dir.data_dir) / "train_series.parquet")
train_df = train_df.with_columns(
            pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%z")
        )

In [6]:
event_df = pl.read_csv(Path(cfg.dir.data_dir) / "train_events.csv").drop_nulls()
event_df = event_df.with_columns(
    pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%z")
)

In [7]:
pred1_path2col_dict = {
   "148_gru_scale_factor.parquet": "148_gru_scale_factor",
    "156_gru_transformer_residual.parquet" : "156_gru_transformer_residual",
}
pred2_path2col_dict = {
    "../output/cv_inference/exp065_split_drop/single/train_pred.parquet": "exp065_split_drop",
    "../output/cv_inference/exp068_transformer/single/train_pred.parquet": "exp068_transformer",
    "../output/cv_inference/exp078_lstm/single/train_pred.parquet": "exp078_lstm",
}

In [8]:
df_list = [train_df]
events = ["onset", "wakeup"]
for path, name in pred1_path2col_dict.items():
    pred_df = pl.read_parquet(path)
    pred_df = pred_df.with_columns(
        [pl.col("prediction_"+col).cast(pl.Float32).alias(name+"_"+col) for col in events]
    ).select([name+"_"+col for col in events])
    df_list.append(pred_df)

for path, name in pred2_path2col_dict.items():
    pred_df = pl.read_parquet(path)
    pred_df = pred_df.with_columns(
        [pl.col("pred_"+col).cast(pl.Float32).alias(name+"_"+col) for col in events]
    ).select([name+"_"+col for col in events])
    df_list.append(pred_df)


pred_all_df = pl.concat(df_list, how="horizontal")
pred_all_df.head()

series_id,step,timestamp,anglez,enmo,148_gru_scale_factor_onset,148_gru_scale_factor_wakeup,156_gru_transformer_residual_onset,156_gru_transformer_residual_wakeup,exp065_split_drop_onset,exp065_split_drop_wakeup,exp068_transformer_onset,exp068_transformer_wakeup,exp078_lstm_onset,exp078_lstm_wakeup
str,u32,"datetime[μs, UTC]",f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
"""038441c925bb""",0,2018-08-14 19:30:00 UTC,2.6367,0.0217,1.3214e-09,3e-06,6.2617e-13,2.8e-05,0.014008,0.02298,2.9e-05,3.5e-05,0.004398,0.011375
"""038441c925bb""",1,2018-08-14 19:30:05 UTC,2.6368,0.0215,2.4195e-09,3e-06,9.6546e-13,3.4e-05,0.01252,0.018784,3.1e-05,3.2e-05,0.003567,0.009117
"""038441c925bb""",2,2018-08-14 19:30:10 UTC,2.637,0.0216,4.3217e-09,4e-06,1.1043e-12,3.4e-05,0.009552,0.010391,3.3e-05,2.6e-05,0.001902,0.004608
"""038441c925bb""",3,2018-08-14 19:30:15 UTC,2.6368,0.0213,4.5551e-09,4e-06,7.27e-13,3.4e-05,0.009407,0.006439,3.9e-05,2.2e-05,0.001182,0.002462
"""038441c925bb""",4,2018-08-14 19:30:20 UTC,2.6368,0.0215,5.2894e-09,3e-06,7.1623e-13,3.2e-05,0.012085,0.006927,4.8e-05,1.9e-05,0.001407,0.002687


In [9]:
from scipy.optimize import minimize

In [10]:
from tqdm.auto import tqdm
from scipy.signal import find_peaks

def make_submission(
    preds_df: pl.DataFrame,
    periodicity_dict: dict[str, np.ndarray],
    height: float = 0.001,
    distance: int = 100,
    day_norm: bool = False,
    daily_score_offset: float = 1.0,
    pred_prefix: str = "prediction",
    late_date_rate: float | None = None,
) -> pl.DataFrame:
    event_dfs = []

    for series_id, series_df in tqdm(
        preds_df.group_by("series_id"), desc="find peaks", leave=False, total=len(preds_df["series_id"].unique())
    ):
        for event in ["onset", "wakeup"]:
            event_preds = series_df[f"{pred_prefix}_{event}"].to_numpy().copy()
            event_preds *= 1 - periodicity_dict[series_id][: len(event_preds)]
            steps = find_peaks(event_preds, height=height, distance=distance)[0]
            event_dfs.append(
                series_df.filter(pl.col("step").is_in(steps))
                .with_columns(pl.lit(event).alias("event"))
                .rename({f"{pred_prefix}_{event}": "score"})
                .select(["series_id", "step", "timestamp", "event", "score"])
            )

    submission_df = (
        pl.concat(event_dfs).sort(["series_id", "step"]).with_columns(pl.arange(0, pl.count()).alias("row_id"))
    )

    if day_norm:
        submission_df = submission_df.with_columns(
            pl.col("timestamp").dt.offset_by("2h").dt.date().alias("date")
        ).with_columns(
            pl.col("score") / (pl.col("score").sum().over(["series_id", "event", "date"]) + daily_score_offset)
        )

    if late_date_rate is not None:
        submission_df = (
            submission_df.with_columns(pl.col("timestamp").dt.offset_by("2h").dt.date().alias("date"))
            .with_columns(
                pl.col("date").min().over("series_id").alias("min_date"),
                pl.col("date").max().over("series_id").alias("max_date"),
            )
            .with_columns(
                pl.col("score")
                * (
                    1
                    - (
                        (1 - pl.lit(late_date_rate))
                        * (
                            (pl.col("date") - pl.col("min_date")).dt.days()
                            / ((pl.col("max_date") - pl.col("min_date")).dt.days() + 1.0)
                        )
                    )
                )
            )
        )

    return submission_df.select(["row_id", "series_id", "step", "event", "score"])


In [16]:
cols = ["148_gru_scale_factor", "156_gru_transformer_residual", "exp065_split_drop", 'exp068_transformer', "exp078_lstm"]

def calc_score(param, pred_all_df):
    param = param/param.sum()
    tmp_df = pred_all_df.with_columns(
        [(pl.col(f"{col}_{event}")* param[i]).alias(f'weighted_{col}_{event}') for i, col in enumerate(cols) for event in events]
    )
    tmp_df = tmp_df.with_columns(
        [pl.sum_horizontal([f'weighted_{col}_{event}' for col in cols]).alias(f"prediction_{event}")
        for event in events]
    )

    sub_df1 = make_submission(
        tmp_df,
        periodicity_dict= periodicity_dict,
        height = 0.001,
        distance = 107,
        daily_score_offset = 1.0,
        day_norm=True
    ) 
    score = event_detection_ap(
        event_df.to_pandas(),
        sub_df1.to_pandas(),
    )
    print(param, score)

    return -score

In [18]:
calc_score(np.array([0.42164168, 0.4814393,0,0,0]),pred_all_df)

find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.46689244 0.53310756 0.         0.         0.        ] 0.816267170390424


-0.816267170390424

In [25]:

weights = [1.0/len(cols) for _ in range(len(cols)) ]
# cons = ({'type': 'eq', 'fun': lambda x: x.sum()})
res = minimize(calc_score, weights, args=(pred_all_df), 
               method='Nelder-Mead',
               options={'disp':True, "maxiter":100}, 
               bounds=[[0.0, 1.0] for _ in range(len(weights))]
              )
# [0.42164168 0.48143932 0.096919   0.         0.        ] 0.8197241628618371


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.2 0.2 0.2 0.2 0.2] 0.8142127037458327


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20792079 0.1980198  0.1980198  0.1980198  0.1980198 ] 0.8142750518990867


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.1980198  0.20792079 0.1980198  0.1980198  0.1980198 ] 0.8143136473030057


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.1980198  0.1980198  0.20792079 0.1980198  0.1980198 ] 0.8140901037752398


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.1980198  0.1980198  0.1980198  0.20792079 0.1980198 ] 0.8141702212727759


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.1980198  0.1980198  0.1980198  0.1980198  0.20792079] 0.8140862828696005


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.2027833 0.2027833 0.2027833 0.2027833 0.1888668] 0.814304728599401


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20469932 0.20469932 0.19076065 0.20469932 0.19514138] 0.8142589228999724


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20739203 0.20739203 0.19781272 0.19342221 0.193981  ] 0.8144717139830571


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.21213335 0.21213335 0.19770797 0.18608751 0.19193781] 0.8147063986937246


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.21013667 0.21013667 0.19495065 0.19589977 0.18887623] 0.8145343504610083


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20767015 0.20767015 0.20578478 0.18770735 0.19116757] 0.8145173543639357


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20434591 0.21828694 0.20169153 0.19018182 0.18549381] 0.8147366825155249


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20254301 0.2285081  0.20354327 0.18622895 0.17917667] 0.8149001383376133


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.20940299 0.22371187 0.19721563 0.1788392  0.19083031] 0.8148746771330089


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.21883078 0.22498515 0.20167553 0.17580837 0.17870017] 0.8152863202162917


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.22937859 0.23363403 0.20352839 0.16455076 0.16890822] 0.8157381504178438


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.2177902  0.2357245  0.19287349 0.176937   0.1766748 ] 0.8153401712245925


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.21841554 0.24372455 0.20309646 0.16076324 0.17400021] 0.8155634893593178


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.21882274 0.25393503 0.20238562 0.16091119 0.16394542] 0.8156850250047398


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.22551881 0.25480199 0.20504074 0.1607822  0.15385626] 0.8159479915649527


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.2338449  0.2708644  0.2090835  0.15145323 0.13475398] 0.8163377799661083


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.24516952 0.2668884  0.20074919 0.13918921 0.14800367] 0.816429709168061


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.26722685 0.2867485  0.19930338 0.11484824 0.13187303] 0.8171895562751565


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.2492963  0.27992991 0.21432772 0.12384592 0.13260015] 0.8167471853985322


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.26091551 0.28613529 0.20835123 0.12570079 0.11889718] 0.8171195205124353


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.27841096 0.28931204 0.21162324 0.11056215 0.11009161] 0.8175212537713386


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31005347 0.30809792 0.2165286  0.08382582 0.08149419] 0.8182452943087728


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.30074034 0.34219919 0.21582118 0.07296156 0.06827774] 0.8186128278150113


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.34004406 0.40199334 0.22259164 0.0225172  0.01285375] 0.8189309956023463


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.33766475 0.35373995 0.21516352 0.03637204 0.05705975] 0.8188456029232074


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.3588009  0.37549056 0.20992518 0.02858372 0.02719963] 0.8190222563287488


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.39737511 0.40591084 0.19671404 0.         0.        ] 0.8191139044160535


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.39057313 0.40545672 0.20397015 0.         0.        ] 0.8189991755445727


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.39379372 0.41102143 0.19518485 0.         0.        ] 0.8190733023265071


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.39002178 0.43368112 0.1762971  0.         0.        ] 0.8192164961758843


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.3961521  0.45500757 0.14884033 0.         0.        ] 0.8194531767849604


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.3980711  0.44329595 0.15863294 0.         0.        ] 0.819327660739704


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.42640975 0.43305512 0.14053513 0.         0.        ] 0.8195421857838251


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.44999379 0.43826409 0.11174212 0.         0.        ] 0.8196141840618415


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.4225677  0.45149347 0.12593883 0.         0.        ] 0.8195120444586599


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43079172 0.46235637 0.10685191 0.         0.        ] 0.8196647986574159


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.44301423 0.47931517 0.0776706  0.         0.        ] 0.8196572574376566


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43454122 0.47569407 0.08976471 0.         0.        ] 0.8196638220247431


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.44883225 0.46715017 0.08401757 0.         0.        ] 0.8196585652085593


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.4662883  0.46322341 0.07048829 0.         0.        ] 0.8195037860000072


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.45322128 0.46169272 0.085086   0.         0.        ] 0.8196537962996517


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.46013234 0.46973134 0.07013632 0.         0.        ] 0.8195478986364358


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.45237016 0.46596275 0.08166709 0.         0.        ] 0.8196606356687006


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43990531 0.48901236 0.07108233 0.         0.        ] 0.819598769820929


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.44698161 0.45341632 0.09960208 0.         0.        ] 0.8196631450103946


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43192289 0.46905769 0.09901941 0.         0.        ] 0.8196721133921571


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.41989977 0.47321529 0.10688495 0.         0.        ] 0.8196167790607775


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.42884933 0.46373173 0.10741894 0.         0.        ] 0.819673286275953


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.4162813  0.46158174 0.12213696 0.         0.        ] 0.8195768039830463


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.41326797 0.46411677 0.12261526 0.         0.        ] 0.8195965120766735


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.44395467 0.46556546 0.09047987 0.         0.        ] 0.8197110943674901


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.42164168 0.48143932 0.096919   0.         0.        ] 0.8197241628618371


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.40928021 0.49510967 0.09561013 0.         0.        ] 0.8196699951160633


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.42778362 0.45959694 0.11261944 0.         0.        ] 0.8196323371621813


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43312903 0.47233011 0.09454086 0.         0.        ] 0.8197038022914865


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.4330609  0.47740473 0.08953438 0.         0.        ] 0.8196724681733403


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43257848 0.4752695  0.09215202 0.         0.        ] 0.8196750171475944


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43110852 0.46526831 0.10362318 0.         0.        ] 0.8196595750724319


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43262263 0.47468035 0.09269702 0.         0.        ] 0.819695659577053


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.43619067 0.48223903 0.0815703  0.         0.        ] 0.8196486734982109


KeyboardInterrupt: 

In [33]:
cols = ["exp065_split_drop", 'exp068_transformer', "exp078_lstm"]

def calc_score(param, pred_all_df):
    param = param/param.sum()
    tmp_df = pred_all_df.with_columns(
        [(pl.col(f"{col}_{event}")* param[i]).alias(f'weighted_{col}_{event}') for i, col in enumerate(cols) for event in events]
    )
    tmp_df = tmp_df.with_columns(
        [pl.sum_horizontal([f'weighted_{col}_{event}' for col in cols]).alias(f"prediction_{event}")
        for event in events]
    )

    sub_df1 = make_submission(
        tmp_df,
        periodicity_dict= periodicity_dict,
        height = 0.001,
        distance = 107,
        daily_score_offset = 1.0,
        day_norm=False
    ) 
    score = event_detection_ap(
        event_df.to_pandas(),
        sub_df1.to_pandas(),
    )
    print(param, score)

    return -score



weights = [1.0/len(cols) for _ in range(len(cols)) ]
# cons = ({'type': 'eq', 'fun': lambda x: x.sum()})
res = minimize(calc_score, weights, args=(pred_all_df), 
               method='Nelder-Mead',
               options={'disp':True, "maxiter":100}, 
               bounds=[[0.0, 1.0] for _ in range(len(weights))]
              )
# 0.7966527549813314


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.33333333 0.33333333 0.33333333] 0.8023403512023469


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.3442623  0.32786885 0.32786885] 0.8022210785269291


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32786885 0.3442623  0.32786885] 0.8023365106590106


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32786885 0.32786885 0.3442623 ] 0.8024864496357588


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31491713 0.34254144 0.34254144] 0.8024616704234082


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32282004 0.32467532 0.35250464] 0.8024416261817242


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31059683 0.33008526 0.3593179 ] 0.802435645090285


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31622365 0.33088909 0.35288726] 0.8024914211667002


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31664656 0.34258142 0.34077201] 0.8024761876956356


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32546037 0.32526021 0.34927942] 0.8024247522013199


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31759678 0.33814926 0.34425396] 0.8024713620676294


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32286792 0.32950943 0.34762264] 0.8024950002261494


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32546037 0.32526021 0.34927942] 0.8024247522013199


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.32810111 0.31607379 0.35582509] 0.8023762275673343


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

[0.31948052 0.3360232  0.34449628] 0.8024224492237793


find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from tqdm.auto import tqdm
def score_ternary_search_distance(
    val_event_df: pl.DataFrame, pred_df, score_th: float = 0.005, end_diff: int=2, prefix="prediction"
) -> [float, float]:
    """
    post_process_for_seg のパラメータdistanceを ternary searchで探索する
    """
    l = 5
    r = 150
    cnt = 0
    best_score = 0.0
    best_distance = 0

    for cnt in tqdm(range(5)):
        if r - l < 1:
            break
        m1 = int(l + (r - l) / 3)
        m2 = int(r - (r - l) / 3)
        score1 = event_detection_ap(
            val_event_df.to_pandas(),
            make_submission(
                pred_df,
                height = score_th,
                distance = m1,
                periodicity_dict = periodicity_dict,
                prefix=prefix,
            ).to_pandas()
        )
        score2 = event_detection_ap(
            val_event_df.to_pandas(),
            make_submission(
                pred_df,
                height = score_th,
                distance = m2,
                periodicity_dict = periodicity_dict,
                prefix=prefix,
            ).to_pandas(),
        )

        if score1 >= score2:
            r = m2
            best_score = score1
            best_distance = m1

        else:
            l = m1
            best_score = score2
            best_distance = m2

        tqdm.write(f"score1(m1): {score1:.5f}({m1:.5f}), score2(m2): {score2:.5f}({m2:.5f}), l: {l:.5f}, r: {r:.5f}")

        if abs(m2 - m1) <= end_diff:
            break

    return best_score, best_distance


In [15]:
cols = ["148_gru_scale_factor", "156_gru_transformer_residual", "exp065_split_drop", 'exp068_transformer', "exp078_lstm"]

col = "exp065_split_drop"
tmp_df = pred_all_df.with_columns(
    [pl.col(f'{col}_{event}').alias(f"prediction_{event}") for event in events]
)

sub_df1 = make_submission(
    tmp_df,
    periodicity_dict= periodicity_dict,
    height = 0.001,
    distance = 107,
    daily_score_offset = 1.0,
    day_norm=True
) 
score = event_detection_ap(
    event_df.to_pandas(),
    sub_df1.to_pandas(),
)
print(score) #0.78818291 0.756548
 

find peaks:   0%|          | 0/277 [00:00<?, ?it/s]

Matching detections to ground truth events:   0%|          | 0/538 [00:00<?, ?it/s]

0.7565489802637723
