<a href="https://colab.research.google.com/github/souvikdas1990/Testing/blob/main/MABe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
# {"username":"souvikdas700","key":"6351d65ad0a73b73a6fbc232d6c5bc55"}
import kagglehub
kagglehub.login()


VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

In [2]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

mabe_mouse_behavior_detection_path = kagglehub.competition_download('MABe-mouse-behavior-detection')

print('Data source import complete.')
print(mabe_mouse_behavior_detection_path)

UnauthenticatedError: User is not authenticated

In [None]:
import re
import cv2
import gc
import os
import json
import glob
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.utils import resample
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.cuda.amp import autocast, GradScaler
import time
from pathlib import Path

In [None]:

tracking_path = f'{mabe_mouse_behavior_detection_path}/train_tracking/'
annotation_path = f'{mabe_mouse_behavior_detection_path}/train_annotation/'
train_csv_path = f'{mabe_mouse_behavior_detection_path}/train.csv'

try:
    train_df = pd.read_csv(train_csv_path)
    print("First 5 rows of train.csv:")
    display(train_df.head())
    print(train_df.shape)
except FileNotFoundError:
    print(f"Error: train.csv not found at {train_csv_path}")
except Exception as e:
    print(f"An error occurred while reading train.csv: {e}")

In [None]:
valid_pairs = []

# Walk over each lab folder in annotation_path
for lab in os.listdir(annotation_path):
    ann_lab_folder = os.path.join(annotation_path, lab)
    if not os.path.isdir(ann_lab_folder):
        continue
    # collect all video_ids (strip .parquet)
    ann_files = [f.replace(".parquet", "") for f in os.listdir(ann_lab_folder) if f.endswith(".parquet")]
    for vid in ann_files:
        valid_pairs.append((lab, vid))

# Build a DataFrame of valid (lab_id, video_id) pairs
valid_df = pd.DataFrame(valid_pairs, columns=["lab_id", "video_id"])

# Filter train_df to only keep rows that appear in valid_df
train_df["video_id"] = train_df["video_id"].astype(str)
valid_df["video_id"] = valid_df["video_id"].astype(str)
train_df = train_df.merge(valid_df, on=["lab_id","video_id"], how="inner")

print("After filtering, train_df shape:", train_df.shape)
print("Unique labs left:", train_df["lab_id"].nunique())


In [None]:
# ====================================================
# Build merged labeled dataset from tracking + annotation parquet files
# ====================================================

KEEP_NIL_FRAC = 0  # keep 10% of NIL rows

all_chunks = []   # collect per-video DataFrames (beware memory if you keep all)

for _, row in train_df.iterrows():
    lab_id   = row['lab_id']
    video_id = row['video_id']

    tracking_file_path   = os.path.join(tracking_path,   lab_id, f'{video_id}.parquet')
    annotation_file_path = os.path.join(annotation_path, lab_id, f'{video_id}.parquet')

    print(f"Processing Lab ID: {lab_id}, Video ID: {video_id}")

    # --- load tracking ---
    try:
        df_tracking = pd.read_parquet(tracking_file_path)
        # init default labels
        df_tracking['target_id'] = 0
        df_tracking['action'] = "NIL"
    except Exception as e:
        print(f"  Error reading tracking file: {e}")
        continue

    # --- load annotations & stamp agent rows ---
    try:
        df_annotation = pd.read_parquet(annotation_file_path)

        # minimal safety check
        need = {'start_frame','stop_frame','agent_id','target_id','action'}
        if not need.issubset(df_annotation.columns):
            print(f"  Annotation missing cols {need - set(df_annotation.columns)}; skipping labels.")

        else:
            for _, ann in df_annotation.iterrows():
                mask_agent = (
                    (df_tracking['video_frame'] >= ann['start_frame']) &
                    (df_tracking['video_frame'] <= ann['stop_frame']) &
                    (df_tracking['mouse_id'] == ann['agent_id'])
                )
                df_tracking.loc[mask_agent, 'target_id'] = ann['target_id']
                df_tracking.loc[mask_agent, 'action']    = ann['action']

                # (optional) also tag target rows with same action:
                # if pd.notna(ann['target_id']):
                #     mask_target = (
                #         (df_tracking['video_frame'] >= ann['start_frame']) &
                #         (df_tracking['video_frame'] <= ann['stop_frame']) &
                #         (df_tracking['mouse_id'] == ann['target_id'])
                #     )
                #     df_tracking.loc[mask_target, 'action'] = ann['action']
                #     df_tracking.loc[mask_target, 'target_id'] = ann['agent_id']

    except Exception as e:
        print(f"  Error reading annotation file: {e}")

    # --- add metadata columns first ---
    df_tracking['lab_id'] = lab_id
    df_tracking['video_id'] = video_id
    cols = ['lab_id', 'video_id'] + [c for c in df_tracking.columns if c not in ('lab_id','video_id')]
    df_tracking = df_tracking[cols]

    # --- drop 90% of NIL rows (keep only 10%) per video ---
    nil_mask = (df_tracking['action'] == 'NIL')

    # per-video deterministic RNG seed
    seed = (hash((str(lab_id), str(video_id))) & 0xFFFFFFFF)
    rng = np.random.RandomState(seed)

    # vectorized keep mask: keep all positives + 10% of NILs
    keep_nil_mask = nil_mask & (rng.rand(len(df_tracking)) < KEEP_NIL_FRAC)
    pos_mask = ~nil_mask
    keep_mask = pos_mask | keep_nil_mask

    df_tracking = df_tracking.loc[keep_mask].reset_index(drop=True)

    # append reduced per-video chunk
    all_chunks.append(df_tracking)

# --- concatenate all reduced chunks ---
df_merged = pd.concat(all_chunks, ignore_index=True)

print(f"Merged (reduced) dataset shape = {df_merged.shape}")

# Optionally save
# out_path = "/kaggle/working/merged_dataset.parquet"
# df_merged.to_parquet(out_path, index=False)


In [None]:
# ====================================================
# Cell: Load ONLY tracking rows that match df_merged
#        (same lab_id, video_id, video_frame AND mouse_id == target_id)
# ====================================================
# Inputs assumed:
#   - train_df with columns ['lab_id','video_id']
#   - tracking_path root containing <lab_id>/<video_id>.parquet
#   - df_merged with columns ['lab_id','video_id','video_frame','target_id']
# Outputs:
#   - df_full_tracking_all: concatenation of ONLY the matching tracking rows
# ====================================================

# 1) Build the (lab_id, video_id, video_frame, target_id) key set from df_merged
if 'target_id' not in df_merged.columns:
    raise ValueError("df_merged must contain 'target_id' column.")

keys_all = (
    df_merged.loc[df_merged['target_id'].fillna(0).astype(int) > 0,
                  ['lab_id','video_id','video_frame','target_id']]
    .dropna()
    .drop_duplicates()
    .copy()
)

# normalize dtypes used for joining
keys_all['lab_id']      = keys_all['lab_id'].astype(str)
keys_all['video_id']    = keys_all['video_id']
keys_all['video_frame'] = keys_all['video_frame'].astype(int, errors='ignore')
keys_all['target_id']   = keys_all['target_id'].astype(int, errors='ignore')

all_train_chunks = []
seen = set()  # avoid re-reading duplicates, if any

for _, row in train_df.iterrows():
    lab_id   = str(row['lab_id'])
    video_id = row['video_id']

    # keep only if we actually have any keys for this (lab, video)
    keys_this = keys_all[(keys_all['lab_id'] == lab_id) & (keys_all['video_id'] == video_id)]
    if keys_this.empty:
        continue

    key = (lab_id, video_id)
    if key in seen:
        continue
    seen.add(key)

    tracking_file_path = os.path.join(tracking_path, lab_id, f'{video_id}.parquet')
    try:
        df_full_tracking = pd.read_parquet(tracking_file_path)

        # --- minimal schema normalization ---
        # some schemas may use 'frame' instead of 'video_frame'
        if 'video_frame' not in df_full_tracking.columns and 'frame' in df_full_tracking.columns:
            df_full_tracking = df_full_tracking.rename(columns={'frame': 'video_frame'})

        # enforce dtypes for join
        if 'video_frame' in df_full_tracking.columns:
            df_full_tracking['video_frame'] = df_full_tracking['video_frame'].astype(int, errors='ignore')
        if 'mouse_id' in df_full_tracking.columns:
            df_full_tracking['mouse_id'] = df_full_tracking['mouse_id'].astype(int, errors='ignore')

        # add metadata for traceability (if not already present)
        df_full_tracking['lab_id'] = lab_id
        df_full_tracking['video_id'] = video_id

        # --- build join keys for this video ---
        # need columns: ['video_frame','mouse_id'] where mouse_id == target_id
        join_keys = keys_this[['video_frame','target_id']].rename(columns={'target_id':'mouse_id'}).drop_duplicates()

        # inner join to keep ONLY rows that match (video_frame, mouse_id == target_id)
        df_match = join_keys.merge(
            df_full_tracking,
            on=['video_frame','mouse_id'],
            how='inner'
        )

        if df_match.empty:
            print(f"[INFO] No matching rows for lab={lab_id}, video={video_id}.")
            continue

        # reorder to keep metadata first
        cols = ['lab_id', 'video_id'] + [c for c in df_match.columns if c not in ('lab_id','video_id')]
        df_match = df_match[cols]

        all_train_chunks.append(df_match)
        print(f"Loaded filtered tracking: lab={lab_id}, video={video_id}, shape={df_match.shape}")

    except Exception as e:
        print(f"[WARN] Could not read Train tracking for lab={lab_id}, video={video_id}: {e}")

# Concatenate if any found
if len(all_train_chunks) > 0:
    df_full_tracking_all = pd.concat(all_train_chunks, ignore_index=True)
    print(f"Combined filtered tracking shape: {df_full_tracking_all.shape}")
else:
    df_full_tracking_all = pd.DataFrame()
    print("No matching tracking rows were loaded.")

# Peek
#df_full_tracking_all.head()


In [None]:
# Count the true labels in the original test labels
#print("True label counts in the data used for inference:")
# Use the un-encoded labels before they went into the DataLoader
#print(df_merged['action'].value_counts())
#unique_values_array = df_merged['bodypart'].unique()

#print("Unique values (as a NumPy array):")
#print(unique_values_array)
#print(len(unique_values_array))
#unique_action_array = df_merged['action'].unique()

#print("Unique values (as a NumPy array):")
#print(unique_action_array)
#print(len(unique_action_array))
#print(df_merged['action'].value_counts())

In [None]:
# ====================================================
# Cell: Join df_merged ⟷ df_full_tracking_all and ADD target_x / target_y
#  - Preserves df_merged row count (LEFT JOIN)
#  - Keys: lab_id, video_id, video_frame, target_id, bodypart==target_bodypart
# ====================================================

def to_long_pose(df):
    """Normalize tracking df into long format: ['lab_id','video_id','video_frame','mouse_id','bodypart','x','y']"""
    df = df.copy()
    # unify frame name
    if 'video_frame' not in df.columns and 'frame' in df.columns:
        df = df.rename(columns={'frame':'video_frame'})
    base = [c for c in ['lab_id','video_id','video_frame','mouse_id'] if c in df.columns]

    # already long?
    if 'bodypart' in df.columns and {'x','y'}.issubset(df.columns):
        return df[base + ['bodypart','x','y']].copy()

    cols = df.columns.tolist()
    # Pattern A: x_<bp>, y_<bp>
    x_bp = [(c, c.split('x_',1)[1]) for c in cols
            if c.startswith('x_') and len(c) > 2 and f"y_{c.split('x_',1)[1]}" in cols]
    # Pattern B: <bp>_x, <bp>_y
    bpx = [(c, c[:-2]) for c in cols if c.endswith('_x') and (c[:-2] + '_y') in cols]

    long_rows = []
    if x_bp or bpx:
        if x_bp:
            for xcol, bp in x_bp:
                ycol = f'y_{bp}'
                sub = df[base + [xcol, ycol]].copy()
                sub['bodypart'] = bp
                sub = sub.rename(columns={xcol:'x', ycol:'y'})
                long_rows.append(sub)
        if bpx:
            for xcol, bp in bpx:
                ycol = f'{bp}_y'
                sub = df[base + [xcol, ycol]].copy()
                sub['bodypart'] = bp
                sub = sub.rename(columns={xcol:'x', ycol:'y'})
                long_rows.append(sub)
        return pd.concat(long_rows, ignore_index=True)[base + ['bodypart','x','y']]

    # Fallback center
    out = df[base].copy()
    if {'x','y'}.issubset(df.columns):
        out['x'] = df['x']; out['y'] = df['y']; out['bodypart'] = 'body_center'
    elif {'body_center_x','body_center_y'}.issubset(df.columns):
        out['x'] = df['body_center_x']; out['y'] = df['body_center_y']; out['bodypart'] = 'body_center'
    else:
        guess_x = [c for c in cols if re.search(r'(^x$|_x$|^x_|center_x$)', c)]
        guess_y = [c for c in cols if re.search(r'(^y$|_y$|^y_|center_y$)', c)]
        out['x'] = df[guess_x].mean(axis=1, skipna=True) if guess_x else np.nan
        out['y'] = df[guess_y].mean(axis=1, skipna=True) if guess_y else np.nan
        out['bodypart'] = 'center_mean'
    return out[base + ['bodypart','x','y']]

# --- Safety: inputs
if 'df_merged' not in globals():
    raise RuntimeError("df_merged is not defined.")
if 'df_full_tracking_all' not in globals():
    raise RuntimeError("df_full_tracking_all is not defined.")

# --- Dtype align on df_merged
df_merged['lab_id']      = df_merged['lab_id'].astype(str)
df_merged['video_id']    = df_merged['video_id'].astype(str)
df_merged['video_frame'] = df_merged['video_frame'].astype(int, errors='ignore')
df_merged['target_id']   = df_merged['target_id'].fillna(0).astype(int, errors='ignore')

# --- Normalize tracking to long + dtypes
trk_long = to_long_pose(df_full_tracking_all)
for k in ['lab_id','video_id']:
    if k in trk_long.columns:
        trk_long[k] = trk_long[k].astype(str)
if 'video_frame' in trk_long.columns:
    trk_long['video_frame'] = trk_long['video_frame'].astype(int, errors='ignore')
if 'mouse_id' in trk_long.columns:
    trk_long['mouse_id'] = trk_long['mouse_id'].astype(int, errors='ignore')

# --- Build right table with target coords
right_tbl = trk_long.rename(columns={
    'mouse_id': 'target_id',
    'bodypart': 'target_bodypart',
    'x': 'target_x',
    'y': 'target_y'
})[['lab_id','video_id','video_frame','target_id','target_bodypart','target_x','target_y']]

# --- LEFT JOIN on same bodypart
df_joined = df_merged.merge(
    right_tbl,
    left_on = ['lab_id','video_id','video_frame','target_id','bodypart'],
    right_on= ['lab_id','video_id','video_frame','target_id','target_bodypart'],
    how='left'
)

# If you want canonical column names, keep 'target_x','target_y' and 'target_bodypart'
# Drop the duplicate right-side key column (target_bodypart) only if the left already had one:
# (If left also had 'target_bodypart', prefer the joined one where available.)
if 'target_bodypart_x' in df_joined.columns and 'target_bodypart_y' in df_joined.columns:
    # rare case from earlier merges; ignore
    pass
elif 'target_bodypart' in df_joined.columns and 'target_bodypart_y' in df_joined.columns:
    # also rare; ignore
    pass
else:
    # nothing to do; 'target_bodypart' is the right-side column we just created

    # If you want to keep only one column name for the bodypart on the right:
    pass

print("Joined shape (rows preserved from df_merged):", df_joined.shape)
#print("Has target_x/target_y?", 'target_x' in df_joined.columns, 'target_y' in df_joined.columns)
#display(df_joined.head(8)[['lab_id','video_id','video_frame','mouse_id','bodypart','x','y','target_id','action','target_bodypart','target_x','target_y']])


In [None]:
display(df_joined[(df_joined['video_id'] == '1335286655') & (df_joined['video_frame'] == 1807)])

In [None]:
# --- Setup ---
OUTPUT_DIR = '/content/kaggle/working/data'
os.makedirs(OUTPUT_DIR, exist_ok=True)

file_path = os.path.join(OUTPUT_DIR, 'data_final.parquet')
df_joined.to_parquet(file_path, index=False)

In [None]:
del df_merged
del df_full_tracking_all
del df_full_tracking
del train_df
del valid_df
del df_tracking
del df_annotation
#del _6
gc.collect()

In [None]:
ID_COLS = ['mouse_id', 'target_id', 'target_bodypart']
df_joined.drop(columns=ID_COLS, axis=1, inplace=True)
#display(df_joined.head())

In [None]:
rename_map = {}
if 'x' in df_joined.columns:         rename_map['x'] = 'mouse_A_x'
if 'y' in df_joined.columns:         rename_map['y'] = 'mouse_A_y'
if 'target_x' in df_joined.columns:  rename_map['target_x'] = 'mouse_B_x'
if 'target_y' in df_joined.columns:  rename_map['target_y'] = 'mouse_B_y'

df_joined = df_joined.rename(columns=rename_map)

# move `action` to the far right (if present)
cols = df_joined.columns.tolist()
if 'action' in cols:
    cols_no_action = [c for c in cols if c != 'action']
    cols = cols_no_action + ['action']
    df_joined = df_joined[cols]

print("Renamed columns applied. Shape:", df_joined.shape)
display(df_joined.head())

In [None]:
KEYS = ['lab_id','video_id','video_frame']
g = df_joined.groupby(KEYS, observed=True)

# Bounds across BOTH mice per group
Aminx = g['mouse_A_x'].transform('min'); Bminx = g['mouse_B_x'].transform('min')
Amaxx = g['mouse_A_x'].transform('max'); Bmaxx = g['mouse_B_x'].transform('max')
Aminy = g['mouse_A_y'].transform('min'); Bminy = g['mouse_B_y'].transform('min')
Amaxy = g['mouse_A_y'].transform('max'); Bmaxy = g['mouse_B_y'].transform('max')

xmin = np.minimum(Aminx, Bminx)
xmax = np.maximum(Amaxx, Bmaxx)
ymin = np.minimum(Aminy, Bminy)
ymax = np.maximum(Amaxy, Bmaxy)

# Avoid divide-by-zero
eps = 1e-6
den_x = (xmax - xmin).where((xmax - xmin) != 0, other=1.0)
den_y = (ymax - ymin).where((ymax - ymin) != 0, other=1.0)

# In-place normalize to [0,1]
df_joined['mouse_A_x'] = ((df_joined['mouse_A_x'] - xmin) / (den_x + eps)).clip(0, 1)
df_joined['mouse_A_y'] = ((df_joined['mouse_A_y'] - ymin) / (den_y + eps)).clip(0, 1)
df_joined['mouse_B_x'] = ((df_joined['mouse_B_x'] - xmin) / (den_x + eps)).clip(0, 1)
df_joined['mouse_B_y'] = ((df_joined['mouse_B_y'] - ymin) / (den_y + eps)).clip(0, 1)

print("Normalized in place. Shape:", df_joined.shape)
display(df_joined.head())

In [None]:
gc.collect()

In [None]:
del df_merged
del y_actions
del df_test
del train_df
del valid_df
del df_tracking
del df_annotation
#del _6
gc.collect()

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from collections import defaultdict
import pandas as pd # Import needed for context
import random # Needed for role switching

# Assuming df_joined, grouped, keys, act_to_idx, H, W, SIGMA, MIN_PARTS are defined

class HeatmapDataset(Dataset):
    def __init__(self, keys, grouped, act_to_idx, H=96, W=96, sigma=1.5):
        self.keys = keys
        self.grouped = grouped
        self.act_to_idx = act_to_idx
        self.H, self.W = H, W
        self.sigma = sigma

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        vid, fr = self.keys[idx]
        rows = self.grouped[(vid, fr)]

        # --- 1. DETERMINE ACTION LABEL (y) ---
        # The true action of the frame (e.g., 'rear')
        action = str(rows[0]['action'])
        y = self.act_to_idx[action]

        # --- 2. COLLECT RAW POINTS ---
        # Collect all Mouse A points and all Mouse B points
        raw_ptsA, raw_ptsB = [], []

        for r in rows:
            uA, vA = float(r['mouse_A_x']), float(r['mouse_A_y'])
            uB, vB = float(r['mouse_B_x']), float(r['mouse_B_y'])

            if np.isfinite(uA) and np.isfinite(vA):
                raw_ptsA.append((uA, vA))
            if np.isfinite(uB) and np.isfinite(vB):
                raw_ptsB.append((uB, vB))

        # --- 3. DYNAMIC ROLE SWITCHING (CRUCIAL FOR INFERENCE) ---

        # Randomly select which mouse will be the 'Agent' in this sample's input (x)
        # 0 = Mouse B is the Agent (B is put in Channel 0)
        # 1 = Mouse A is the Agent (A is put in Channel 0)
        y_dir = random.choice([0, 1])

        if y_dir == 1: # Mouse A is the agent (A -> B, Channel 0 = A)
            # Channel 0 (Agent) gets A's points, Channel 1 (Target) gets B's points
            pts_agent, pts_target = raw_ptsA, raw_ptsB
        else: # Mouse B is the agent (B -> A, Channel 0 = B)
            # Channel 0 (Agent) gets B's points, Channel 1 (Target) gets A's points
            pts_agent, pts_target = raw_ptsB, raw_ptsA

        # --- 4. FILTER / HANDLE EMPTY FRAME ---

        if len(pts_agent) + len(pts_target) < MIN_PARTS:
            x = np.zeros((2, self.H, self.W), dtype=np.float32)
            return (torch.from_numpy(x),
                    torch.tensor(y, dtype=torch.long),
                    torch.tensor(y_dir, dtype=torch.long),
                    vid, fr)

        # --- 5. RENDER HEATMAPS ---

        # Channel 0: Agent
        h_agent = render_heatmap(pts_agent, self.H, self.W, sigma_px=self.sigma, amp=1.0)
        # Channel 1: Target
        h_target = render_heatmap(pts_target, self.H, self.W, sigma_px=self.sigma, amp=1.0)

        # Stack into [2, H, W] tensor (Agent is always Channel 0)
        x  = np.stack([h_agent, h_target], axis=0).astype(np.float32)

        # --- 6. DUAL OUTPUT ---

        return (torch.from_numpy(x),
                torch.tensor(y, dtype=torch.long),
                torch.tensor(y_dir, dtype=torch.long),
                vid, fr)

# build dataset & loader
ds = HeatmapDataset(keys, grouped, act_to_idx, H=H, W=W, sigma=SIGMA)
dl = DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

# The sanity check now receives the extra tensor
xb, yb, y_dirb, vidb, frb = next(iter(dl))
print("batch:", xb.shape, yb.shape, y_dirb.shape)

batch: torch.Size([128, 2, 96, 96]) torch.Size([128]) torch.Size([128])


In [None]:
# ---- model
class TinyCNNDual(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(2, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),  # 2→32, H/2
            nn.Conv2d(32,64,3, padding=1),  nn.ReLU(), nn.MaxPool2d(2),  # 64, H/4
            nn.Conv2d(64,128,3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),  # -> [B,128,1,1]
        )
        self.action_head = nn.Linear(128, n_classes)  # action classes
        self.dir_head    = nn.Linear(128, 2)          # 0=B is agent, 1=A is agent

    def forward(self, x):
        h = self.features(x).flatten(1)        # [B,128]
        logits_a = self.action_head(h)         # [B,C]
        logits_d = self.dir_head(h)            # [B,2]
        return logits_a, logits_d

# ---- prepare device & infer num classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# FIX: Add y_dirb to unpack the fifth value returned by the DataLoader
xb, yb, y_dirb, _, _ = next(iter(dl))                   # grab one batch to infer C

# --- ADD THESE CHECKS ---
if y_dirb.max().item() >= 2 or y_dirb.min().item() < 0:
    print(f"FATAL ERROR: Directional label max/min out of bounds (Expected 0 or 1).")
    print(f"Max in batch: {y_dirb.max().item()}, Min in batch: {y_dirb.min().item()}")
    # Check the data type again, must be torch.long
    print(f"y_dirb dtype: {y_dirb.dtype}")
    # This check should show you the problem.
# ------------------------

N_CLASSES = int(yb.max().item()) + 1
print(f"Inferred N_CLASSES: {N_CLASSES}")

model = TinyCNNDual(n_classes=N_CLASSES).to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-3)
ce = nn.CrossEntropyLoss()


Inferred N_CLASSES: 35


AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
N_CLASSES_TRUE = len(actions)
print(f"True N_CLASSES: {N_CLASSES_TRUE}")

True N_CLASSES: 37


In [None]:
EPOCHS = 5          # tweak as needed
DIR_LOSS_W = 0.3    # weight for direction head

def train_one_epoch(dloader):
    model.train()
    tot_loss = 0.0; n = 0
    acc_a = 0.0; acc_d = 0.0
    for x, y_action, _, _ in dloader:
        x = x.to(device, non_blocking=True)            # [B,2,H,W]
        y_action = y_action.to(device, non_blocking=True)  # [B]

        # Build swapped batch
        x_sw = x[:, [1,0], ...]                        # swap A/B channels
        x_cat = torch.cat([x, x_sw], dim=0)            # [2B,2,H,W]
        y_action_cat = torch.cat([y_action, y_action], dim=0)  # [2B]
        y_dir = torch.cat([
            torch.ones(len(y_action), dtype=torch.long, device=device),  # originals: A is agent
            torch.zeros(len(y_action), dtype=torch.long, device=device), # swapped:  B is agent
        ], dim=0)

        opt.zero_grad()
        logits_a, logits_d = model(x_cat)
        loss = ce(logits_a, y_action_cat) + DIR_LOSS_W * ce(logits_d, y_dir)
        loss.backward(); opt.step()

        # metrics
        with torch.no_grad():
            acc_a += (logits_a.argmax(1) == y_action_cat).float().sum().item()
            acc_d += (logits_d.argmax(1) == y_dir).float().sum().item()
            tot_loss += float(loss.item()) * x_cat.size(0)
            n += x_cat.size(0)

    return tot_loss / max(1,n), acc_a / max(1,n), acc_d / max(1,n)

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc_a, tr_acc_d = train_one_epoch(dl)
    print(f"epoch {epoch:02d} | loss {tr_loss:.4f} | acc_action {tr_acc_a:.3f} | acc_dir {tr_acc_d:.3f}")


AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
MODEL_DIR = '/kaggle/working/model'
# Create the output directory if it doesn't exist
os.makedirs(MODEL_DIR, exist_ok=True)

MODEL_FILE = os.path.join(MODEL_DIR, 'mouse_action_lstm.pth')

# Ensure your model is on the CPU before saving to avoid GPU compatibility issues when loading
model.to('cpu')

# Save only the model's state dictionary
torch.save(model.state_dict(), MODEL_FILE)

print(f"Model successfully saved to: {MODEL_FILE}")

In [None]:
gc.collect()

In [None]:
# Define your data directory (assuming you placed the 'parquet_chunks' folder
# in the same working directory or a linked dataset)
OUTPUT_DIR = '/kaggle/working/stratified_split'

# --- Load and Prepare Data ---

# 1. Load All Parquet Files
try:
    all_files = glob.glob(os.path.join(OUTPUT_DIR, "data_test_10percent.parquet"))
    list_of_dfs = [pd.read_parquet(f) for f in all_files]
    df_test = pd.concat(list_of_dfs, ignore_index=True)
    print(f"Successfully loaded {len(all_files)} files.")
    print(f"Total rows: {len(df_test):,}")
except Exception as e:
    print(f"Error loading files. Check directory path: {e}")
    # Create a dummy DataFrame if loading fails to prevent kernel crash
    # df_final = pd.DataFrame()

In [None]:
# 1. Count the true labels in the original test labels
print("True label counts in the data used for inference:")
# Use the un-encoded labels before they went into the DataLoader
print(df_test['action'].value_counts())

In [None]:
# --- 1. Define the Index Columns ---
# These columns will define the unique rows in the resulting DataFrame.
ID_COLS = ['lab_id', 'video_id', 'video_frame', 'mouse_id', 'target_id']

# --- 2. Separate Action for Merging ---
# Since 'Action' is constant for a given combination of ID_COLS,
# we extract it separately to avoid issues with the pivot,
# then merge it back later. This is often necessary when the value
# column (like 'x'/'y') is not strictly unique.

# Keep only the unique combinations of ID_COLS and action
action_df = df_test[ID_COLS + ['action']].drop_duplicates()

# --- 3. Perform the Pivot Operation ---
# This transforms the 'bodypart' rows into columns.
df_pivoted = df_test.pivot_table(
    index=ID_COLS,           # The columns that form the new unique row identifier
    columns='bodypart',      # The column whose unique values become the new column headers
    values=['x', 'y'],       # The columns whose values will be aggregated
    aggfunc='mean'           # CRITICAL: Calculates the mean of the 'x' and 'y' duplicates
)

# --- 4. Clean Up Column Names ---
# The pivot operation creates multi-level columns (e.g., ('x', 'headpiece_bottombackright')).
# We flatten and rename them for easier use: 'bodypart_x', 'bodypart_y'.
df_pivoted.columns = [f'{col[1]}_{col[0]}' for col in df_pivoted.columns.values]

# --- 5. Reset Index and Merge Action ---
# Reset the index to turn the ID_COLS back into regular columns
df_pivoted = df_pivoted.reset_index()

# Merge the action column back into the pivoted DataFrame
df_final = pd.merge(
    df_pivoted,
    action_df,
    on=ID_COLS,
    how='left'  # Use a left merge to keep all the pivoted data
)

# --- 6. Reorder Columns for Clarity ---
# Move 'action' to be near the ID columns
final_columns = ID_COLS + ['action'] + [col for col in df_final.columns if col not in ID_COLS + ['action']]
df_final = df_final[final_columns]


In [None]:
df_final.drop(columns=ID_COLS, axis=1, inplace=True)
print("--- Final Reshaped DataFrame Head ---")
print(df_final.head())
print(f"\nFinal DataFrame Shape: {df_final.shape}")

In [None]:
INPUT_SIZE = sum(1 for col in df_final.columns if col.endswith('_x') or col.endswith('_y'))
print(INPUT_SIZE)

NUM_CLASSES = len(df_final['action'].unique())
print(NUM_CLASSES)

In [None]:
# --- Model Definition ---
class MouseActionLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(MouseActionLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # 1. LSTM Layer: Processes the time sequence
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        # 2. Fully Connected Layer: Maps the final hidden state to the class prediction
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Initialize hidden state and cell state (optional, but good practice)
        # h0 and c0 will be created automatically if not passed, but explicit is cleaner

        # Pass the sequence through the LSTM
        # out has shape (batch_size, sequence_length, hidden_size)
        out, _ = self.lstm(x)

        # We only care about the output from the LAST frame in the sequence
        # out[:, -1, :] extracts the last time step output
        out = self.fc(out[:, -1, :])
        return out

# Instantiate the model
HIDDEN_SIZE = 128
NUM_LAYERS = 2
MODEL_FILE = '/kaggle/working/model/mouse_action_lstm.pth'
loaded_model  = MouseActionLSTM(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NUM_CLASSES)
loaded_model.load_state_dict(torch.load(MODEL_FILE))
print("LSTM model defined successfully.")

In [None]:
class PoseSequenceDataset(Dataset):
    def __init__(self, features, labels, sequence_length):
        self.features = features
        self.labels = labels
        self.sequence_length = sequence_length
        self.indices = self._create_indices()

    def _create_indices(self):
        # Create indices for the start of each sequence.
        # The last possible start index is len(features) - sequence_length
        return np.arange(len(self.features) - self.sequence_length)

    def __len__(self):
        # The number of available sequences
        return len(self.indices)

    def __getitem__(self, idx):
        start_idx = self.indices[idx]
        end_idx = start_idx + self.sequence_length

        # X: Sequence of features (e.g., 30 frames x 50 coordinates)
        x_sequence = self.features[start_idx:end_idx]

        # Y: The label for the *last frame* in the sequence
        y_label = self.labels[end_idx - 1]

        return x_sequence, y_label

# --- Configuration (MUST match training) ---
SEQUENCE_LENGTH = 30  # Same as used for training
# Assuming you have the trained LabelEncoder 'le' from the training script
# If not, you must save and load it, or recreate it with ALL known classes.
# For simplicity, we assume le is available.

# --- 1. Impute NaN Values ---
print("Applying ffill and bfill to df_test...")
# Select only the feature columns for imputation
feature_cols = [col for col in df_final.columns if col.endswith('_x') or col.endswith('_y')]

# Apply FFill and BFill in sequence
df_final[feature_cols].fillna(method='ffill', inplace=True)
df_final[feature_cols].fillna(method='bfill', inplace=True)

# --- 2. Separate Features (X_test) and Target (y_test) ---
X_test_np = df_final[feature_cols].values
y_test_labels = df_final['action'].values # Keep original labels for comparison
y_test_encoded = le.transform(y_test_labels)
# Convert to PyTorch Tensor
X_test_tensor = torch.tensor(X_test_np, dtype=torch.float32)

# --- 3. Create Sequence Dataset ---
# We reuse the PoseSequenceDataset class defined during training.
# Since we need to match the structure, we use a simple DataLoader.

# IMPORTANT: Skip encoding the test labels to suppress the ValueError
# We use a placeholder tensor for the labels, which will be the correct size.
# NOTE: This means you CANNOT use 'y_test_tensor' to calculate accuracy with le.transform()
# You must handle the evaluation comparison manually later.

# ORIGINAL LINE (Caused Error): y_test_encoded = le.transform(y_test_labels)
# ORIGINAL LINE (Caused Error): y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

test_dataset = PoseSequenceDataset(X_test_tensor, y_test_tensor, SEQUENCE_LENGTH)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

print(f"Test data ready. Total sequences: {len(test_dataset):,}")

In [None]:
del df_final
del df_test
del df_pivoted
del X_test_np
del y_test_labels
del y_test_tensor
del test_dataset

In [None]:
# --- Inference Execution ---

# Ensure the model is in evaluation mode and on the correct device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model.to(DEVICE)
loaded_model.eval()

all_predictions = []
all_true_labels = []

print("Starting inference...")

with torch.no_grad(): # Essential: disables gradient calculation to save memory and speed
    for sequences, labels in test_loader:

        sequences = sequences.to(DEVICE)

        # 1. Forward Pass
        outputs = loaded_model(sequences)

        # 2. Get Predicted Class Index
        # torch.max returns (max_value, max_index). We want the index (the class ID).
        _, predicted_indices = torch.max(outputs.data, 1)

        # Store predictions and true labels
        all_predictions.extend(predicted_indices.cpu().numpy())
        all_true_labels.extend(labels.cpu().numpy()) # Store true encoded labels

print("Inference complete.")



In [None]:
# --- 3. Decode and Evaluate ---
# Decode the predicted indices back into their original string labels
print(le)
predicted_actions = le.inverse_transform(all_predictions)

# Create a final DataFrame for review
results_df = pd.DataFrame({
    'True_Action': le.inverse_transform(all_true_labels),
    'Predicted_Action': predicted_actions
})

# Calculate Final Accuracy
final_accuracy = accuracy_score(all_true_labels, all_predictions)

print("\n--- Inference Results ---")
print(f"Overall Test Accuracy: {final_accuracy:.4f}")
print("\nSample Predictions:")
print(results_df)

In [None]:
# Filter the DataFrame where the 'True_Action' column is not 'NIL'
non_nil_actions = results_df[results_df['Predicted_Action'] != 'NIL']

# Display the resulting DataFrame
print("--- Rows where True_Action is NOT 'NIL' ---")
print(len(non_nil_actions))