In [17]:
import datetime
import gc
import itertools
import json
import re
import sys
import time
import traceback
from collections import defaultdict
from pathlib import Path

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
from tqdm.auto import tqdm

# Display full table
pl.Config.set_tbl_rows(-1)
pl.Config.set_tbl_cols(-1)
pl.Config.set_fmt_str_lengths(200) 

polars.config.Config

In [18]:
# const
INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
TRAIN_TRACKING_DIR = INPUT_DIR / "train_tracking"
WORKING_DIR = Path("/kaggle/working")

INDEX_COLS = [
    "video_id",
    "agent_mouse_id",
    "target_mouse_id",
    "video_frame",
]
BODY_PARTS = [
    "ear_left",
    "ear_right",
    "nose",
    "neck",
    "body_center",
    "lateral_left",
    "lateral_right",
    "hip_left",
    "hip_right",
    "tail_base",
    "tail_tip",
]

SELF_BEHAVIORS = [
    "biteobject",
    "climb",
    "dig",
    "exploreobject",
    "freeze",
    "genitalgroom",
    "huddle",
    "rear",
    "rest",
    "run",
    "selfgroom",
]

PAIR_BEHAVIORS = [
    "allogroom",
    "approach",
    "attack",
    "attemptmount",
    "avoid",
    "chase",
    "chaseattack",
    "defend",
    "disengage",
    "dominance",
    "dominancegroom",
    "dominancemount",
    "ejaculate",
    "escape",
    "flinch",
    "follow",
    "intromit",
    "mount",
    "reciprocalsniff",
    "shepherd",
    "sniff",
    "sniffbody",
    "sniffface",
    "sniffgenital",
    "submit",
    "tussle",
]

In [19]:
# Load train data
train_dataframe = pl.read_csv(INPUT_DIR / "train.csv")
train_dataframe = train_dataframe.filter((pl.col("video_id") != 1212811043))
# train_dataframe

In [20]:
# preprocess behavior labels for single and pair
train_behavior_dataframe = (
    train_dataframe.filter(pl.col("behaviors_labeled").is_not_null())
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled").map_elements(eval, return_dtype=pl.List(pl.Utf8)).alias("behaviors_labeled_list"),
    )
    .explode("behaviors_labeled_list")
    .rename({"behaviors_labeled_list": "behaviors_labeled_element"})
    .select(
        pl.col("lab_id"),
        pl.col("video_id"),
        pl.col("behaviors_labeled_element").str.split(",").list[0].str.replace_all("'", "").alias("agent"),
        pl.col("behaviors_labeled_element").str.split(",").list[1].str.replace_all("'", "").alias("target"),
        pl.col("behaviors_labeled_element").str.split(",").list[2].str.replace_all("'", "").alias("behavior"),
    )
)

train_self_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(SELF_BEHAVIORS))
train_pair_behavior_dataframe = train_behavior_dataframe.filter(pl.col("behavior").is_in(PAIR_BEHAVIORS))

# Analysis

Regarding the differnece among labs, chatgpt said ...

| Method         | Role                          | Identity Tracking | Multi-Animal | Accuracy | Notes                              |
| -------------- | ----------------------------- | ----------------- | ------------ | -------- | ---------------------------------- |
| **DeepLabCut** | Pose estimation               | Weak              | Moderate     | High     | Most widely used                   |
| **SLEAP**      | Pose estimation + ID tracking | Strong            | Strong       | High     | Best suited for multi-animal data  |
| **MARS**       | **Behavior classification**   | —                 | —            | —        | Uses SLEAP internally for tracking |

Hence we can consider SLEAP and MARS are the same in terms of tracking plot

## Single

In [21]:
train_self_behavior_sum = (
    train_self_behavior_dataframe
    .group_by(['lab_id', 'behavior'])
    .len()
    .rename({"len": "n"})
    .sort(['lab_id'])
)

In [22]:
"""
Integrate with tracking method and arena type
"""
# Get the tracking methods from each lab
tracking_methods = (train_dataframe.select(['lab_id','tracking_method', 'arena_type', 'body_parts_tracked']).unique())

self_behavior_sum = train_self_behavior_sum.join(
    tracking_methods,
    on="lab_id",
    how="left"
)
# self_behavior_sum

In [23]:
"""
Base Model Results
"""
rows_single = [
    ("AdaptableSnail", "rear", 0.62),
    ("CRIM13", "rear", 0.38),
    ("CRIM13", "selfgroom", 0.35),
    ("CalMS21_task1", "genitalgroom", 0.68),
    ("ElegantMink", "rear", None),
    ("ElegantMink", "selfgroom", None),
    ("GroovyShrew", "rear", 0.52),
    ("GroovyShrew", "rest", 0.66),
    ("GroovyShrew", "selfgroom", 0.34),
    ("GroovyShrew", "climb", 0.37),
    ("GroovyShrew", "dig", 0.40),
    ("GroovyShrew", "run", 0.17),
    ("InvincibleJellyfish", "dig", 0.29),
    ("InvincibleJellyfish", "selfgroom", 0.17),
    ("LyricalHare", "freeze", 0.53),
    ("LyricalHare", "rear", 0.36),
    ("NiftyGoldfinch", "biteobject", 0.03),
    ("NiftyGoldfinch", "climb", 0.56),
    ("NiftyGoldfinch", "dig", 0.50),
    ("NiftyGoldfinch", "exploreobject", 0.09),
    ("NiftyGoldfinch", "rear", 0.42),
    ("NiftyGoldfinch", "selfgroom", 0.44),
    ("TranquilPanther", "rear", 0.19),
    ("TranquilPanther", "selfgroom", 0.16),
    ("UppityFerret", "huddle", 0.63),
    ("UppityFerret", "rear", None),
    ("UppityFerret", "selfgroom", None),
]

res = pl.DataFrame(rows_single, schema=["lab_id", "behavior", "f1"])

  res = pl.DataFrame(rows_single, schema=["lab_id", "behavior", "f1"])


In [24]:
self_res = self_behavior_sum.join(
    res,
    on=["lab_id", "behavior"],
    how="left"
)
# self_res = self_res.sort("behavior")
self_res = self_res.sort("body_parts_tracked")
self_res 

lab_id,behavior,n,tracking_method,arena_type,body_parts_tracked,f1
str,str,u32,str,str,str,f64
"""AdaptableSnail""","""rear""",58,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.62
"""UppityFerret""","""huddle""",11,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",0.63
"""UppityFerret""","""selfgroom""",11,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",
"""UppityFerret""","""rear""",21,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",
"""AdaptableSnail""","""rear""",58,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""lateral_left"", ""lateral_right"", ""neck"", ""nose"", ""tail_base"", ""tail_midpoint"", ""tail_tip""]""",0.62
"""NiftyGoldfinch""","""biteobject""",22,"""SLEAP""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""nose"", ""tail_base""]""",0.03
"""NiftyGoldfinch""","""exploreobject""",22,"""SLEAP""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""nose"", ""tail_base""]""",0.09
"""NiftyGoldfinch""","""rear""",22,"""SLEAP""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""nose"", ""tail_base""]""",0.42
"""NiftyGoldfinch""","""dig""",22,"""SLEAP""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""nose"", ""tail_base""]""",0.5
"""NiftyGoldfinch""","""selfgroom""",22,"""SLEAP""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""nose"", ""tail_base""]""",0.44


## Pair

In [25]:
train_pair_behavior_sum = (
    train_pair_behavior_dataframe
    .group_by(['lab_id', 'behavior'])
    .len()
    .rename({"len": "n"})
    .sort(['lab_id'])
)

In [26]:
"""
Integrate with tracking method 
"""
# Get the tracking methods from each lab
tracking_methods = (train_dataframe.select(['lab_id','tracking_method', 'arena_type', 'body_parts_tracked']).unique())

pair_behavior_sum = train_pair_behavior_sum.join(
    tracking_methods,
    on="lab_id",
    how="left"
)
# pair_behavior_sum

In [27]:
"""
Base Model Results
"""
rows_pair = [
    ("AdaptableSnail", "approach", 0.36),
    ("AdaptableSnail", "attack", 0.19),
    ("AdaptableSnail", "avoid", 0.17),
    ("AdaptableSnail", "chase", 0.15),
    ("AdaptableSnail", "chaseattack", 0.27),
    ("AdaptableSnail", "submit", 0.42),
    ("BoisterousParrot", "shepherd", 0.46),
    ("CRIM13", "approach", 0.49),
    ("CRIM13", "attack", 0.69),
    ("CRIM13", "disengage", 0.44),
    ("CRIM13", "mount", 0.68),
    ("CRIM13", "sniff", 0.68),
    ("CalMS21_supplemental", "attack", 0.82),
    ("CalMS21_supplemental", "sniff", 0.69),
    ("CalMS21_supplemental", "sniffgenital", 0.49),
    ("CalMS21_supplemental", "mount", 0.64),
    ("CalMS21_supplemental", "approach", 0.49),
    ("CalMS21_supplemental", "dominancemount", 0.53),
    ("CalMS21_supplemental", "sniffbody", 0.61),
    ("CalMS21_supplemental", "sniffface", 0.72),
    ("CalMS21_supplemental", "attemptmount", 0.16),
    ("CalMS21_supplemental", "intromit", 0.92),
    ("CalMS21_task1", "approach", 0.40),
    ("CalMS21_task1", "mount", 0.77),
    ("CalMS21_task1", "sniffbody", 0.52),
    ("CalMS21_task1", "sniffface", 0.57),
    ("CalMS21_task1", "sniffgenital", 0.70),
    ("CalMS21_task1", "attack", 0.76),
    ("CalMS21_task1", "intromit", 0.74),
    ("CalMS21_task1", "sniff", 0.75),
    ("CalMS21_task2", "attack", 0.76),
    ("CalMS21_task2", "mount", 0.88),
    ("CalMS21_task2", "sniff", 0.80),
    ("CautiousGiraffe", "approach", None),
    ("CautiousGiraffe", "chase", 0.45),
    ("CautiousGiraffe", "escape", 0.78),
    ("CautiousGiraffe", "reciprocalsniff", 0.74),
    ("CautiousGiraffe", "sniffbody", 0.37),
    ("CautiousGiraffe", "sniffgenital", 0.55),
    ("CautiousGiraffe", "sniff", 0.48),
    ("DeliriousFly", "sniff", 0.46),
    ("DeliriousFly", "attack", 0.55),
    ("DeliriousFly", "dominance", 0.64),
    ("ElegantMink", "attack", 0.77),
    ("ElegantMink", "intromit", 0.70),
    ("ElegantMink", "mount", 0.37),
    ("ElegantMink", "sniff", 0.50),
    ("ElegantMink", "sniffgenital", None),
    ("ElegantMink", "attemptmount", 0.14),
    ("ElegantMink", "allogroom", 0.15),
    ("ElegantMink", "ejaculate", 0.13),
    ("GroovyShrew", "intromit", None),
    ("GroovyShrew", "mount", None),
    ("GroovyShrew", "sniff", 0.66),
    ("GroovyShrew", "sniffgenital", 0.53),
    ("GroovyShrew", "approach", 0.39),
    ("GroovyShrew", "defend", 0.08),
    ("GroovyShrew", "escape", 0.22),
    ("GroovyShrew", "attemptmount", 0.35),
    ("InvincibleJellyfish", "allogroom", 0.29),
    ("InvincibleJellyfish", "attack", 0.68),
    ("InvincibleJellyfish", "dominancegroom", 0.26),
    ("InvincibleJellyfish", "escape", 0.09),
    ("InvincibleJellyfish", "sniff", 0.55),
    ("InvincibleJellyfish", "sniffgenital", 0.45),
    ("JovialSwallow", "attack", 0.55),
    ("JovialSwallow", "chase", 0.03),
    ("JovialSwallow", "sniff", 0.59),
    ("LyricalHare", "approach", 0.20),
    ("LyricalHare", "attack", 0.78),
    ("LyricalHare", "defend", 0.58),
    ("LyricalHare", "escape", 0.68),
    ("LyricalHare", "sniff", 0.70),
    ("NiftyGoldfinch", "approach", 0.47),
    ("NiftyGoldfinch", "attack", 0.59),
    ("NiftyGoldfinch", "chase", 0.68),
    ("NiftyGoldfinch", "defend", 0.44),
    ("NiftyGoldfinch", "escape", 0.64),
    ("NiftyGoldfinch", "flinch", 0.08),
    ("NiftyGoldfinch", "follow", 0.42),
    ("NiftyGoldfinch", "sniff", 0.49),
    ("NiftyGoldfinch", "sniffface", 0.62),
    ("NiftyGoldfinch", "sniffgenital", 0.19),
    ("NiftyGoldfinch", "tussle", 0.39),
    ("PleasantMeerkat", "attack", 0.09),
    ("PleasantMeerkat", "chase", 0.09),
    ("PleasantMeerkat", "escape", 0.14),
    ("PleasantMeerkat", "follow", 0.70),
    ("ReflectiveManatee", "sniff", 0.88),
    ("ReflectiveManatee", "attack", 0.83),
    ("SparklingTapir", "attack", 0.66),
    ("SparklingTapir", "defend", 0.58),
    ("SparklingTapir", "escape", 0.69),
    ("SparklingTapir", "mount", 0.83),
    ("SparklingTapir", "sniffgenital", None),
    ("TranquilPanther", "intromit", 0.55),
    ("TranquilPanther", "mount", 0.41),
    ("TranquilPanther", "sniff", 0.46),
    ("TranquilPanther", "sniffgenital", 0.50),
    ("UppityFerret", "reciprocalsniff", 0.62),
    ("UppityFerret", "sniff", None),
    ("UppityFerret", "sniffgenital", 0.51),
    ("UppityFerret", "intromit", None),
    ("UppityFerret", "mount", None),
]

res = pl.DataFrame(rows_pair, schema=["lab_id", "behavior", "f1"])

  res = pl.DataFrame(rows_pair, schema=["lab_id", "behavior", "f1"])


In [28]:
pair_res = pair_behavior_sum.join(
    res,
    on=["lab_id", "behavior"],
    how="left"
)
# pair_res = pair_res.sort("behavior")
pair_res = pair_res.sort(["body_parts_tracked", "arena_type", "lab_id"])
pair_res 

lab_id,behavior,n,tracking_method,arena_type,body_parts_tracked,f1
str,str,u32,str,str,str,f64
"""AdaptableSnail""","""submit""",156,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.42
"""AdaptableSnail""","""chase""",160,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.15
"""AdaptableSnail""","""chaseattack""",158,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.27
"""AdaptableSnail""","""approach""",156,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.36
"""AdaptableSnail""","""avoid""",165,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.17
"""AdaptableSnail""","""attack""",159,"""DeepLabCut""","""familiar""","""[""body_center"", ""ear_left"", ""ear_right"", ""headpiece_bottombackleft"", ""headpiece_bottombackright"", ""headpiece_bottomfrontleft"", ""headpiece_bottomfrontright"", ""headpiece_topbackleft"", ""headpiece_topback…",0.19
"""UppityFerret""","""reciprocalsniff""",22,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",0.62
"""UppityFerret""","""sniff""",21,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",
"""UppityFerret""","""mount""",10,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",
"""UppityFerret""","""intromit""",10,"""DeepLabCut""","""neutral""","""[""body_center"", ""ear_left"", ""ear_right"", ""hip_left"", ""hip_right"", ""lateral_left"", ""lateral_right"", ""nose"", ""spine_1"", ""spine_2"", ""tail_base"", ""tail_middle_1"", ""tail_middle_2"", ""tail_tip""]""",


"ReflectiveManatee"	and "SparklingTapir" <br><br>
"CRIM13", "CalMS21_supplemental", "CautiousGiraffe", "CalMS21_task1", "CalMS21_task2", "ElegantMink", "InvincibleJellyfish", "JovialSwallow", and "TranquilPanther"<br><br>
can be integrated

# Analysis

### Merge "PleasantMeerkat" and "DeliriousFly"?

They both use the same tracked body parts and they share some behaviors in common (such as attack).
However, because the arena type is different, the distance between the two mice’s bodies during an attack is likely to differ.
This difference in spatial constraints may cause the model to learn arena-specific attack patterns, even though the behavior label is the same. <br><br>

### Merge "ReflectiveManatee" and "SparklingTapir"? 

Both have good scores. Hence I will keep them independent for now.

### Merge "CRIM13", "CalMS21_supplemental", "CautiousGiraffe", "CalMS21_task1", "CalMS21_task2", "ElegantMink", "InvincibleJellyfish", "JovialSwallow", and "TranquilPanther"?

First of all, merging all of them causes memory issues, so I need to selectively choose which labs to group.

- Group A: CalMS family (CalMS21_supplemental, CalMS21_task1, CalMS21_task2)
  (very similar action sets, strong synergy)
- Group B: CRIM13 + CautiousGiraffe
- Group C: ElegantMink + InvincibleJellyfish + JovialSwallow
(weak labs that can support each other)

- Group D: TranquilPanther + NiftyGoldfinch
(similar action variety and medium F1)