In [None]:
!pip install -q --no-index --find-links=/kaggle/input/mabe-package xgboost==3.1.1

In [None]:
!rm -rf /kaggle/working/*

In [None]:
import datetime
import gc
import itertools
import json
import re
import sys
import time
import traceback
from collections import defaultdict
from pathlib import Path

import joblib
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import xgboost as xgb
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedGroupKFold
from tqdm.auto import tqdm

sys.path.append("/kaggle/usr/lib/mabe-f-beta")
from metric import score

In [None]:
# const
INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
EXTERNAL_TRACKING_INPUT_DIR = Path("/kaggle/input/mabe-outlier17")
EXTERNAL_ANNOT_INPUT_DIR = Path("/kaggle/input/mabe-train-annotation-01")

# TRAIN_TRACKING_DIR = INPUT_DIR / "train_tracking"
TRAIN_TRACKING_DIR = EXTERNAL_TRACKING_INPUT_DIR / "cleaned_train_tracking"
TRAIN_ANNOTATION_DIR = INPUT_DIR / "train_annotation"
# TRAIN_ANNOTATION_DIR = EXTERNAL_ANNOT_INPUT_DIR / "filtered_train_annotation"

WORKING_DIR = Path("/kaggle/working")

INDEX_COLS = [
    "video_id",
    "agent_mouse_id",
    "target_mouse_id",
    "video_frame",
]
BODY_PARTS = [
    "ear_left",
    "ear_right",
    "nose",
    "neck",
    "body_center",
    "lateral_left",
    "lateral_right",
    "hip_left",
    "hip_right",
    "tail_base",
    "tail_tip",
]

SELF_BEHAVIORS = [
    "biteobject",
    "climb",
    "dig",
    "exploreobject",
    "freeze",
    "genitalgroom",
    "huddle",
    "rear",
    "rest",
    "run",
    "selfgroom",
]

PAIR_BEHAVIORS = [
    "allogroom",
    "approach",
    "attack",
    "attemptmount",
    "avoid",
    "chase",
    "chaseattack",
    "defend",
    "disengage",
    "dominance",
    "dominancegroom",
    "dominancemount",
    "ejaculate",
    "escape",
    "flinch",
    "follow",
    "intromit",
    "mount",
    "reciprocalsniff",
    "shepherd",
    "sniff",
    "sniffbody",
    "sniffface",
    "sniffgenital",
    "submit",
    "tussle",
]

In [None]:
# read data
train_dataframe = pl.read_csv(INPUT_DIR / "train.csv")

In [None]:
train_dataframe = train_dataframe.filter((pl.col("video_id") != 1212811043))
train_dataframe

# DEBUG:
# train_dataframe = train_dataframe.filter((pl.col("lab_id") == "AdaptableSnail") & (pl.col("video_id") != 1212811043))
# train_dataframe

# Data Processing

## Pre-preparation

In [None]:
# preprocess behavior labels
train_behavior_dataframe = (
    train_dataframe.filter(pl.col("behaviors_labeled").is_not_null())
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled").map_elements(eval, return_dtype=pl.List(pl.Utf8)).alias("behaviors_labeled_list"),
    )
    .explode("behaviors_labeled_list")
    .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled_element").str.split(",").list[0].str.replace_all("'", "").alias("agent"),
        pl.col("behaviors_labeled_element").str.split(",").list[1].str.replace_all("'", "").alias("target"),
        pl.col("behaviors_labeled_element").str.split(",").list[2].str.replace_all("'", "").alias("behavior"),
    )
)

train_self_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(SELF_BEHAVIORS))
train_pair_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(PAIR_BEHAVIORS))

In [None]:
train_self_behavior_dataframe.head()

In [None]:
train_pair_behavior_dataframe.head()

## Feature Engineering

In [None]:
%%writefile self_features.py
import polars as pl
import itertools

def make_self_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    def body_parts_distance(body_part_1, body_part_2):
        """
        agent の bodypart 間の距離（cm）
        """
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"agent_x_{body_part_1}") - pl.col(f"agent_x_{body_part_2}")).pow(2)
            + (pl.col(f"agent_y_{body_part_1}") - pl.col(f"agent_y_{body_part_2}")).pow(2)
        ).sqrt()

    def body_part_speed(body_part, period_ms):
        """
        bodypart の速度（cm/s）を period_ms [ms] の窓で平滑化したもの
        """
        assert body_part in BODY_PARTS

        window_frames = max(
            1,
            int(round(period_ms * metadata["frames_per_second"] / 1000.0))
        )

        x = pl.col(f"agent_x_{body_part}").cast(pl.Float32)
        y = pl.col(f"agent_y_{body_part}").cast(pl.Float32)

        # 差分 -> 距離（cm/frame）-> fps 掛けて cm/s
        speed = (
            ((x.diff()).pow(2) + (y.diff()).pow(2)).sqrt()
            * metadata["frames_per_second"]
        )

        return speed.rolling_mean(window_size=window_frames, center=True)

    def elongation():
        """
        鼻〜尾根本 / 左右耳 の比
        """
        d1 = body_parts_distance("nose", "tail_base")
        d2 = body_parts_distance("ear_left", "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle():
        """
        nose - body_center と tail_base - body_center のなす角の cos 類似度
        """
        v1x = pl.col("agent_x_nose") - pl.col("agent_x_body_center")
        v1y = pl.col("agent_y_nose") - pl.col("agent_y_body_center")
        v2x = pl.col("agent_x_tail_base") - pl.col("agent_x_body_center")
        v2y = pl.col("agent_y_tail_base") - pl.col("agent_y_body_center")

        return (v1x * v2x + v1y * v2y) / (
            (v1x.pow(2) + v1y.pow(2)).sqrt()
            * (v2x.pow(2) + v2y.pow(2)).sqrt()
            + 1e-06
        )

    # px -> cm
    # tracking = tracking.with_columns(
    #     (pl.col("x") / metadata["pix_per_cm_approx"]).alias("x"),
    #     (pl.col("y") / metadata["pix_per_cm_approx"]).alias("y"),
    # )

    # マウス数
    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )

    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame   = tracking.select(pl.col("video_frame").max()).item()

    result = []

    # (video_frame, mouse_id, bodypart) → x,y のワイド形式
    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])

    pivot_trackings = {}
    
    for mouse_id in range(1, n_mice + 1):
        df_mouse = pivot.filter(pl.col("mouse_id") == mouse_id)
        if df_mouse.height == 0:
            continue
        pivot_trackings[mouse_id] = df_mouse

    for agent_mouse_id in range(1, n_mice + 1):
        if agent_mouse_id not in pivot_trackings:
            continue

        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": -1,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        pivot_agent = pivot_trackings[agent_mouse_id].select(
            pl.col("video_frame"),
            pl.from_epoch(
                pl.col("video_frame")
                .truediv(metadata["frames_per_second"])
                .mul(1_000_000),
                time_unit="us",
            ).alias("timestamp"),
            pl.exclude("video_frame").name.prefix("agent_"),
        )

        # bodypart が欠けている場合は None で補完
        columns = pivot_agent.columns
        pivot_agent = pivot_agent.with_columns(
            *[
                pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}")
                for bp in BODY_PARTS if f"agent_x_{bp}" not in columns
            ],
            *[
                pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}")
                for bp in BODY_PARTS if f"agent_y_{bp}" not in columns
            ],
        )

        # --- speed 列を追加 ---
        speed_cols: list[str] = []
        for body_part, period_ms in itertools.product(
            ["ear_left", "ear_right", "tail_base"],
            [500, 1000, 2000],
        ):
            colname = f"agent__{body_part}__speed_{period_ms}ms"
            speed_cols.append(colname)
            pivot_agent = pivot_agent.with_columns(
                body_part_speed(body_part, period_ms).alias(colname)
            )

        # --- speed から acc を計算 ---
        # acc_exprs: list[pl.Expr] = []
        # acc_cols: list[str] = []
        # for body_part, period_ms in itertools.product(
        #     ["ear_left", "ear_right", "tail_base"],
        #     [500, 1000],
        # ):
        #     speed_col = f"agent__{body_part}__speed_{period_ms}ms"
        #     acc_col   = f"agent__{body_part}__acc_mean_{period_ms}ms"

        #     W = max(
        #         1,
        #         int(round(period_ms * metadata["frames_per_second"] / 1000.0))
        #     )

        #     acc_exprs.append(
        #         (pl.col(speed_col).diff() * metadata["frames_per_second"])
        #         .rolling_mean(window_size=W, center=True)
        #         .alias(acc_col)
        #     )
        #     acc_cols.append(acc_col)

        # pivot_agent = pivot_agent.with_columns(acc_exprs)

        features = pivot_agent.with_columns(
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(-1).alias("target_mouse_id"),
        ).select(
            pl.col("video_frame"),
            pl.col("agent_mouse_id"),
            pl.col("target_mouse_id"),
            # 体の形状（距離）
            *[
                body_parts_distance(b1, b2).alias(f"aa__{b1}__{b2}__distance")
                for b1, b2 in itertools.combinations(BODY_PARTS, 2)
            ],
            # speed
            *[pl.col(c) for c in speed_cols],
            # acc
            # *[pl.col(c) for c in acc_cols],

            # elongation / angle
            elongation().alias("agent__elongation"),
            body_angle().alias("agent__body_angle"),
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")

In [None]:
%%writefile pair_features.py

def make_pair_features(
    metadata: dict,
    tracking: pl.DataFrame,
) -> pl.DataFrame:
    def body_parts_distance(agent_or_target_1, body_part_1, agent_or_target_2, body_part_2):
        # agent / target 間の bodypart 距離（cm）
        assert agent_or_target_1 in ("agent", "target")
        assert agent_or_target_2 in ("agent", "target")
        assert body_part_1 in BODY_PARTS
        assert body_part_2 in BODY_PARTS
        return (
            (pl.col(f"{agent_or_target_1}_x_{body_part_1}") - pl.col(f"{agent_or_target_2}_x_{body_part_2}")).pow(2)
            + (pl.col(f"{agent_or_target_1}_y_{body_part_1}") - pl.col(f"{agent_or_target_2}_y_{body_part_2}")).pow(2)
        ).sqrt()

    def body_part_speed(agent_or_target, body_part, period_ms):
        # bodypart の速度（cm/s）
        assert agent_or_target in ("agent", "target")
        assert body_part in BODY_PARTS

        window_frames = max(
            1,
            int(round(period_ms * metadata["frames_per_second"] / 1000.0))
        )

        x = pl.col(f"{agent_or_target}_x_{body_part}").cast(pl.Float32)
        y = pl.col(f"{agent_or_target}_y_{body_part}").cast(pl.Float32)

        speed = (
            ((x.diff()).pow(2) + (y.diff()).pow(2)).sqrt()
            * metadata["frames_per_second"]   # cm/s
        )

        return speed.rolling_mean(window_size=window_frames, center=True)

    def elongation(agent_or_target):
        assert agent_or_target in ("agent", "target")
        d1 = body_parts_distance(agent_or_target, "nose", agent_or_target, "tail_base")
        d2 = body_parts_distance(agent_or_target, "ear_left", agent_or_target, "ear_right")
        return d1 / (d2 + 1e-06)

    def body_angle(agent_or_target):
        assert agent_or_target in ("agent", "target")
        v1x = pl.col(f"{agent_or_target}_x_nose") - pl.col(f"{agent_or_target}_x_body_center")
        v1y = pl.col(f"{agent_or_target}_y_nose") - pl.col(f"{agent_or_target}_y_body_center")
        v2x = pl.col(f"{agent_or_target}_x_tail_base") - pl.col(f"{agent_or_target}_x_body_center")
        v2y = pl.col(f"{agent_or_target}_y_tail_base") - pl.col(f"{agent_or_target}_y_body_center")
        return (v1x * v2x + v1y * v2y) / (
            (v1x.pow(2) + v1y.pow(2)).sqrt()
            * (v2x.pow(2) + v2y.pow(2)).sqrt()
            + 1e-06
        )

    n_mice = (
        (metadata["mouse1_strain"] is not None)
        + (metadata["mouse2_strain"] is not None)
        + (metadata["mouse3_strain"] is not None)
        + (metadata["mouse4_strain"] is not None)
    )
    start_frame = tracking.select(pl.col("video_frame").min()).item()
    end_frame   = tracking.select(pl.col("video_frame").max()).item()

    result = []

    pivot = tracking.pivot(
        on=["bodypart"],
        index=["video_frame", "mouse_id"],
        values=["x", "y"],
    ).sort(["mouse_id", "video_frame"])

    pivot_trackings: dict[int, pl.DataFrame] = {}

    # マウスごとに外れ値除去
    for mouse_id in range(1, n_mice + 1):
        df_mouse = pivot.filter(pl.col("mouse_id") == mouse_id)
        if df_mouse.height == 0:
            continue
        pivot_trackings[mouse_id] = df_mouse

    
    for agent_mouse_id, target_mouse_id in itertools.permutations(range(1, n_mice + 1), 2):
        if agent_mouse_id not in pivot_trackings or target_mouse_id not in pivot_trackings:
            continue

        result_element = pl.DataFrame(
            {
                "video_id": metadata["video_id"],
                "agent_mouse_id": agent_mouse_id,
                "target_mouse_id": target_mouse_id,
                "video_frame": pl.arange(start_frame, end_frame + 1, eager=True),
            },
            schema={
                "video_id": pl.Int32,
                "agent_mouse_id": pl.Int8,
                "target_mouse_id": pl.Int8,
                "video_frame": pl.Int32,
            },
        )

        merged_pivot = (
            pivot_trackings[agent_mouse_id]
            .select(
                pl.col("video_frame"),
                pl.exclude("video_frame").name.prefix("agent_"),
            )
            .join(
                pivot_trackings[target_mouse_id].select(
                    pl.col("video_frame"),
                    pl.exclude("video_frame").name.prefix("target_"),
                ),
                on="video_frame",
                how="left",  # inner to left
            )
            .with_columns(
                pl.from_epoch(
                    pl.col("video_frame")
                    .truediv(metadata["frames_per_second"])
                    .mul(1_000_000),
                    time_unit="us",
                ).alias("timestamp"),
            )
        )

        columns = merged_pivot.columns
        merged_pivot = merged_pivot.with_columns(
            *[
                pl.lit(None).cast(pl.Float32).alias(f"agent_x_{bp}")
                for bp in BODY_PARTS if f"agent_x_{bp}" not in columns
            ],
            *[
                pl.lit(None).cast(pl.Float32).alias(f"agent_y_{bp}")
                for bp in BODY_PARTS if f"agent_y_{bp}" not in columns
            ],
            *[
                pl.lit(None).cast(pl.Float32).alias(f"target_x_{bp}")
                for bp in BODY_PARTS if f"target_x_{bp}" not in columns
            ],
            *[
                pl.lit(None).cast(pl.Float32).alias(f"target_y_{bp}")
                for bp in BODY_PARTS if f"target_y_{bp}" not in columns
            ],
        )

        # speed 列を追加（agent / target 両方）
        agent_speed_cols: list[str] = []
        target_speed_cols: list[str] = []
        speed_exprs: list[pl.Expr] = []

        for agent_or_target, speed_cols in [
            ("agent", agent_speed_cols),
            ("target", target_speed_cols),
        ]:
            for body_part, period_ms in itertools.product(
                ["ear_left", "ear_right", "tail_base"],
                [500, 1000],
            ):
                colname = f"{agent_or_target}__{body_part}__speed_{period_ms}ms"
                speed_cols.append(colname)
                speed_exprs.append(
                    body_part_speed(agent_or_target, body_part, period_ms).alias(colname)
                )

        merged_pivot = merged_pivot.with_columns(speed_exprs)

        # --- speed から加速度列を計算 ---
        # agent_acc_cols: list[str] = []
        # target_acc_cols: list[str] = []
        # acc_exprs: list[pl.Expr] = []

        # for agent_or_target, speed_cols, acc_cols in [
        #     ("agent", agent_speed_cols, agent_acc_cols),
        #     ("target", target_speed_cols, target_acc_cols),
        # ]:
        #     for body_part, period_ms in itertools.product(
        #         ["ear_left", "ear_right", "tail_base"],
        #         [500, 1000],
        #     ):
        #         speed_col = f"{agent_or_target}__{body_part}__speed_{period_ms}ms"
        #         acc_col   = f"{agent_or_target}__{body_part}__acc_mean_{period_ms}ms"

        #         W = max(
        #             1,
        #             int(round(period_ms * metadata["frames_per_second"] / 1000.0))
        #         )

        #         acc_exprs.append(
        #             (pl.col(speed_col).diff() * metadata["frames_per_second"])
        #             .rolling_mean(window_size=W, center=True)
        #             .alias(acc_col)
        #         )
        #         acc_cols.append(acc_col)

        # merged_pivot = merged_pivot.with_columns(acc_exprs)

        # 最終的な特徴抽出
        features = merged_pivot.with_columns(
            pl.lit(agent_mouse_id).alias("agent_mouse_id"),
            pl.lit(target_mouse_id).alias("target_mouse_id"),
        ).select(
            pl.col("video_frame"),
            pl.col("agent_mouse_id"),
            pl.col("target_mouse_id"),
            # dist（agent bodypart × target bodypart）
            *[
                body_parts_distance("agent", agent_body_part, "target", target_body_part).alias(
                    f"at__{agent_body_part}__{target_body_part}__distance"
                )
                for agent_body_part, target_body_part in itertools.product(BODY_PARTS, repeat=2)
            ],
            # speed
            *[pl.col(c) for c in agent_speed_cols],
            *[pl.col(c) for c in target_speed_cols],
            # acc
            # *[pl.col(c) for c in agent_acc_cols],
            # *[pl.col(c) for c in target_acc_cols],
            
            # elongation / angle
            elongation("agent").alias("agent__elongation"),
            elongation("target").alias("target__elongation"),
            body_angle("agent").alias("agent__body_angle"),
            body_angle("target").alias("target__body_angle"),
        )

        result_element = result_element.join(
            features,
            on=["video_frame", "agent_mouse_id", "target_mouse_id"],
            how="left",
        )
        result.append(result_element)

    return pl.concat(result, how="vertical")


In [None]:
%run -i self_features.py
%run -i pair_features.py

def process_video(row):
    """Process a single video to extract self and pair features."""
    lab_id = row["lab_id"]
    video_id = row["video_id"]

    tracking_path = TRAIN_TRACKING_DIR / f"{lab_id}/{video_id}.parquet"
    tracking = pl.read_parquet(tracking_path)

    self_features = make_self_features(metadata=row, tracking=tracking)
    pair_features = make_pair_features(metadata=row, tracking=tracking)

    self_features.write_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
    pair_features.write_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")

    return video_id


# make data
(WORKING_DIR / "self_features").mkdir(exist_ok=True, parents=True)
(WORKING_DIR / "pair_features").mkdir(exist_ok=True, parents=True)

rows = list(train_dataframe.filter(pl.col("behaviors_labeled").is_not_null()).rows(named=True))
results = joblib.Parallel(n_jobs=-1, verbose=5)(joblib.delayed(process_video)(row) for row in rows) # 特徴量生成
print(f"Processed {len(results)} videos successfully")

del rows, results
gc.collect()

# Training

In [None]:
def tune_threshold(oof_action, y_action):
    thresholds = np.arange(0, 1.005, 0.005)
    scores = [f1_score(y_action, (oof_action >= th), zero_division=0) for th in thresholds]
    best_idx = np.argmax(scores)
    return thresholds[best_idx]

In [None]:
def group_lab(df):
    
    # group_a = ["ReflectiveManatee", "SparklingTapir"]
    
    # group_b = [
    #     "CRIM13", "CalMS21_supplemental", "CautiousGiraffe",
    #     "CalMS21_task1", "CalMS21_task2", "ElegantMink",
    #     "InvincibleJellyfish", "JovialSwallow", "TranquilPanther"
    # ]

    group_c = ["PleasantMeerkat", "DeliriousFly"]
    
    # mapping
    mapping = {
        **{lab: "Pleasant_and_Delirious" for lab in group_c},
    }

    # mapping = {
    #     **{lab: "Pleasant_and_Delirious" for lab in group_c},
    #     **{lab: "Cal_and_Others" for lab in group_b},
    # }
    
    grouped_df = df.with_columns(
        pl.col("lab_id").replace(mapping).alias("grouped_lab_id")
    )
    return grouped_df


In [None]:
train_self_behavior_dataframe = group_lab(train_self_behavior_dataframe)
train_pair_behavior_dataframe = group_lab(train_pair_behavior_dataframe)

In [None]:
def train_validate(grouped_lab_id: str, behavior: str, indices: pl.DataFrame, features: pl.DataFrame, labels: pl.Series):
    result_dir = WORKING_DIR / "results" / grouped_lab_id / behavior
    result_dir.mkdir(exist_ok=True, parents=True)

    if labels.sum() == 0:
        with open(result_dir / "f1.txt", "w") as f:
            f.write("0.0\n")
        oof_prediction_dataframe = indices.with_columns(
            pl.Series("fold", [-1] * len(labels), dtype=pl.Int8),
            pl.Series("prediction", [0.0] * len(labels), dtype=pl.Float32),
            pl.Series("predicted_label", [0] * len(labels), dtype=pl.Int8),
        )
        oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")
        return 0.0

    folds = np.ones(len(labels), dtype=np.int8) * -1
    oof_predictions = np.zeros(len(labels), dtype=np.float32)
    oof_prediction_labels = np.zeros(len(labels), dtype=np.int8)

    for fold, (train_idx, valid_idx) in enumerate(
        StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=42).split(
            X=features,
            y=labels,
            groups=indices.get_column("video_id"),
        )
    ):
        result_dir_fold = result_dir / f"fold_{fold}"
        result_dir_fold.mkdir(exist_ok=True, parents=True)

        X_train = features[train_idx]
        y_train = labels[train_idx]
        X_valid = features[valid_idx]
        y_valid = labels[valid_idx]

        scale_pos_weight = (len(y_train) - y_train.sum()) / y_train.sum()

        params = {
            "objective": "binary:logistic",
            "eval_metric": "logloss",
            "device": "cpu",
            "tree_method": "hist",
            "learning_rate": 0.05,
            "max_depth": 6,
            "min_child_weight": 5,
            "subsample": 0.8,
            "colsample_bytree": 0.8,
            "scale_pos_weight": scale_pos_weight,
            "max_bin": 64,
            "seed": 42,
        }
        dtrain = xgb.QuantileDMatrix(X_train, label=y_train, feature_names=features.columns, max_bin=64)
        dvalid = xgb.DMatrix(X_valid, label=y_valid, feature_names=features.columns)

        evals_result = {}
        early_stopping_callback = xgb.callback.EarlyStopping(
            rounds=10,
            metric_name="logloss",
            data_name="valid",
            maximize=False,
            save_best=True,
        )
        model = xgb.train(
            params,
            dtrain=dtrain,
            num_boost_round=250,
            evals=[(dtrain, "train"), (dvalid, "valid")],
            callbacks=[early_stopping_callback],
            evals_result=evals_result,
            verbose_eval=0,
        )

        fold_predictions = model.predict(dvalid)

        threshold = tune_threshold(fold_predictions, y_valid)
        folds[valid_idx] = fold
        oof_predictions[valid_idx] = fold_predictions
        oof_prediction_labels[valid_idx] = (fold_predictions >= threshold).astype(np.int8)

        # save results
        model.save_model(result_dir_fold / "model.json")
        with open(result_dir_fold / "threshold.txt", "w") as f:
            f.write(f"{threshold}\n")

        xgb.plot_importance(model, max_num_features=20, importance_type="gain", values_format="{v:.2f}")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "feature_importance.png")
        plt.close()

        lgb.plot_metric(evals_result, metric="logloss")
        plt.tight_layout()
        plt.savefig(result_dir_fold / "metric.png")
        plt.close()

        gc.collect()

    oof_prediction_dataframe = indices.with_columns(
        pl.Series("fold", folds, dtype=pl.Int8),
        pl.Series("prediction", oof_predictions, dtype=pl.Float32),
        pl.Series("predicted_label", oof_prediction_labels, dtype=pl.Int8),
    )
    f1 = f1_score(labels, oof_prediction_labels, zero_division=0)
    with open(result_dir / "f1.txt", "w") as f:
        f.write(f"{f1}\n")

    oof_prediction_dataframe.write_parquet(result_dir / "oof_predictions.parquet")

        
    return f1


In [None]:
train_self_behavior_dataframe

In [None]:
"""
self features
"""
groups = train_self_behavior_dataframe.group_by("grouped_lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((grouped_lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):

    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{grouped_lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    # Each data 
    for row in group.rows(named=True):
        original_lab_id = row["lab_id"] 
        video_id = row["video_id"]
        agent = row["agent"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))

        data = pl.scan_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / original_lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter((pl.col("action") == behavior) & (pl.col("agent_id") == agent_mouse_id))
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(grouped_lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

In [None]:
"""
pair features
"""
groups = train_pair_behavior_dataframe.group_by("grouped_lab_id", "behavior", maintain_order=True)
total_groups = len(list(groups))
start_time = time.perf_counter()

for idx, ((grouped_lab_id, behavior), group) in tqdm(enumerate(groups), total=total_groups):
    
    if idx == 0:
        tqdm.write(
            f"|{'LAB':^25}|{'BEHAVIOR':^15}|{'SAMPLES':^10}|{'POSITIVE':^10}|{'FEATURES':^10}|{'F1':^10}|{'ELAPSED TIME':^15}|",
            end="\n",
        )

    tqdm.write(f"|{grouped_lab_id:^25}|{behavior:^15}|", end="")
    index_list = []
    feature_list = []
    label_list = []

    for row in group.rows(named=True):
        original_lab_id = row["lab_id"] 
        video_id = row["video_id"]
        agent = row["agent"]
        target = row["target"]

        agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
        target_mouse_id = int(re.search(r"mouse(\d+)", target).group(1))

        data = pl.scan_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet").filter(
            (pl.col("agent_mouse_id") == agent_mouse_id) & (pl.col("target_mouse_id") == target_mouse_id)
        )
        index = data.select(INDEX_COLS).collect(engine="streaming")
        feature = data.select(pl.exclude(INDEX_COLS)).collect(engine="streaming")

        # read annotation
        annotation_path = TRAIN_ANNOTATION_DIR / original_lab_id / f"{video_id}.parquet"
        if annotation_path.exists():
            annotation = (
                pl.scan_parquet(annotation_path)
                .filter(
                    (pl.col("action") == behavior)
                    & (pl.col("agent_id") == agent_mouse_id)
                    & (pl.col("target_id") == target_mouse_id)
                )
                .collect()
            )
        else:
            annotation = pl.DataFrame(
                schema={
                    "agent_id": pl.Int8,
                    "target_id": pl.Int8,
                    "action": str,
                    "start_frame": pl.Int16,
                    "stop_frame": pl.Int16,
                }
            )

        label_frames = set()
        for annotation_row in annotation.rows(named=True):
            label_frames.update(range(annotation_row["start_frame"], annotation_row["stop_frame"]))
        label = index.select(pl.col("video_frame").is_in(label_frames).cast(pl.Int8).alias("label"))

        if label.get_column("label").sum() == 0:
            continue

        index_list.append(index)
        feature_list.append(feature)
        label_list.append(label.get_column("label"))

    if not index_list:
        elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
        tqdm.write(f"{0:>10,}|{0:>10,}|{0:>10,}|{'-':>10}|{str(elapsed_time):>15}|", end="\n")
        continue

    indices = pl.concat(index_list, how="vertical")
    features = pl.concat(feature_list, how="vertical")
    labels = pl.concat(label_list, how="vertical")

    del index_list, feature_list, label_list
    gc.collect()

    tqdm.write(f"{len(indices):>10,}|{labels.sum():>10,}|{len(features.columns):>10,}|", end="")

    f1 = train_validate(grouped_lab_id, behavior, indices, features, labels)
    tqdm.write(f"{f1:>10.2f}|", end="")

    elapsed_time = datetime.timedelta(seconds=int(time.perf_counter() - start_time))
    tqdm.write(f"{str(elapsed_time):>15}|", end="\n")

    gc.collect()

In [None]:
%%writefile robustify.py

def robustify(submission: pl.DataFrame, dataset: pl.DataFrame, train_test: str = "train"):
    # traintest_directory = INPUT_DIR / f"{train_test}_tracking"
    traintest_directory = TRAIN_TRACKING_DIR
    old_submission = submission.clone()
    submission = submission.filter(pl.col("start_frame") < pl.col("stop_frame"))
    if len(submission) != len(old_submission):
        print("ERROR: Dropped frames with start >= stop")

    old_submission = submission.clone()
    group_list = []
    for _, group in submission.group_by("video_id", "agent_id", "target_id"):
        group = group.sort("start_frame")
        mask = np.ones(len(group), dtype=bool)
        last_stop_frame = 0
        for i, row in enumerate(group.rows(named=True)):
            if row["start_frame"] < last_stop_frame:
                mask[i] = False
            else:
                last_stop_frame = row["stop_frame"]
        group_list.append(group.filter(pl.Series("mask", mask)))

    submission = pl.concat(group_list)

    if len(submission) != len(old_submission):
        print("ERROR: Dropped duplicate frames")

    s_list = []
    for row in dataset.rows(named=True):
        lab_id = row["lab_id"]
        video_id = row["video_id"]
        if row["behaviors_labeled"] is None:
            continue

        if video_id in submission.get_column("video_id").to_list():
            continue

        if isinstance(row["behaviors_labeled"], str):
            continue

        print(f"Video {video_id} has no predictions.")

        # tracking data
        path = traintest_directory / f"/{lab_id}/{video_id}.parquet"
        vid = pd.read_parquet(path)

        vid_behaviors = json.loads(row["behaviors_labeled"])
        vid_behaviors = sorted(list({b.replace("'", "") for b in vid_behaviors}))
        vid_behaviors = [b.split(",") for b in vid_behaviors]
        vid_behaviors = pd.DataFrame(vid_behaviors, columns=["agent", "target", "action"])

        start_frame = vid.video_frame.min()
        stop_frame = vid.video_frame.max() + 1

        for (agent, target), actions in vid_behaviors.groupby(["agent", "target"]):
            batch_length = int(np.ceil((stop_frame - start_frame) / len(actions)))
            for i, action_row in enumerate(actions.itertuples(index=False)):
                batch_start = start_frame + i * batch_length
                batch_stop = min(batch_start + batch_length, stop_frame)
                s_list.append((video_id, agent, target, action_row["action"], batch_start, batch_stop))

    if len(s_list) > 0:
        submission = pd.concat(
            [
                submission,
                pd.DataFrame(s_list, columns=["video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]),
            ]
        )
        print("ERROR: Filled empty videos")

    return submission

In [None]:
group_oof_predictions = []
train_behavior_dataframe = group_lab(train_behavior_dataframe)
groups = train_behavior_dataframe.group_by("grouped_lab_id", "video_id", "agent", "target", maintain_order=True)

for (grouped_lab_id, video_id, agent, target), group in tqdm(groups, total=len(list(groups))):
    agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
    target_mouse_id = -1 if target == "self" else int(re.search(r"mouse(\d+)", target).group(1))

    prediction_dataframe_list = []

    for row in group.rows(named=True):
        behavior = row["behavior"]

        oof_path = WORKING_DIR / "results" / grouped_lab_id / behavior / "oof_predictions.parquet"
        if not oof_path.exists():
            continue

        prediction = (
            pl.scan_parquet(oof_path)
            .filter(
                (pl.col("video_id") == video_id)
                & (pl.col("agent_mouse_id") == agent_mouse_id)
                & (pl.col("target_mouse_id") == target_mouse_id)
            )
            .select(*INDEX_COLS, (pl.col("prediction") * pl.col("predicted_label")).alias(behavior))
            .collect()
        )

        if len(prediction) == 0:
            continue

        prediction_dataframe_list.append(prediction)

    if not prediction_dataframe_list:
        continue

    prediction_dataframe = pl.concat(prediction_dataframe_list, how="align")

    cols = prediction_dataframe.select(pl.exclude(INDEX_COLS)).columns
    prediction_labels_dataframe = prediction_dataframe.with_columns(
        pl.struct(pl.exclude(INDEX_COLS))
        .map_elements(
            lambda row: "none" if sum(row.values()) == 0 else (cols[np.argmax(list(row.values()))]),
            return_dtype=pl.String,
        )
        .alias("prediction")
    ).select(INDEX_COLS + ["prediction"])

    group_oof_prediction = (
        prediction_labels_dataframe.filter((pl.col("prediction") != pl.col("prediction").shift(1)))
        .with_columns(pl.col("video_frame").shift(-1).alias("stop_frame"))
        .filter(pl.col("prediction") != "none")
        .select(
            pl.col("video_id"),
            ("mouse" + pl.col("agent_mouse_id").cast(str)).alias("agent_id"),
            pl.when(pl.col("target_mouse_id") == -1)
            .then(pl.lit("self"))
            .otherwise("mouse" + pl.col("target_mouse_id").cast(str))
            .alias("target_id"),
            pl.col("prediction").alias("action"),
            pl.col("video_frame").alias("start_frame"),
            pl.col("stop_frame"),
        )
    )

    group_oof_predictions.append(group_oof_prediction)

%run -i robustify.py

oof_predictions = pl.concat(group_oof_predictions, how="vertical")
oof_predictions = robustify(oof_predictions, train_dataframe, train_test="train")
oof_predictions.with_row_index("row_id").write_csv(WORKING_DIR / "oof_predictions.csv")

In [None]:
def group_lab_pd(df):
    group_c = ["PleasantMeerkat", "DeliriousFly"]
    
    mapping = {lab: "Pleasant_and_Delirious" for lab in group_c}

    df["grouped_lab_id"] = df["lab_id"].replace(mapping)

    return df

In [None]:
def compute_validation_metrics(submission, verbose=True):
    """Compute and display validation metrics for single vs pair behaviors."""
    # solution_df
    dataset = pl.read_csv(INPUT_DIR / "train.csv").to_pandas()
    dataset = group_lab_pd(dataset)
    
    solution = []
    for _, row in dataset.iterrows():
        lab_id = row["lab_id"]
        grouped_lab_id = row["grouped_lab_id"]
        if lab_id.startswith("MABe22"):
            continue

        video_id = row["video_id"]
        path = TRAIN_ANNOTATION_DIR / lab_id / f"{video_id}.parquet"
        try:
            annot = pd.read_parquet(path)
        except FileNotFoundError:
            continue

        annot["lab_id"] = lab_id
        annot["grouped_lab_id"] = grouped_lab_id
        annot["video_id"] = video_id
        annot["behaviors_labeled"] = row["behaviors_labeled"]
        annot["target_id"] = np.where(
            annot.target_id != annot.agent_id, annot["target_id"].apply(lambda s: f"mouse{s}"), "self"
        )
        annot["agent_id"] = annot["agent_id"].apply(lambda s: f"mouse{s}")
        solution.append(annot)

    solution = pd.concat(solution)

    try:
        # Separate single and pair behaviors
        submission_single = submission[submission["target_id"] == "self"].copy()
        submission_pair = submission[submission["target_id"] != "self"].copy()

        # Filter solution to match submission videos
        solution_videos = set(submission["video_id"].unique())
        solution = solution[solution["video_id"].isin(solution_videos)]

        if len(solution) == 0:
            return

        # Compute overall F1 score
        overall_f1 = score(solution, submission, "row_id", beta=1.0)
        print(f"\n{'=' * 60}")
        print("PERFORMANCE METRICS")
        print(f"{'=' * 60}")
        print(f"Overall F1 Score: {overall_f1:.4f}")
        print(f"Total predictions: {len(submission)}")
        print(f"  - Single behaviors: {len(submission_single)}")
        print(f"  - Pair behaviors: {len(submission_pair)}")

        # Compute per-action F1 scores using existing scoring function
        solution_pl = pl.DataFrame(solution)
        submission_pl = pl.DataFrame(submission)

        # Add label_key and prediction_key
        solution_pl = solution_pl.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias("label_key"),
        )
        submission_pl = submission_pl.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias("prediction_key"),
        )

        # Group by action and compute metrics
        action_stats = defaultdict(lambda: {"single": {"count": 0, "f1": 0.0}, "pair": {"count": 0, "f1": 0.0}})

        for lab in solution_pl["grouped_lab_id"].unique():
            lab_solution = solution_pl.filter(pl.col("grouped_lab_id") == lab).clone()
            lab_videos = set(lab_solution["video_id"].unique())
            lab_submission = submission_pl.filter(pl.col("video_id").is_in(lab_videos)).clone()

            # Compute per-action F1 using same logic as single_lab_f1
            label_frames = defaultdict(set)
            prediction_frames = defaultdict(set)

            for row in lab_solution.to_dicts():
                label_frames[row["label_key"]].update(range(row["start_frame"], row["stop_frame"]))

            for row in lab_submission.to_dicts():
                key = row["prediction_key"]
                prediction_frames[key].update(range(row["start_frame"], row["stop_frame"]))

            for key in set(list(label_frames.keys()) + list(prediction_frames.keys())):
                action = key.split("_")[-1]
                mode = "single" if "self" in key else "pair"

                pred_frames = prediction_frames.get(key, set())
                label_frames_set = label_frames.get(key, set())

                tp = len(pred_frames & label_frames_set)
                fn = len(label_frames_set - pred_frames)
                fp = len(pred_frames - label_frames_set)

                if tp + fn + fp > 0:
                    f1 = (1 + 1**2) * tp / ((1 + 1**2) * tp + 1**2 * fn + fp)
                    action_stats[action][mode]["count"] += 1
                    action_stats[action][mode]["f1"] += f1

        # Print per-action summary
        print("\nPer-Action Performance Summary:")
        print(f"{'-' * 60}")
        print(f"{'Action':<20} {'Mode':<10} {'Count':<10} {'Avg F1':<10}")
        print(f"{'-' * 60}")

        for action in sorted(action_stats.keys()):
            for mode in ["single", "pair"]:
                stats = action_stats[action][mode]
                if stats["count"] > 0:
                    avg_f1 = stats["f1"] / stats["count"]
                    print(f"{action:<20} {mode:<10} {stats['count']:<10} {avg_f1:<10.4f}")

        # Summary by mode
        single_actions = [a for a in action_stats.keys() if action_stats[a]["single"]["count"] > 0]
        pair_actions = [a for a in action_stats.keys() if action_stats[a]["pair"]["count"] > 0]

        if single_actions:
            single_avg_f1 = np.mean(
                [
                    action_stats[a]["single"]["f1"] / action_stats[a]["single"]["count"]
                    for a in single_actions
                    if action_stats[a]["single"]["count"] > 0
                ]
            )
            print(f"\nSingle behaviors: {len(single_actions)} actions, Avg F1: {single_avg_f1:.4f}")

        if pair_actions:
            pair_avg_f1 = np.mean(
                [
                    action_stats[a]["pair"]["f1"] / action_stats[a]["pair"]["count"]
                    for a in pair_actions
                    if action_stats[a]["pair"]["count"] > 0
                ]
            )
            print(f"Pair behaviors: {len(pair_actions)} actions, Avg F1: {pair_avg_f1:.4f}")

        print(f"{'=' * 60}\n")

    except Exception as e:
        if verbose:
            error_msg = str(e)
            if len(error_msg) > 200:
                error_msg = error_msg[:200] + "..."
            print(f"\nWarning: Could not compute validation metrics: {error_msg}")
            if verbose:
                print(f"Traceback: {traceback.format_exc()[:300]}")

compute_validation_metrics(submission=pd.read_csv(WORKING_DIR / "oof_predictions.csv"))

# Inference

In [None]:
# # read data
# test_dataframe = pl.read_csv(INPUT_DIR / "test.csv")

In [None]:
# # preprocess behavior labels
# test_behavior_dataframe = (
#     test_dataframe.filter(pl.col("behaviors_labeled").is_not_null())
#     .select(
#         pl.col("lab_id"),
#         pl.col("video_id"),
#         pl.col("behaviors_labeled").map_elements(eval, return_dtype=pl.List(pl.Utf8)).alias("behaviors_labeled_list"),
#     )
#     .explode("behaviors_labeled_list")
#     .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
#     .select(
#         pl.col("lab_id"),
#         pl.col("video_id"),
#         pl.col("behaviors_labeled_element").str.split(",").list[0].str.replace_all("'", "").alias("agent"),
#         pl.col("behaviors_labeled_element").str.split(",").list[1].str.replace_all("'", "").alias("target"),
#         pl.col("behaviors_labeled_element").str.split(",").list[2].str.replace_all("'", "").alias("behavior"),
#     )
# )

# test_self_behavior_dataframe = test_behavior_dataframe.filter(pl.col("behavior").is_in(SELF_BEHAVIORS))
# test_pair_behavior_dataframe = test_behavior_dataframe.filter(pl.col("behavior").is_in(PAIR_BEHAVIORS))

In [None]:
# (WORKING_DIR / "self_features").mkdir(exist_ok=True, parents=True)
# (WORKING_DIR / "pair_features").mkdir(exist_ok=True, parents=True)

# rows = test_dataframe.rows(named=True)

# for row in tqdm(rows, total=len(rows)):
#     lab_id = row["lab_id"]
#     video_id = row["video_id"]

#     tracking_path = TEST_TRACKING_DIR / f"{lab_id}/{video_id}.parquet"
#     tracking = pl.read_parquet(tracking_path)

#     self_features = make_self_features(metadata=row, tracking=tracking)
#     pair_features = make_pair_features(metadata=row, tracking=tracking)

#     self_features.write_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
#     pair_features.write_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")

#     del self_features, pair_features
#     gc.collect()

In [None]:
# group_submissions = []
# groups = list(test_behavior_dataframe.group_by("lab_id", "video_id", "agent", "target", maintain_order=True))

# for (lab_id, video_id, agent, target), group in tqdm(groups, total=len(list(groups))):
#     agent_mouse_id = int(re.search(r"mouse(\d+)", agent).group(1))
#     target_mouse_id = -1 if target == "self" else int(re.search(r"mouse(\d+)", target).group(1))

#     if target == "self":
#         index = (
#             pl.scan_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
#             .filter((pl.col("agent_mouse_id") == agent_mouse_id))
#             .select(INDEX_COLS)
#             .collect()
#         )
#         feature = (
#             pl.scan_parquet(WORKING_DIR / "self_features" / f"{video_id}.parquet")
#             .filter((pl.col("agent_mouse_id") == agent_mouse_id))
#             .select(pl.exclude(INDEX_COLS))
#             .collect()
#         )
#     else:
#         index = (
#             pl.scan_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")
#             .filter((pl.col("agent_mouse_id") == agent_mouse_id) & (pl.col("target_mouse_id") == target_mouse_id))
#             .select(INDEX_COLS)
#             .collect()
#         )
#         feature = (
#             pl.scan_parquet(WORKING_DIR / "pair_features" / f"{video_id}.parquet")
#             .filter((pl.col("agent_mouse_id") == agent_mouse_id) & (pl.col("target_mouse_id") == target_mouse_id))
#             .select(pl.exclude(INDEX_COLS))
#             .collect()
#         )

#     prediction_dataframe = index.clone()

#     for row in group.rows(named=True):
#         behavior = row["behavior"]

#         predictions = []
#         prediction_labels = []

#         fold_dirs = list((WORKING_DIR / "results" / lab_id / behavior).glob("fold_*"))
#         if not fold_dirs:
#             continue

#         for fold_dir in fold_dirs:
#             with open(fold_dir / "threshold.txt", "r") as f:
#                 threshold = float(f.read().strip())
#             model = xgb.Booster(model_file=fold_dir / "model.json")
#             dtest = xgb.DMatrix(feature, feature_names=feature.columns)
#             fold_predictions = model.predict(dtest)
#             predictions.append(fold_predictions)
#             prediction_labels.append((fold_predictions >= threshold).astype(np.int8))

#         prediction_dataframe = prediction_dataframe.with_columns(
#             *[
#                 pl.Series(name=f"{behavior}_{fold}", values=predictions[fold] * prediction_labels[fold], dtype=pl.Float32)
#                 for fold in range(len(fold_dirs))
#             ]
#         )

#     cols = prediction_dataframe.select(pl.exclude(INDEX_COLS)).columns
#     if not cols:
#         tqdm.write(f"Warning: No predictions found for {lab_id}, {video_id}, {agent}, {target}")
#         continue

#     prediction_labels_dataframe = prediction_dataframe.with_columns(
#         pl.struct(pl.col(cols))
#         .map_elements(
#             lambda row: "none" if sum(row.values()) == 0 else (cols[np.argmax(list(row.values()))]).split("_")[0],
#             return_dtype=pl.String,
#         )
#         .alias("prediction")
#     ).select(INDEX_COLS + ["prediction"])

#     group_submission = (
#         prediction_labels_dataframe.filter((pl.col("prediction") != pl.col("prediction").shift(1)))
#         .with_columns(pl.col("video_frame").shift(-1).alias("stop_frame"))
#         .filter(pl.col("prediction") != "none")
#         .select(
#             pl.col("video_id"),
#             ("mouse" + pl.col("agent_mouse_id").cast(str)).alias("agent_id"),
#             pl.when(pl.col("target_mouse_id") == -1)
#             .then(pl.lit("self"))
#             .otherwise("mouse" + pl.col("target_mouse_id").cast(str))
#             .alias("target_id"),
#             pl.col("prediction").alias("action"),
#             pl.col("video_frame").alias("start_frame"),
#             pl.col("stop_frame"),
#         )
#     )

#     group_submissions.append(group_submission)

# submission = pl.concat(group_submissions, how="vertical").sort(
#     "video_id",
#     "agent_id",
#     "target_id",
#     "action",
#     "start_frame",
#     "stop_frame",
# )
# submission = robustify(submission, test_dataframe, train_test="test")
# submission.with_row_index("row_id").write_csv(WORKING_DIR / "submission.csv")

In [None]:
# !head submission.csv