# MABe-2.0 Behavior Recognition: MS-TCN++ Training and Inference

This notebook implements a multi-stage temporal convolutional network (MS-TCN++) for multi-agent behavior recognition in the MABe Challenge.

In [None]:
# Install dependencies (for Kaggle)
!pip install pytorch-lightning omegaconf rich -q

In [2]:
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from scipy.ndimage import median_filter, uniform_filter1d
from tqdm.auto import tqdm

# Set random seed
pl.seed_everything(42)

ModuleNotFoundError: No module named 'numpy'

In [None]:
# Configuration
CONFIG = {
    'data_dir': '/kaggle/input/MABe-mouse-behavior-detection',
    'output_dir': '/kaggle/working',
    'window_size': 256,
    'stride': 128,
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 0.0005,
    'num_stages': 4,
    'num_layers': 8,
    'num_f_maps': 64,
    'dropout': 0.5,
    'threshold': 0.5,
    'min_duration': 5
}

# Kaggle paths
INPUT_DIR = Path(CONFIG['data_dir'])
OUTPUT_DIR = Path(CONFIG['output_dir'])

## 1. Data Loading and Preprocessing

In [None]:
class MABeDataset(Dataset):
    """Dataset for MABe behavior recognition."""
    
    def __init__(
        self,
        metadata_df: pd.DataFrame,
        tracking_dir: Path,
        annotation_dir: Optional[Path] = None,
        behaviors: Optional[List[str]] = None,
        window_size: int = 256,
        stride: int = 128,
        is_train: bool = True
    ):
        self.tracking_dir = Path(tracking_dir)
        self.annotation_dir = Path(annotation_dir) if annotation_dir else None
        self.window_size = window_size
        self.stride = stride
        self.is_train = is_train
        self.metadata_df = metadata_df
        
        # Build behavior vocabulary
        if behaviors is None:
            self.behaviors = self._collect_behaviors()
        else:
            self.behaviors = behaviors
        self.behavior_to_idx = {b: i for i, b in enumerate(self.behaviors)}
        self.num_classes = len(self.behaviors)
        
        # Build sample index
        self.samples = self._build_samples()
        
    def _collect_behaviors(self) -> List[str]:
        """Collect unique behaviors from annotations."""
        if self.annotation_dir is None:
            return []
        behaviors = set()
        for _, row in self.metadata_df.iterrows():
            ann_path = self.annotation_dir / row['lab_id'] / f"{row['video_id']}.parquet"
            if ann_path.exists():
                ann_df = pd.read_parquet(ann_path)
                behaviors.update(ann_df['action'].unique())
        return sorted(list(behaviors))
    
    def _build_samples(self) -> List[Dict]:
        """Build sample index."""
        samples = []
        for _, row in self.metadata_df.iterrows():
            lab_id = row['lab_id']
            video_id = row['video_id']
            
            track_path = self.tracking_dir / lab_id / f"{video_id}.parquet"
            if not track_path.exists():
                continue
                
            track_df = pd.read_parquet(track_path)
            n_frames = track_df['video_frame'].max() + 1
            mice = track_df['mouse_id'].unique()
            
            for start in range(0, max(1, n_frames - self.window_size + 1), self.stride):
                for agent in mice:
                    for target in mice:
                        samples.append({
                            'lab_id': lab_id,
                            'video_id': video_id,
                            'start_frame': start,
                            'agent_id': agent,
                            'target_id': target,
                            'metadata': row.to_dict()
                        })
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load tracking data
        features = self._load_features(sample)
        
        # Load labels if available
        if self.annotation_dir:
            labels = self._load_labels(sample)
        else:
            labels = np.zeros((self.window_size, self.num_classes), dtype=np.float32)
        
        return {
            'features': torch.from_numpy(features),
            'labels': torch.from_numpy(labels),
            'video_id': sample['video_id'],
            'agent_id': sample['agent_id'],
            'target_id': sample['target_id'],
            'start_frame': sample['start_frame']
        }
    
    def _load_features(self, sample) -> np.ndarray:
        """Load and preprocess tracking features."""
        track_path = self.tracking_dir / sample['lab_id'] / f"{sample['video_id']}.parquet"
        track_df = pd.read_parquet(track_path)
        
        # Extract coordinates for agent and target
        agent_coords = self._extract_coords(track_df, sample['agent_id'])
        target_coords = self._extract_coords(track_df, sample['target_id'])
        
        # Get window
        start = sample['start_frame']
        end = start + self.window_size
        
        agent_window = self._get_window(agent_coords, start, end)
        target_window = self._get_window(target_coords, start, end)
        
        # Normalize
        pix_per_cm = sample['metadata'].get('pix per cm (approx)', 1.0)
        if pix_per_cm and pix_per_cm > 0:
            agent_window = agent_window / pix_per_cm
            target_window = target_window / pix_per_cm
        
        # Flatten and concatenate
        features = np.concatenate([
            agent_window.reshape(self.window_size, -1),
            target_window.reshape(self.window_size, -1)
        ], axis=-1)
        
        # Handle NaN
        features = np.nan_to_num(features, nan=0.0)
        
        return features.astype(np.float32)
    
    def _extract_coords(self, track_df, mouse_id) -> np.ndarray:
        """Extract coordinates for a mouse."""
        mouse_df = track_df[track_df['mouse_id'] == mouse_id]
        n_frames = track_df['video_frame'].max() + 1
        bodyparts = mouse_df['bodypart'].unique()
        
        coords = np.full((n_frames, len(bodyparts), 2), np.nan, dtype=np.float32)
        
        for i, bp in enumerate(bodyparts):
            bp_df = mouse_df[mouse_df['bodypart'] == bp].sort_values('video_frame')
            frames = bp_df['video_frame'].values
            coords[frames, i, 0] = bp_df['x'].values
            coords[frames, i, 1] = bp_df['y'].values
        
        return coords
    
    def _get_window(self, data, start, end) -> np.ndarray:
        """Extract window with padding."""
        n_frames = data.shape[0]
        
        if start < 0:
            pre_pad = -start
            start = 0
        else:
            pre_pad = 0
            
        if end > n_frames:
            post_pad = end - n_frames
            end = n_frames
        else:
            post_pad = 0
        
        window = data[start:end]
        
        if pre_pad > 0 or post_pad > 0:
            pad_width = [(pre_pad, post_pad)] + [(0, 0)] * (window.ndim - 1)
            window = np.pad(window, pad_width, mode='edge')
        
        return window
    
    def _load_labels(self, sample) -> np.ndarray:
        """Load annotation labels."""
        labels = np.zeros((self.window_size, self.num_classes), dtype=np.float32)
        
        ann_path = self.annotation_dir / sample['lab_id'] / f"{sample['video_id']}.parquet"
        if not ann_path.exists():
            return labels
        
        ann_df = pd.read_parquet(ann_path)
        
        # Filter for this pair
        pair_anns = ann_df[
            (ann_df['agent_id'] == f"mouse{sample['agent_id']}") &
            ((ann_df['target_id'] == f"mouse{sample['target_id']}") | 
             (ann_df['target_id'] == 'self'))
        ]
        
        start = sample['start_frame']
        
        for _, row in pair_anns.iterrows():
            action = row['action']
            if action not in self.behavior_to_idx:
                continue
            
            action_idx = self.behavior_to_idx[action]
            window_start = max(0, row['start_frame'] - start)
            window_end = min(self.window_size, row['stop_frame'] - start)
            
            if window_start < window_end:
                labels[window_start:window_end, action_idx] = 1.0
        
        return labels

## 2. MS-TCN++ Model

In [None]:
class DilatedResidualLayer(nn.Module):
    """Dilated residual layer."""
    
    def __init__(self, dilation, in_channels, out_channels, kernel_size=3, dropout=0.3):
        super().__init__()
        padding = (kernel_size - 1) * dilation // 2
        
        self.conv_dilated = nn.Conv1d(in_channels, out_channels, kernel_size,
                                       padding=padding, dilation=dilation)
        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
        self.dropout = nn.Dropout(dropout)
        self.skip = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        out = F.relu(self.conv_dilated(x))
        out = self.conv_1x1(out)
        out = self.dropout(out)
        return out + self.skip(x)


class SingleStageTCN(nn.Module):
    """Single stage TCN."""
    
    def __init__(self, in_channels, num_layers, num_f_maps, num_classes, kernel_size=3, dropout=0.3):
        super().__init__()
        self.conv_in = nn.Conv1d(in_channels, num_f_maps, 1)
        self.layers = nn.ModuleList([
            DilatedResidualLayer(2**i, num_f_maps, num_f_maps, kernel_size, dropout)
            for i in range(num_layers)
        ])
        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
    
    def forward(self, x):
        out = self.conv_in(x)
        for layer in self.layers:
            out = layer(out)
        return self.conv_out(out)


class RefinementStage(nn.Module):
    """Refinement stage."""
    
    def __init__(self, num_layers, num_f_maps, num_classes, kernel_size=3, dropout=0.3):
        super().__init__()
        self.conv_in = nn.Conv1d(num_classes, num_f_maps, 1)
        self.layers = nn.ModuleList([
            DilatedResidualLayer(2**i, num_f_maps, num_f_maps, kernel_size, dropout)
            for i in range(num_layers)
        ])
        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
    
    def forward(self, x):
        out = self.conv_in(x)
        for layer in self.layers:
            out = layer(out)
        return self.conv_out(out)


class MSTCN(nn.Module):
    """Multi-Stage TCN."""
    
    def __init__(self, input_dim, num_classes, num_stages=4, num_layers=10, 
                 num_f_maps=64, kernel_size=3, dropout=0.5):
        super().__init__()
        
        self.stage1 = SingleStageTCN(input_dim, num_layers, num_f_maps, num_classes, kernel_size, dropout)
        self.stages = nn.ModuleList([
            RefinementStage(num_layers, num_f_maps, num_classes, kernel_size, dropout)
            for _ in range(num_stages - 1)
        ])
    
    def forward(self, x, mask=None):
        x = x.transpose(1, 2)  # (B, C, T)
        
        stage_outputs = []
        out = self.stage1(x)
        stage_outputs.append(out)
        
        for stage in self.stages:
            out = stage(F.softmax(out, dim=1))
            stage_outputs.append(out)
        
        final = out.transpose(1, 2)  # (B, T, C)
        stage_outputs = [s.transpose(1, 2) for s in stage_outputs]
        
        return final, stage_outputs

## 3. Training Module

In [None]:
class BehaviorModule(pl.LightningModule):
    """Lightning module for training."""
    
    def __init__(self, input_dim, num_classes, behaviors, config):
        super().__init__()
        self.save_hyperparameters()
        
        self.behaviors = behaviors
        self.config = config
        
        self.model = MSTCN(
            input_dim=input_dim,
            num_classes=num_classes,
            num_stages=config['num_stages'],
            num_layers=config['num_layers'],
            num_f_maps=config['num_f_maps'],
            dropout=config['dropout']
        )
    
    def forward(self, x, mask=None):
        predictions, _ = self.model(x, mask)
        return predictions
    
    def _compute_loss(self, predictions, stage_outputs, labels):
        # Multi-stage loss
        total_loss = 0
        for stage_out in stage_outputs:
            loss = F.binary_cross_entropy_with_logits(stage_out, labels)
            total_loss += loss
        
        # Smoothing loss
        log_probs = F.log_softmax(predictions, dim=-1)
        diff = log_probs[:, 1:] - log_probs[:, :-1]
        smooth_loss = torch.clamp(diff ** 2, 0, 16).mean()
        
        return total_loss / len(stage_outputs) + 0.15 * smooth_loss
    
    def training_step(self, batch, batch_idx):
        features = batch['features']
        labels = batch['labels']
        
        predictions, stage_outputs = self.model(features)
        loss = self._compute_loss(predictions, stage_outputs, labels)
        
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        features = batch['features']
        labels = batch['labels']
        
        predictions, stage_outputs = self.model(features)
        loss = self._compute_loss(predictions, stage_outputs, labels)
        
        # F1 score
        probs = torch.sigmoid(predictions)
        preds = (probs > 0.5).float()
        
        tp = (preds * labels).sum()
        fp = (preds * (1 - labels)).sum()
        fn = ((1 - preds) * labels).sum()
        
        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        
        self.log('val_loss', loss)
        self.log('val_f1', f1)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.config['learning_rate'],
            weight_decay=0.0001
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config['num_epochs']
        )
        return [optimizer], [scheduler]

## 4. Training

In [None]:
# Load metadata
train_df = pd.read_csv(INPUT_DIR / 'train.csv')
print(f"Training videos: {len(train_df)}")

# Create datasets
train_dataset = MABeDataset(
    metadata_df=train_df.iloc[:-10],  # Use most for training
    tracking_dir=INPUT_DIR / 'train_tracking',
    annotation_dir=INPUT_DIR / 'train_annotation',
    window_size=CONFIG['window_size'],
    stride=CONFIG['stride']
)

val_dataset = MABeDataset(
    metadata_df=train_df.iloc[-10:],  # Last 10 for validation
    tracking_dir=INPUT_DIR / 'train_tracking',
    annotation_dir=INPUT_DIR / 'train_annotation',
    behaviors=train_dataset.behaviors,
    window_size=CONFIG['window_size'],
    stride=CONFIG['window_size']
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Behaviors: {train_dataset.behaviors}")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# Get input dimension from first sample
sample = train_dataset[0]
input_dim = sample['features'].shape[-1]
print(f"Input dimension: {input_dim}")

In [None]:
# Initialize model
model = BehaviorModule(
    input_dim=input_dim,
    num_classes=train_dataset.num_classes,
    behaviors=train_dataset.behaviors,
    config=CONFIG
)

# Callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=OUTPUT_DIR,
        filename='best-{epoch:02d}-{val_f1:.4f}',
        monitor='val_f1',
        mode='max',
        save_top_k=1
    ),
    EarlyStopping(
        monitor='val_f1',
        mode='max',
        patience=10
    )
]

# Trainer
trainer = pl.Trainer(
    max_epochs=CONFIG['num_epochs'],
    accelerator='auto',
    callbacks=callbacks,
    log_every_n_steps=10
)

In [None]:
# Train
trainer.fit(model, train_loader, val_loader)

## 5. Inference

In [None]:
def extract_segments(frame_probs, behavior_names, threshold=0.5, min_duration=5):
    """Convert frame predictions to segments."""
    segments = []
    n_frames, n_behaviors = frame_probs.shape
    
    for b_idx in range(n_behaviors):
        probs = median_filter(frame_probs[:, b_idx], size=5)
        binary = (probs >= threshold).astype(int)
        
        # Find contiguous regions
        diff = np.diff(np.concatenate([[0], binary, [0]]))
        starts = np.where(diff == 1)[0]
        ends = np.where(diff == -1)[0]
        
        for start, end in zip(starts, ends):
            if end - start >= min_duration:
                segments.append({
                    'action': behavior_names[b_idx],
                    'start_frame': int(start),
                    'stop_frame': int(end),
                    'confidence': float(probs[start:end].mean())
                })
    
    return segments

In [None]:
# Load test data
test_df = pd.read_csv(INPUT_DIR / 'test.csv')

test_dataset = MABeDataset(
    metadata_df=test_df,
    tracking_dir=INPUT_DIR / 'test_tracking',
    annotation_dir=None,
    behaviors=train_dataset.behaviors,
    window_size=CONFIG['window_size'],
    stride=CONFIG['window_size'] // 2
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=4
)

print(f"Test samples: {len(test_dataset)}")

In [None]:
# Run inference
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

all_predictions = defaultdict(list)

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Running inference"):
        features = batch['features'].to(device)
        predictions = model(features)
        probs = torch.sigmoid(predictions).cpu().numpy()
        
        for i in range(len(batch['video_id'])):
            key = (
                batch['video_id'][i],
                batch['agent_id'][i],
                batch['target_id'][i]
            )
            all_predictions[key].append({
                'start_frame': batch['start_frame'][i],
                'probs': probs[i]
            })

In [None]:
# Aggregate and extract segments
submission_rows = []
row_id = 0

for (video_id, agent_id, target_id), preds in tqdm(all_predictions.items()):
    # Find total length
    max_frame = max(p['start_frame'] + p['probs'].shape[0] for p in preds)
    n_classes = preds[0]['probs'].shape[-1]
    
    # Aggregate
    sum_probs = np.zeros((max_frame, n_classes))
    counts = np.zeros(max_frame)
    
    for p in preds:
        start = p['start_frame']
        end = start + p['probs'].shape[0]
        sum_probs[start:end] += p['probs']
        counts[start:end] += 1
    
    mask = counts > 0
    sum_probs[mask] /= counts[mask, np.newaxis]
    
    # Extract segments
    segments = extract_segments(
        sum_probs,
        train_dataset.behaviors,
        threshold=CONFIG['threshold'],
        min_duration=CONFIG['min_duration']
    )
    
    for seg in segments:
        submission_rows.append({
            'row_id': row_id,
            'video_id': video_id,
            'agent_id': f"mouse{agent_id}" if isinstance(agent_id, int) else agent_id,
            'target_id': f"mouse{target_id}" if isinstance(target_id, int) and target_id >= 0 else 'self',
            'action': seg['action'],
            'start_frame': seg['start_frame'],
            'stop_frame': seg['stop_frame']
        })
        row_id += 1

In [None]:
# Create submission
submission_df = pd.DataFrame(submission_rows)
submission_df = submission_df.sort_values(['video_id', 'agent_id', 'target_id', 'start_frame'])
submission_df.to_csv(OUTPUT_DIR / 'submission.csv', index=False)

print(f"Submission saved with {len(submission_df)} rows")
print(submission_df.head(10))