In [None]:
import os
import polars as pl
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from src.model import GestureBranchedModel
from src.dataset import *
from pathlib import Path
import yaml
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Load model and config
exp_dir = Path('../experiments/cmi_training_20250818_234605')
with open(exp_dir / 'configs' / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Initialize model
model = GestureBranchedModel(
    num_classes=config['model']['num_classes'],
    d_model=config['model']['d_model'],
    d_reduced=config['model']['d_reduced'],
    num_heads=config['model']['num_heads'],
    num_layers=config['model']['num_layers'],
    dropout=config['model']['dropout'],
    max_seq_length=config['model']['max_seq_length'],
    sequence_processor=config['model']['sequence_processor'],
    tof_backbone=config['model']['tof_backbone']
)

# Load trained weights
checkpoint = torch.load(exp_dir / 'models/best_model.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print("Model loaded successfully")

In [None]:
# Load training data for label encoding setup
dataset_dir = Path('../dataset')
train_sequences_df = pl.read_csv(dataset_dir / 'train.csv')
train_sequences_df = train_sequences_df.fill_null(-1.0).fill_nan(-1.0)
train_demographics_df = pl.read_csv(dataset_dir / 'train_demographics.csv')

# Prepare gesture labels and get label encoder
train_sequences_df, labelencoder, target_gesture_id, non_target_gesture_id = prepare_gesture_labels(train_sequences_df)

# Get gesture classes for mapping predictions back to gesture names
gesture_classes = list(labelencoder.classes_)
print(f"Loaded {len(gesture_classes)} gesture classes")

In [None]:
def smooth_chunk_average(logits):
    for i in range(1, logits.shape[0]-1):
        logits[i, :] = (
            logits[i-1, :] * 0.2 + 
            logits[i, :] * 0.6 + 
            logits[i+1, :] * 0.2
        )
    logits[0, :] = logits[1, :] * 0.2 + logits[0, :] * 0.8
    logits[-1, :] = logits[-2, :] * 0.2 + logits[-1, :] * 0.8
    return logits

In [None]:
def predict(data_batch):
    """
    Predict gesture for a single sequence using attention-based chunk aggregation.
    
    Args:
        data_batch: Tuple of (sequence_df, demographics_df) from the inference server
        
    Returns:
        str: Predicted gesture name
    """
    sequence_df, demographics_df = data_batch
    
    # Process sequence into chunks
    sequence_processor = SequenceProcessor()
    sequences = sequence_processor.process_dataframe(
        df=sequence_df, 
        max_seq_length=config['data']['chunk_size']
    )
    
    if not sequences:
        return "Text on phone"  # Default prediction if no valid sequences
    
    # Create dataset for this sequence
    dataset = CMIDataset(sequences=sequences, max_length=config['data']['chunk_size'])
    
    # Collect all chunks for this sequence
    chunk_data_list = []
    for data in dataset:
        chunk_data_list.append({
            'tof': data['tof'].to(device),  # (seq_len, 320)
            'acc': data['acc'].to(device),  # (seq_len, 3)
            'rot': data['rot'].to(device),  # (seq_len, 4)
            'thm': data['thm'].to(device),  # (seq_len, 5)
            'chunk_start_idx': data['chunk_start_idx'].to(device)  # scalar
        })
    
    if not chunk_data_list:
        return "Text on phone"  # Default prediction
    
    # Use attention-based chunk aggregation
    with torch.no_grad():
        if len(chunk_data_list) == 1:
            # Single chunk - use regular forward pass
            chunk = chunk_data_list[0]
            logits = model(
                tof_features=chunk['tof'].unsqueeze(0),  # (1, seq_len, 320)
                acc_features=chunk['acc'].unsqueeze(0),  # (1, seq_len, 3)
                rot_features=chunk['rot'].unsqueeze(0),  # (1, seq_len, 4)
                thm_features=chunk['thm'].unsqueeze(0),  # (1, seq_len, 5)
                chunk_start_idx=chunk['chunk_start_idx'].unsqueeze(0)  # (1,)
            ).squeeze(0)  # (num_classes,)
        else:
            # Multiple chunks - use attention-based aggregation
            logits, attention_weights = model.predict_with_chunks(chunk_data_list)
            
            # Optional: Log attention weights for debugging
            # print(f"Chunk attention weights: {attention_weights.cpu().numpy()}")
    
    # Get predicted class
    predicted_class_id = logits.argmax().item()
    
    # Map class ID back to gesture name
    predicted_gesture = gesture_classes[predicted_class_id]
    
    return predicted_gesture

print("Predict function with chunk attention defined")

In [None]:
grouped = train_sequences_df.group_by("sequence_id")
predictions = []
true_labels = []
for sequence_id, sequence in tqdm(grouped):
    # Predict gesture for this sequence
    predicted_gesture = predict((sequence, train_demographics_df))
    
    predictions.append(predicted_gesture)
    true_labels.append(sequence['gesture'][0])

In [None]:
from sklearn.metrics import f1_score, accuracy_score
f1 = f1_score(true_labels, predictions, average='weighted')
accuracy = accuracy_score(true_labels, predictions)

print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

In [None]:
# mean(logits)
# F1 Score: 0.7187
# Accuracy: 0.7314

# smooth_chunk_average(logits)


In [None]:
# # Initialize and run the inference server
# import kaggle_evaluation.cmi_inference_server

# inference_server = kaggle_evaluation.cmi_inference_server.CMIInferenceServer(predict)

# if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
#     # Running in Kaggle competition environment
#     inference_server.serve()
# else:
#     # # Running locally for testing
#     inference_server.run_local_gateway(
#         data_paths=(
#             '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv',
#             '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv',
#         )
#     )

# # Show results if running locally
# if not os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
#         results = pd.read_parquet("submission.parquet")
#         print(results.head())