In [None]:
import os
import sys
from pathlib import Path
from collections import OrderedDict
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# Check if running on Kaggle
ON_KAGGLE = os.path.exists('/kaggle/input')

if ON_KAGGLE:
    # Kaggle paths
    DATA_DIR = Path('/kaggle/input/MABe-mouse-behavior-detection')
    CHECKPOINT_PATH = Path('/kaggle/input/convtransformer/MABe-ConvTransformer/outputs/checkpoints/mabe-epoch=97-val_segment_f1=0.2440.ckpt')
    OUTPUT_DIR = Path('/kaggle/working')
    # Add src to path
    sys.path.insert(0, '/kaggle/input/convtransformer/MABe-ConvTransformer')

print(f"Data directory: {DATA_DIR}")
print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Import from src
from src.models.lightning_module import BehaviorRecognitionModule
from src.data.preprocessing import (
    CoordinateNormalizer,
    TemporalResampler,
    MissingDataHandler,
    BodyPartMapper
)
from src.utils.postprocessing import (
    aggregate_window_predictions,
    extract_segments,
    merge_segments,
    apply_nms,
    create_submission,
    BehaviorSegment,
)

print("Imports successful!")

# Configuration

In [None]:
# Model and data configuration
CONFIG = {
    'window_size': 512,
    'stride': 256,
    'target_fps': 30.0,
    'batch_size': 64,
    'num_workers': 2,
    
    # Evaluation settings - OPTIMIZED based on validation analysis
    # Lower threshold to capture more behaviors (was 0.39, best macro_f1 at 0.30)
    'threshold': 0.30,
    'min_duration': 5,
    'smoothing_kernel': 5,
    'nms_threshold': 0.3,
    'merge_gap': 5,
    
    # Per-behavior thresholds for rare classes
    # (submit has max prob ~0.05, so we need special handling)
    'behavior_thresholds': {
        'submit': 0.02,  # Very rare, need lower threshold
        'chaseattack': 0.25,  # Rare but detectable
    },
}

# Behavior classes (must match training)
BEHAVIORS = [
    # Self behaviors
    'biteobject', 'climb', 'dig', 'exploreobject', 'freeze',
    'genitalgroom', 'huddle', 'rear', 'rest', 'run', 'selfgroom',
    # Pair behaviors
    'allogroom', 'approach', 'attack', 'attemptmount', 'avoid',
    'chase', 'chaseattack', 'defend', 'disengage', 'dominance',
    'dominancegroom', 'dominancemount', 'ejaculate', 'escape',
    'flinch', 'follow', 'intromit', 'mount', 'reciprocalsniff',
    'shepherd', 'sniff', 'sniffbody', 'sniffface', 'sniffgenital',
    'submit', 'tussle'
]

# Test behaviors (what Kaggle evaluates)
TEST_BEHAVIORS = ['approach', 'attack', 'avoid', 'chase', 'chaseattack', 'submit', 'rear']

print(f"Number of behavior classes: {len(BEHAVIORS)}")
print(f"Test behaviors: {TEST_BEHAVIORS}")
print(f"Default threshold: {CONFIG['threshold']}")
print(f"Per-behavior thresholds: {CONFIG['behavior_thresholds']}")

# Datasets

In [None]:
class TestDataset(Dataset):
    """
    Dataset for loading test tracking data (no annotations).
    Generates sliding windows over all agent-target pairs.
    """
    
    def __init__(
        self,
        metadata_df: pd.DataFrame,
        tracking_dir: Path,
        behaviors: List[str],
        window_size: int = 512,
        stride: int = 256,
        target_fps: float = 30.0,
        tracking_cache_size: int = 4
    ):
        self.tracking_dir = Path(tracking_dir)
        self.window_size = window_size
        self.stride = stride
        self.target_fps = target_fps
        self.tracking_cache_size = max(1, tracking_cache_size)
        
        self.behaviors = behaviors
        self.num_classes = len(behaviors)
        
        # Preprocessors
        self.coord_normalizer = CoordinateNormalizer()
        self.temporal_resampler = TemporalResampler(target_fps)
        self.missing_handler = MissingDataHandler()
        self.bodypart_mapper = BodyPartMapper(use_core_only=False)
        
        # Cache
        self._tracking_cache: OrderedDict = OrderedDict()
        
        # Parse metadata
        self.metadata_df = metadata_df.copy()
        self.video_ids = metadata_df['video_id'].unique().tolist()
        
        # Build sample index
        self.samples = self._build_sample_index()
        print(f"Built {len(self.samples)} test samples from {len(self.video_ids)} videos")
    
    def _get_fps(self, metadata: Dict) -> float:
        return metadata.get('frames_per_second',
                           metadata.get('frames per second',
                                        metadata.get('fps', 30)))
    
    def _build_sample_index(self) -> List[Dict]:
        samples = []
        
        for _, row in tqdm(self.metadata_df.iterrows(), total=len(self.metadata_df), desc="Building sample index"):
            lab_id = row['lab_id']
            video_id = row['video_id']
            fps = self._get_fps(row)
            
            # Load tracking to get video length and mice
            track_path = self.tracking_dir / f"{lab_id}" / f"{video_id}.parquet"
            if not track_path.exists():
                continue
            
            track_df = pd.read_parquet(track_path)
            n_frames = track_df['video_frame'].max() + 1
            mice = sorted(track_df['mouse_id'].unique())
            
            # Adjust for resampling
            if fps != self.target_fps:
                duration = n_frames / fps
                n_frames = int(duration * self.target_fps)
            
            # Generate windows for each agent-target pair
            for start in range(0, max(1, n_frames - self.window_size + 1), self.stride):
                for agent in mice:
                    for target in mice:
                        samples.append({
                            'lab_id': lab_id,
                            'video_id': video_id,
                            'start_frame': start,
                            'agent_id': agent,
                            'target_id': target,
                            'metadata': row.to_dict()
                        })
        
        return samples
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample_info = self.samples[idx]
        features, valid_mask = self._load_tracking(sample_info)
        
        # Dummy labels (zeros) for test data
        labels = np.zeros((self.window_size, self.num_classes), dtype=np.float32)
        
        return {
            'features': torch.from_numpy(features),
            'labels': torch.from_numpy(labels),
            'valid_mask': torch.from_numpy(valid_mask),
            'video_id': sample_info['video_id'],
            'agent_id': sample_info['agent_id'],
            'target_id': sample_info['target_id'],
            'start_frame': sample_info['start_frame']
        }
    
    def _get_cached_tracking(self, lab_id: str, video_id: str, metadata: Dict):
        key = (lab_id, video_id)
        if key in self._tracking_cache:
            self._tracking_cache.move_to_end(key)
            return self._tracking_cache[key]
        
        track_path = self.tracking_dir / f"{lab_id}" / f"{video_id}.parquet"
        if not track_path.exists():
            raise FileNotFoundError(f"Tracking file not found: {track_path}")
        
        track_df = pd.read_parquet(track_path)
        fps = self._get_fps(metadata)
        bodyparts = sorted(track_df['bodypart'].unique().tolist())
        
        coords_by_mouse = {}
        valid_by_mouse = {}
        
        for mouse_id in track_df['mouse_id'].unique():
            raw_coords = self._extract_mouse_coords(track_df, mouse_id, bodyparts)
            mapped_coords, mapped_parts, availability = self.bodypart_mapper.map_bodyparts(raw_coords, bodyparts)
            mapped_coords = self.bodypart_mapper.compute_derived_parts(mapped_coords, mapped_parts, availability)
            
            if fps != self.target_fps:
                mapped_coords = self.temporal_resampler(mapped_coords, fps)
            
            mapped_coords = self.coord_normalizer(mapped_coords, metadata)
            mapped_coords, valid_mask = self.missing_handler.interpolate_missing(mapped_coords)
            mapped_coords = np.nan_to_num(mapped_coords, nan=0.0)
            
            coords_by_mouse[mouse_id] = mapped_coords.astype(np.float32)
            valid_by_mouse[mouse_id] = valid_mask.astype(np.float32)
        
        cache_entry = {
            'coords_by_mouse': coords_by_mouse,
            'valid_by_mouse': valid_by_mouse
        }
        self._tracking_cache[key] = cache_entry
        
        if len(self._tracking_cache) > self.tracking_cache_size:
            self._tracking_cache.popitem(last=False)
        
        return cache_entry
    
    def _extract_mouse_coords(self, track_df: pd.DataFrame, mouse_id: int, bodyparts: List[str]) -> Dict[str, np.ndarray]:
        mouse_df = track_df[track_df['mouse_id'] == mouse_id].copy()
        n_frames = track_df['video_frame'].max() + 1
        
        coords = {}
        for bp in bodyparts:
            bp_df = mouse_df[mouse_df['bodypart'] == bp].sort_values('video_frame')
            frames = bp_df['video_frame'].values
            part_coords = np.full((n_frames, 2), np.nan, dtype=np.float32)
            part_coords[frames, 0] = bp_df['x'].values
            part_coords[frames, 1] = bp_df['y'].values
            coords[bp] = part_coords
        
        return coords
    
    def _load_tracking(self, sample_info: Dict) -> Tuple[np.ndarray, np.ndarray]:
        lab_id = sample_info['lab_id']
        video_id = sample_info['video_id']
        start_frame = sample_info['start_frame']
        agent_id = sample_info['agent_id']
        target_id = sample_info['target_id']
        metadata = sample_info['metadata']
        
        cache_entry = self._get_cached_tracking(lab_id, video_id, metadata)
        coords_by_mouse = cache_entry['coords_by_mouse']
        valid_by_mouse = cache_entry['valid_by_mouse']
        
        agent_coords = coords_by_mouse.get(agent_id)
        target_coords = coords_by_mouse.get(target_id)
        agent_valid = valid_by_mouse.get(agent_id)
        target_valid = valid_by_mouse.get(target_id)
        
        if agent_coords is None or target_coords is None:
            raise ValueError(f"Missing coordinates for agent {agent_id} or target {target_id} in {video_id}")
        
        end_frame = start_frame + self.window_size
        agent_window = self._get_window(agent_coords, start_frame, end_frame)
        target_window = self._get_window(target_coords, start_frame, end_frame)
        
        features = np.concatenate([
            agent_window.reshape(self.window_size, -1),
            target_window.reshape(self.window_size, -1)
        ], axis=-1)
        
        agent_valid_window = self._get_window(agent_valid.astype(np.float32), start_frame, end_frame)
        target_valid_window = self._get_window(target_valid.astype(np.float32), start_frame, end_frame)
        valid_mask = (agent_valid_window.mean(axis=-1) > 0.5) & (target_valid_window.mean(axis=-1) > 0.5)
        
        return features.astype(np.float32), valid_mask.astype(np.float32)
    
    def _get_window(self, data: np.ndarray, start: int, end: int) -> np.ndarray:
        n_frames = data.shape[0]
        
        if start < 0:
            pre_pad = -start
            start = 0
        else:
            pre_pad = 0
        
        if end > n_frames:
            post_pad = end - n_frames
            end = n_frames
        else:
            post_pad = 0
        
        window = data[start:end]
        
        if pre_pad > 0 or post_pad > 0:
            pad_width = [(pre_pad, post_pad)] + [(0, 0)] * (window.ndim - 1)
            window = np.pad(window, pad_width, mode='edge')
        
        return window

# Model

In [None]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load checkpoint with safe globals for numpy scalars
safe_classes = [np.core.multiarray.scalar]
try:
    torch.serialization.add_safe_globals(safe_classes)
except Exception:
    pass

try:
    safe_ctx = torch.serialization.safe_globals(safe_classes)
except Exception:
    from contextlib import nullcontext
    safe_ctx = nullcontext()

print(f"Loading model from {CHECKPOINT_PATH}")
with safe_ctx:
    model = BehaviorRecognitionModule.load_from_checkpoint(
        str(CHECKPOINT_PATH),
        map_location=device,
        weights_only=False,
        strict=True,
    )

model.to(device)
model.eval()

print(f"Model loaded successfully!")
print(f"  - Model type: {model.model_name}")
print(f"  - Input dim: {model.input_dim}")
print(f"  - Num classes: {model.num_classes}")
print(f"  - Behaviors: {model.behaviors[:5]}... ({len(model.behaviors)} total)")

# Inference

In [None]:
# Load test metadata
test_csv = DATA_DIR / 'test.csv'
test_df = pd.read_csv(test_csv)
print(f"Test metadata: {len(test_df)} videos")
print(test_df.head())

# Create test dataset
test_dataset = TestDataset(
    metadata_df=test_df,
    tracking_dir=DATA_DIR / 'test_tracking',
    behaviors=BEHAVIORS,
    window_size=CONFIG['window_size'],
    stride=CONFIG['stride'],
    target_fps=CONFIG['target_fps'],
    tracking_cache_size=8
)

print(f"\nTest dataset: {len(test_dataset)} samples")

# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True,
    drop_last=False
)

def _to_scalar(value: Any) -> Any:
    """Convert tensors/NumPy scalars to plain Python values."""
    if isinstance(value, torch.Tensor):
        value = value.item()
    if isinstance(value, np.generic):
        value = value.item()
    return value


def _format_mouse_id(mouse_id: Any, agent_id: Any = None) -> str:
    """
    Normalize mouse identifiers to submission format.
    
    Args:
        mouse_id: The mouse ID to format (already 1-indexed from tracking data)
        agent_id: If provided and equals mouse_id, returns 'self' (for target formatting)
    
    Returns:
        Formatted string: 'mouse1', 'mouse2', etc. or 'self' if target == agent
    """
    mouse_id = _to_scalar(mouse_id)
    agent_id = _to_scalar(agent_id) if agent_id is not None else None
    
    # Check if this is a self-reference (target == agent)
    if agent_id is not None and mouse_id == agent_id:
        return 'self'
    
    # Handle string mouse IDs
    if isinstance(mouse_id, str):
        cleaned = mouse_id.strip()
        if cleaned.lower().startswith('mouse'):
            return cleaned
        if cleaned.lstrip('-').isdigit():
            mouse_id = int(cleaned)
        else:
            return cleaned
    
    # Convert to integer and format as mouse ID
    try:
        mouse_int = int(mouse_id)
    except (TypeError, ValueError):
        return str(mouse_id)
    
    # Tracking data uses 1-indexed mouse IDs (1, 2, 3, 4)
    # Format directly as 'mouse1', 'mouse2', etc. to match behaviors_labeled
    return f'mouse{mouse_int}'

In [None]:
def collect_predictions(
    model: BehaviorRecognitionModule,
    dataloader: DataLoader,
    device: torch.device
) -> List[Dict]:
    """
    Run model on dataloader and collect window-level predictions.
    """
    window_predictions = []
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Running inference"):
            features = batch['features'].to(device)
            mask = batch.get('valid_mask')
            if mask is not None:
                mask = mask.to(device)
            
            logits = model(features, mask)
            probs = torch.sigmoid(logits)
            
            for i in range(features.shape[0]):
                window_prob = probs[i].detach().cpu().numpy()
                if mask is not None:
                    window_mask = mask[i].detach().cpu().numpy().reshape(-1, 1)
                    window_prob = window_prob * window_mask
                
                window_predictions.append({
                    'video_id': _to_scalar(batch['video_id'][i]),
                    'agent_id': _to_scalar(batch['agent_id'][i]),
                    'target_id': _to_scalar(batch['target_id'][i]),
                    'start_frame': int(_to_scalar(batch['start_frame'][i])),
                    'probabilities': window_prob,
                })
    
    return window_predictions


# Run inference
print("Running inference on test data...")
window_predictions = collect_predictions(model, test_loader, device)
print(f"Collected {len(window_predictions)} window predictions")

# Post Processing

In [None]:
def resolve_overlaps(segments: List[BehaviorSegment], min_duration: int = 5) -> List[BehaviorSegment]:
    """
    Resolve overlapping behaviors for the same agent-target pair.
    
    When two behaviors overlap for the same (video_id, agent_id, target_id),
    the latter behavior's start_frame is adjusted to begin right after the
    previous behavior's stop_frame.
    
    Args:
        segments: List of BehaviorSegment objects
        min_duration: Minimum duration for a valid segment
    
    Returns:
        List of BehaviorSegment objects with overlaps resolved
    """
    from collections import defaultdict
    
    # Group segments by (video_id, agent_id, target_id)
    groups = defaultdict(list)
    for seg in segments:
        key = (seg.video_id, seg.agent_id, seg.target_id)
        groups[key].append(seg)
    
    resolved_segments = []
    
    for key, group_segments in groups.items():
        # Sort by start_frame, then by stop_frame (to handle ties)
        group_segments.sort(key=lambda s: (s.start_frame, s.stop_frame))
        
        # Track the end of the last non-overlapping segment
        last_end = -1
        
        for seg in group_segments:
            new_start = seg.start_frame
            new_stop = seg.stop_frame
            
            # If this segment overlaps with the previous one, adjust start
            if new_start < last_end:
                new_start = last_end
            
            # Check if segment is still valid after adjustment
            if new_start < new_stop and (new_stop - new_start) >= min_duration:
                # Create a new segment with adjusted start
                resolved_seg = BehaviorSegment(
                    video_id=seg.video_id,
                    agent_id=seg.agent_id,
                    target_id=seg.target_id,
                    action=seg.action,
                    start_frame=new_start,
                    stop_frame=new_stop,
                    confidence=seg.confidence,
                )
                resolved_segments.append(resolved_seg)
                last_end = new_stop
            # If segment becomes invalid (too short or start >= stop), skip it
    
    return resolved_segments


def extract_segments_with_per_behavior_threshold(
    frame_probs: np.ndarray,
    behavior_names: List[str],
    default_threshold: float = 0.3,
    behavior_thresholds: Dict[str, float] = None,
    min_duration: int = 5,
    smoothing_kernel: int = 5
) -> List[Tuple[str, int, int, float]]:
    """
    Convert frame-level probabilities to behavior segments with per-behavior thresholds.
    
    Args:
        frame_probs: Array of shape (n_frames, n_behaviors)
        behavior_names: List of behavior class names
        default_threshold: Default probability threshold
        behavior_thresholds: Dict mapping behavior names to custom thresholds
        min_duration: Minimum segment duration in frames
        smoothing_kernel: Size of median filter for smoothing
    
    Returns:
        List of (behavior, start_frame, stop_frame, confidence) tuples
    """
    from scipy.ndimage import median_filter
    
    n_frames, n_behaviors = frame_probs.shape
    segments = []
    behavior_thresholds = behavior_thresholds or {}
    
    for behavior_idx in range(n_behaviors):
        probs = frame_probs[:, behavior_idx].copy()
        behavior = behavior_names[behavior_idx]
        
        # Get threshold for this behavior
        threshold = behavior_thresholds.get(behavior, default_threshold)
        
        # Apply temporal smoothing
        if smoothing_kernel > 1:
            probs = median_filter(probs, size=smoothing_kernel)
        
        # Threshold to binary
        binary = (probs >= threshold).astype(np.int32)
        
        # Find contiguous regions
        diff = np.diff(np.concatenate([[0], binary, [0]]))
        starts = np.where(diff == 1)[0]
        ends = np.where(diff == -1)[0]
        
        # Filter by duration and add confidence
        for start, end in zip(starts, ends):
            duration = end - start
            if duration >= min_duration:
                confidence = float(probs[start:end].mean())
                segments.append((behavior, start, end, confidence))
    
    return segments


def build_submission_rows(
    window_predictions: List[Dict],
    behaviors: List[str],
    threshold: float,
    min_duration: int,
    smoothing_kernel: int,
    nms_threshold: float,
    merge_gap: int = 5,
    behavior_thresholds: Dict[str, float] = None,
) -> List[Dict]:
    """
    Convert window-level predictions into submission-format rows.
    
    Now supports per-behavior thresholds for rare classes.
    """
    print("Aggregating window predictions...")
    aggregated = aggregate_window_predictions(window_predictions, overlap_strategy='average')
    print(f"  {len(aggregated)} unique (video, agent, target) combinations")
    
    all_segments: List[BehaviorSegment] = []
    
    print("Extracting segments...")
    for (video_id, agent_id, target_id), frame_probs in tqdm(aggregated.items(), desc="Processing videos"):
        # Use per-behavior thresholds
        raw_segments = extract_segments_with_per_behavior_threshold(
            frame_probs,
            behaviors,
            default_threshold=threshold,
            behavior_thresholds=behavior_thresholds,
            min_duration=min_duration,
            smoothing_kernel=smoothing_kernel,
        )
        merged = merge_segments(raw_segments, gap_threshold=merge_gap)
        final_segments = apply_nms(merged, iou_threshold=nms_threshold)
        
        # Format mouse IDs: agent as "mouse1", "mouse2", etc.
        # target as "self" if same as agent, otherwise "mouse1", "mouse2", etc.
        formatted_agent = _format_mouse_id(agent_id)
        formatted_target = _format_mouse_id(target_id, agent_id=agent_id)
        
        for behavior, start, stop, conf in final_segments:
            all_segments.append(BehaviorSegment(
                video_id=int(_to_scalar(video_id)),
                agent_id=formatted_agent,
                target_id=formatted_target,
                action=behavior,
                start_frame=int(start),
                stop_frame=int(stop),
                confidence=float(conf),
            ))
    
    # Resolve overlapping behaviors for the same agent-target pair
    print("Resolving overlapping behaviors...")
    all_segments = resolve_overlaps(all_segments, min_duration=min_duration)
    print(f"  {len(all_segments)} segments after overlap resolution")
    
    # Sort segments and create submission rows directly
    print("Creating submission rows...")
    all_segments.sort(key=lambda s: (s.video_id, s.agent_id, s.target_id, s.start_frame))
    
    rows = []
    for row_id, seg in enumerate(all_segments):
        if seg.duration >= min_duration and seg.start_frame < seg.stop_frame:
            rows.append({
                'row_id': row_id,
                'video_id': seg.video_id,
                'agent_id': seg.agent_id,
                'target_id': seg.target_id,
                'action': seg.action,
                'start_frame': seg.start_frame,
                'stop_frame': seg.stop_frame,
            })
    
    return rows


# Build submission with per-behavior thresholds
submission_rows = build_submission_rows(
    window_predictions,
    BEHAVIORS,
    threshold=CONFIG['threshold'],
    min_duration=CONFIG['min_duration'],
    smoothing_kernel=CONFIG['smoothing_kernel'],
    nms_threshold=CONFIG['nms_threshold'],
    merge_gap=CONFIG['merge_gap'],
    behavior_thresholds=CONFIG.get('behavior_thresholds', {}),
)

print(f"\nGenerated {len(submission_rows)} submission rows")

# Submission

In [None]:
# Create submission DataFrame
submission_df = pd.DataFrame(
    submission_rows,
    columns=['row_id', 'video_id', 'agent_id', 'target_id', 'action', 'start_frame', 'stop_frame']
)

# Save submission
submission_path = OUTPUT_DIR / 'submission.csv'
submission_df.to_csv(submission_path, index=False)

print(f"Submission saved to {submission_path}")
print(f"\nSubmission shape: {submission_df.shape}")
print(f"\nSubmission preview:")
print(submission_df.head(20))

In [None]:
# Submission statistics
print("\n=== Submission Statistics ===")
print(f"Total predictions: {len(submission_df)}")
print(f"Unique videos: {submission_df['video_id'].nunique()}")
print(f"\nPredictions per action:")
print(submission_df['action'].value_counts().head(20))

print(f"\nAverage segment duration:")
submission_df['duration'] = submission_df['stop_frame'] - submission_df['start_frame']
print(f"  Mean: {submission_df['duration'].mean():.1f} frames")
print(f"  Median: {submission_df['duration'].median():.1f} frames")
print(f"  Min: {submission_df['duration'].min()} frames")
print(f"  Max: {submission_df['duration'].max()} frames")

# Validate test behavior coverage
print("\n=== Test Behavior Coverage ===")
predicted_behaviors = set(submission_df['action'].unique())
missing_test_behaviors = set(TEST_BEHAVIORS) - predicted_behaviors
covered_test_behaviors = set(TEST_BEHAVIORS) & predicted_behaviors

print(f"Test behaviors covered: {len(covered_test_behaviors)}/{len(TEST_BEHAVIORS)}")
for beh in TEST_BEHAVIORS:
    count = (submission_df['action'] == beh).sum()
    status = "✓" if beh in covered_test_behaviors else "✗ MISSING"
    print(f"  {beh:15s}: {count:5d} predictions {status}")

if missing_test_behaviors:
    print(f"\n⚠️  WARNING: Missing predictions for: {missing_test_behaviors}")
    print("   Consider lowering per-behavior thresholds for these behaviors")
else:
    print("\n✓ All test behaviors have predictions")

In [None]:
# Final check
print("\n=== Final Checks ===")
print(f"Submission file exists: {submission_path.exists()}")
print(f"File size: {submission_path.stat().st_size / 1024:.1f} KB")

# Verify columns
expected_columns = ['row_id', 'video_id', 'agent_id', 'target_id', 'action', 'start_frame', 'stop_frame']
actual_columns = list(submission_df.columns[:7])
print(f"\nColumns match expected: {actual_columns == expected_columns}")

print("\nDone!")