In [None]:
import polars as pl
import numpy as np
import torch
import torch.nn as nn
from src.model import GestureBranchedModel
from src.dataset import *
from pathlib import Path
import yaml
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
exp_dir = Path('../experiments/cmi_training_20250814_220701')
dataset_dir = Path('../dataset')
train_sequences_df = pl.read_csv(dataset_dir / 'train.csv')
train_sequences_df = train_sequences_df.fill_null(-1.0).fill_nan(-1.0)
train_demographics_df = pl.read_csv(dataset_dir / 'train_demographics.csv')
test_sequences_df = pl.read_csv(dataset_dir / 'test.csv')
test_sequences_df = test_sequences_df.fill_null(-1.0).fill_nan(-1.0)
test_demographics_df = pl.read_csv(dataset_dir / 'test_demographics.csv')
with open(exp_dir / 'configs' / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [None]:
train_sequences_df, labelencoder, target_gesture_id, non_target_gesture_id \
=  prepare_gesture_labels(train_sequences_df)
# test_sequences_df['gesture_id'] = -1

In [None]:
# Create dataset from chunks
sequence_processor = SequenceProcessor()
train_sequences = sequence_processor.process_dataframe(
    df=train_sequences_df, max_seq_length=config['data']['max_seq_length']
)

In [None]:
dataset = CMIDataset(
    sequences=train_sequences,
    max_length=config['data']['max_seq_length']
)

model = GestureBranchedModel(
    num_classes=config['model']['num_classes'],
    d_model= config['model']['d_model'],
    d_reduced= config['model']['d_reduced'],
    num_heads= config['model']['num_heads'],
    num_layers= config['model']['num_layers'],
    dropout= config['model']['dropout'],
    max_seq_length= config['model']['max_seq_length'],
    sequence_processor= config['model']['sequence_processor'],
    tof_backbone='b0'
)
# Load the checkpoint
checkpoint = torch.load(exp_dir / 'models/best_model.pt', map_location='cpu')

# Extract just the model state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

In [None]:
processing_id = dataset[0]['sequence_id']
predictions = {}
batch_data = []

for i, data in tqdm(enumerate(dataset), total=len(dataset), desc="Processing sequences"):
    if data['sequence_id'] == processing_id:
        # Collect data for the same sequence_id
        batch_data.append(data)
    else:
        # Process the accumulated batch
        if batch_data:
            # Stack tensors along batch dimension
            tof_batch = torch.stack([d['tof'] for d in batch_data]).to(device)
            acc_batch = torch.stack([d['acc'] for d in batch_data]).to(device)
            rot_batch = torch.stack([d['rot'] for d in batch_data]).to(device)
            thm_batch = torch.stack([d['thm'] for d in batch_data]).to(device)
            chunk_start_idx_batch = torch.stack([d['chunk_start_idx'] for d in batch_data]).to(device)
            
            # Forward pass with batched data
            with torch.no_grad():
                logits = model.forward(
                    tof_features=tof_batch,
                    acc_features=acc_batch,
                    rot_features=rot_batch,
                    thm_features=thm_batch,
                    chunk_start_idx=chunk_start_idx_batch
                )
            
            # Compute mean across batch dimension and get predicted class
            mean_logits = logits.mean(dim=0)  # Average across batch
            predicted_class = mean_logits.argmax().item()
            
            predictions[processing_id] = predicted_class
        
        # Start new sequence
        batch_data = [data]
        processing_id = data['sequence_id']

# Process the last batch
if batch_data:
    tof_batch = torch.stack([d['tof'] for d in batch_data]).to(device)
    acc_batch = torch.stack([d['acc'] for d in batch_data]).to(device)
    rot_batch = torch.stack([d['rot'] for d in batch_data]).to(device)
    thm_batch = torch.stack([d['thm'] for d in batch_data]).to(device)
    chunk_start_idx_batch = torch.stack([d['chunk_start_idx'] for d in batch_data]).to(device)
    
    with torch.no_grad():
        logits = model.forward(
            tof_features=tof_batch,
            acc_features=acc_batch,
            rot_features=rot_batch,
            thm_features=thm_batch,
            chunk_start_idx=chunk_start_idx_batch
        )
    
    mean_logits = logits.mean(dim=0)
    predicted_class = mean_logits.argmax().item()
    predictions[processing_id] = predicted_class

In [None]:
from sklearn.metrics import accuracy_score, f1_score, classification_report

# Get true labels for evaluation
true_labels = []
pred_labels = []

for sequence_id in tqdm(predictions.keys()):
    # Get true label for this sequence
    true_label = train_sequences_df.filter(pl.col('sequence_id') == sequence_id)['gesture_id'][0]
    true_labels.append(true_label)
    pred_labels.append(predictions[sequence_id])

# Calculate metrics
accuracy = accuracy_score(true_labels, pred_labels)
f1_macro = f1_score(true_labels, pred_labels, average='macro')
f1_weighted = f1_score(true_labels, pred_labels, average='weighted')

print(f"Accuracy: {accuracy:.4f}")
print(f"F1-Score (Macro): {f1_macro:.4f}")
print(f"F1-Score (Weighted): {f1_weighted:.4f}")

print("\nClassification Report:")
print(classification_report(true_labels, pred_labels))