In [None]:
import polars as pl
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import sys
import os
from src.model import GestureBranchedModel
from src.dataset import *
from pathlib import Path
import yaml
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_config(seed=0):
    exp_dir = Path(f'../experiments/seed{seed}')
    with open(exp_dir / 'configs' / 'config.yaml', 'r') as f:
        config = yaml.safe_load(f)
    return exp_dir, config

def load_model(exp_dir, config):
    # Initialize model
    model = GestureBranchedModel(
        num_classes=config['model']['num_classes'],
        d_model=config['model']['d_model'],
        hidden_dim=config['model']['hidden_dim'],
        num_heads=config['model']['num_heads'],
        num_layers=config['model']['num_layers'],
        dropout=config['model']['dropout'],
        sequence_processor=config['model']['sequence_processor'],
        tof_backbone=config['model']['tof_backbone']
    )

    # Load trained weights
    checkpoint = torch.load(exp_dir / 'models/best_model.pt', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    return model

In [None]:
# Load training data for label encoding setup
train_sequences_df = pl.read_csv('../dataset/train.csv')
train_sequences_df = train_sequences_df.fill_null(0.0).fill_nan(0.0)

# Prepare gesture labels and get label encoder
train_sequences_df, labelencoder, target_gesture_id, non_target_gesture_id = prepare_gesture_labels(train_sequences_df)

# Get gesture classes for mapping predictions back to gesture names
gesture_classes = list(labelencoder.classes_)

In [None]:
def smooth_prediction(logits):
    for i in range(1, logits.shape[0]-1):
        logits[i, :] = (
            logits[i-1, :] * 0.2 + 
            logits[i, :] * 0.6 + 
            logits[i+1, :] * 0.2
        )
    logits[0, :] = logits[1, :] * 0.2 + logits[0, :] * 0.8
    logits[-1, :] = logits[-2, :] * 0.2 + logits[-1, :] * 0.8
    return logits

In [None]:
def is_chunk_valid(tof, acc, rot, thm, zero_threshold=0.8):
    """
    Check if a chunk is valid (not overwhelmed by zero values).
    
    Args:
        tof, acc, rot, thm: Tensor chunks for each sensor modality
        zero_threshold: Fraction of zero values above which chunk is considered invalid
    
    Returns:
        bool: True if chunk is valid, False if overwhelmed by zeros
    """
    # Calculate zero ratio for each modality
    tof_zeros = (tof == 0).float().mean().item()
    acc_zeros = (acc == 0).float().mean().item()  
    rot_zeros = (rot == 0).float().mean().item()
    thm_zeros = (thm == 0).float().mean().item()
    
    # If any modality is overwhelmed by zeros, consider chunk invalid
    if (tof_zeros > zero_threshold or 
        acc_zeros > zero_threshold or 
        rot_zeros > zero_threshold or 
        thm_zeros > zero_threshold):
        return False
    
    return True

def predict(sequence_df):
    """
    Predict gesture from sequence, filtering out chunks overwhelmed by zero values.
    
    Args:
        sequence_df: Sequence dataframe
        zero_threshold: Fraction of zeros above which to filter out chunk
        min_valid_chunks: Minimum number of valid chunks required for prediction
    """
    # Process sequence into chunks
    sequence_processor = SequenceProcessor()
    sequences = sequence_processor.process_dataframe(
        df=sequence_df, 
        chunk_size=30
    )
    
    # Create dataset for this sequence with chunking enabled
    dataset = CMIDataset(
        sequences=sequences, 
        chunk_size=30,
        use_chunking=True,
    )

    # Filter valid chunks before batching
    valid_chunks = []
    for d in dataset:
        if is_chunk_valid(d['tof'], d['acc'], d['rot'], d['thm'], 0.5):
            valid_chunks.append(d)
    
    # Check if we have enough valid chunks
    if len(valid_chunks) < 1:
        # print(f"Warning: Only {len(valid_chunks)} valid chunks found (threshold: {zero_threshold}), using all chunks")
        valid_chunks = list(dataset)
    
    # Stack tensors for batch processing
    tof_batch = torch.stack([d['tof'] for d in valid_chunks]).to(device)  # (batch_size, seq_len, 320)
    acc_batch = torch.stack([d['acc'] for d in valid_chunks]).to(device)  # (batch_size, seq_len, 3)
    rot_batch = torch.stack([d['rot'] for d in valid_chunks]).to(device)  # (batch_size, seq_len, 4)
    thm_batch = torch.stack([d['thm'] for d in valid_chunks]).to(device)  # (batch_size, seq_len, 5)
    
    model_logits = []
    for i in range(3):
        exp_dir, config = load_config(i)
        model = load_model(exp_dir, config)
        # Forward pass
        with torch.no_grad():
            logits = model(
                tof_features=tof_batch,
                acc_features=acc_batch,
                rot_features=rot_batch,
                thm_features=thm_batch,
            )  # (batch_size, num_classes)
        
        # Average predictions across all valid chunks and get predicted class
        #logits = smooth_prediction(logits)
        mean_logits = logits.mean(dim=0)  # (num_classes,)
        model_logits.append(mean_logits.cpu().numpy())
    
    predicted_class_id = np.argmax(np.mean(model_logits, axis=0))
    # Map class ID back to gesture name
    predicted_gesture = gesture_classes[predicted_class_id]
    return predicted_gesture

In [None]:
predictions = []
true_labels = []
for i, (_, group) in enumerate(train_sequences_df.group_by('sequence_id')):
    if i > 100:
        break
    sequence_id = group['sequence_id'][0]
    pred = predict(group)
    predictions.append(pred)
    true_labels.append(group['gesture'][0])

In [None]:
# Accuracy
correct = sum(p == t for p, t in zip(predictions, true_labels))
accuracy = correct / len(true_labels)
print(f"Accuracy: {accuracy:.4f}")

# F1 Score
from sklearn.metrics import f1_score
f1 = f1_score(true_labels, predictions, average='macro')
print(f"F1 Score: {f1:.4f}")