# CMI Gesture Recognition Training Notebook

This notebook trains the four-branch gesture recognition model for the CMI competition.

## Setup and Imports

In [None]:
import os
import sys

sys.path.append(os.path.dirname(os.getcwd()))

import warnings

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import torch
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

warnings.filterwarnings("ignore")

from feature_processor import FeatureProcessor
from src.model import create_model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load and Prepare Data

In [None]:
# Load data
print("Loading training data...")
train_df = pl.read_csv("../dataset/train.csv")
demographics_df = pl.read_csv("../dataset/train_demographics.csv")

print(f"Training data shape: {train_df.shape}")
print(f"Demographics data shape: {demographics_df.shape}")
print(f"\nColumns in training data: {len(train_df.columns)}")
print(f"Unique sequences: {train_df['sequence_id'].n_unique()}")
print(f"Unique gestures: {train_df['gesture'].n_unique()}")

In [None]:
# Check data types and missing values
print("Gesture distribution:")
gesture_counts = (
    train_df.group_by("gesture")
    .agg(pl.count().alias("count"))
    .sort("count", descending=True)
)
print(gesture_counts)

In [None]:
# Define gesture classes
target_gestures = [
    "Above ear - pull hair",
    "Cheek - pinch skin",
    "Eyebrow - pull hair",
    "Eyelash - pull hair",
    "Forehead - pull hairline",
    "Forehead - scratch",
    "Neck - pinch skin",
    "Neck - scratch",
]

non_target_gestures = [
    "Text on phone",
    "Wave hello",
    "Write name in air",
    "Pull air toward your face",
    "Feel around in tray and pull out an object",
    "Glasses on/off",
    "Drink from bottle/cup",
    "Scratch knee/leg skin",
    "Write name on leg",
    "Pinch knee/leg skin",
]

all_gestures = target_gestures + non_target_gestures
print(f"Total gesture classes: {len(all_gestures)}")

# Create label encoder
label_encoder = LabelEncoder()
train_df = train_df.with_columns(
    pl.Series(label_encoder.fit_transform(train_df["gesture"].to_numpy())).alias(
        "gesture_id",
    ),
)

print("\nGesture ID mapping:")
for i, gesture in enumerate(label_encoder.classes_):
    print(f"{i}: {gesture}")

## Data Processing and Feature Extraction

In [None]:
# Define feature columns
acc_cols = ["acc_x", "acc_y", "acc_z"]
rot_cols = ["rot_w", "rot_x", "rot_y", "rot_z"]
thm_cols = [f"thm_{i}" for i in range(1, 6)]
tof_cols = [f"tof_{i}_v{j}" for i in range(1, 6) for j in range(64)]

print(f"Accelerometer columns: {len(acc_cols)}")
print(f"Rotation columns: {len(rot_cols)}")
print(f"Thermal columns: {len(thm_cols)}")
print(f"ToF columns: {len(tof_cols)}")
print(f"Total sensor columns: {len(acc_cols + rot_cols + thm_cols + tof_cols)}")

In [None]:
# Prepare sequences with enhanced features
print("Preparing sequences with enhanced feature processing...")
sequences = []
sequence_lengths = []

feature_processor = FeatureProcessor()

# Group by sequence_id and process each sequence
grouped = train_df.group_by("sequence_id")

for seq_id, group in tqdm(grouped, desc="Processing sequences"):
    try:
        # Get sequence data and metadata
        gesture_id = group["gesture_id"][0]

        # Create enhanced features using FeatureProcessor
        enhanced_features = feature_processor.create_sequence_features(group)

        # Store both original and enhanced data
        sequences.append(
            {
                "sequence_id": seq_id[0],
                "enhanced_data": enhanced_features,
                "label": gesture_id,
            },
        )
        sequence_lengths.append(enhanced_features["sequence_length"])

    except Exception as e:
        print(f"Error processing sequence {seq_id[0]}: {e}")
        # Fallback to original processing
        seq_data = group.select(acc_cols + rot_cols + thm_cols + tof_cols).to_numpy()
        sequences.append(
            {
                "sequence_id": seq_id[0],
                "data": seq_data,
                "label": gesture_id,
            },
        )
        sequence_lengths.append(len(seq_data))

print(f"\nProcessed {len(sequences)} sequences")
print("Sequence length statistics:")
print(f"  Min: {min(sequence_lengths)}")
print(f"  Max: {max(sequence_lengths)}")
print(f"  Mean: {np.mean(sequence_lengths):.1f}")
print(f"  Median: {np.median(sequence_lengths):.1f}")

# Check enhanced feature dimensions
if "enhanced_data" in sequences[0]:
    sample_features = sequences[0]["enhanced_data"]
    print("\nEnhanced feature dimensions:")
    print(f"  ToF: {sample_features['tof'].shape[1]} features")
    print(f"  ACC: {sample_features['acc'].shape[1]} features (enhanced)")
    print(f"  ROT: {sample_features['rot'].shape[1]} features (enhanced)")
    print(f"  THM: {sample_features['thm'].shape[1]} features (enhanced)")

In [None]:
# Choose max sequence length (use 95th percentile to avoid extreme outliers)
max_seq_length = int(np.percentile(sequence_lengths, 95))
print(f"Using max sequence length: {max_seq_length}")

# Split sequences by gesture for stratified split
sequence_ids = [seq["sequence_id"] for seq in sequences]
labels = [seq["label"] for seq in sequences]

# Stratified train/validation split
train_indices, val_indices = train_test_split(
    range(len(sequences)),
    test_size=0.2,
    random_state=42,
    stratify=labels,
)

train_sequences = [sequences[i] for i in train_indices]
val_sequences = [sequences[i] for i in val_indices]

print(f"Training sequences: {len(train_sequences)}")
print(f"Validation sequences: {len(val_sequences)}")

In [None]:
class CMIDataset(Dataset):
    def __init__(self, sequences, max_length=None, use_enhanced_features=True):
        self.sequences = sequences
        self.max_length = max_length or max(
            len(seq["enhanced_data"]["tof"])
            if "enhanced_data" in seq
            else len(seq["data"])
            for seq in sequences
        )
        self.use_enhanced_features = use_enhanced_features

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = sequence["label"]

        if self.use_enhanced_features and "enhanced_data" in sequence:
            # Use enhanced features
            tof_data = sequence["enhanced_data"]["tof"]
            acc_data = sequence["enhanced_data"]["acc"]
            rot_data = sequence["enhanced_data"]["rot"]
            thm_data = sequence["enhanced_data"]["thm"]
        else:
            # Use original features
            data = sequence["data"]

            # Pad or truncate sequence
            seq_len = len(data)
            if seq_len < self.max_length:
                # Pad with zeros
                padding = np.zeros((self.max_length - seq_len, data.shape[1]))
                data = np.vstack([data, padding])
            elif seq_len > self.max_length:
                # Truncate
                data = data[: self.max_length]

            # Split into sensor modalities
            tof_data = data[:, :320]  # ToF features (320)
            acc_data = data[:, 320:323]  # Accelerometer (3)
            rot_data = data[:, 323:327]  # Rotation (4)
            thm_data = data[:, 327:332]  # Thermal (5)

            # Handle missing values (-1.0) by replacing with 0
            tof_data = np.where(tof_data == -1.0, 0.0, tof_data)
            acc_data = np.where(acc_data == -1.0, 0.0, acc_data)
            rot_data = np.where(rot_data == -1.0, 0.0, rot_data)
            thm_data = np.where(thm_data == -1.0, 0.0, thm_data)

        # Ensure consistent length for enhanced features
        if self.use_enhanced_features:
            seq_len = len(tof_data)
            if seq_len < self.max_length:
                # Pad enhanced features
                tof_padding = np.zeros((self.max_length - seq_len, tof_data.shape[1]))
                acc_padding = np.zeros((self.max_length - seq_len, acc_data.shape[1]))
                rot_padding = np.zeros((self.max_length - seq_len, rot_data.shape[1]))
                thm_padding = np.zeros((self.max_length - seq_len, thm_data.shape[1]))

                tof_data = np.vstack([tof_data, tof_padding])
                acc_data = np.vstack([acc_data, acc_padding])
                rot_data = np.vstack([rot_data, rot_padding])
                thm_data = np.vstack([thm_data, thm_padding])
            elif seq_len > self.max_length:
                # Truncate
                tof_data = tof_data[: self.max_length]
                acc_data = acc_data[: self.max_length]
                rot_data = rot_data[: self.max_length]
                thm_data = thm_data[: self.max_length]

        return {
            "tof": torch.FloatTensor(tof_data),
            "acc": torch.FloatTensor(acc_data),
            "rot": torch.FloatTensor(rot_data),
            "thm": torch.FloatTensor(thm_data),
            "label": torch.LongTensor([label])[0],
        }

In [None]:
# config
CONFIG = {
    "batch_size": 16,
    "d_model": 128,
    "num_heads": 8,
    "max_seq_length": max_seq_length,
    "use_enhanced_features": True,
    "lr": 1e-4,
    "weight_decay": 1e-2,
    "patience": 5,
    "decay_factor": 0.5,
    "num_epochs": 10,
    "device": device,
}

In [None]:
# Create datasets and dataloaders
train_dataset = CMIDataset(
    train_sequences,
    max_length=CONFIG["max_seq_length"],
    use_enhanced_features=CONFIG["use_enhanced_features"],
)
val_dataset = CMIDataset(
    val_sequences,
    max_length=CONFIG["max_seq_length"],
    use_enhanced_features=CONFIG["use_enhanced_features"],
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=2,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=2,
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## Model Setup and Training

In [None]:
# Create model
model = create_model(
    d_model=CONFIG["d_model"],
    num_heads=CONFIG["num_heads"],
    seq_len=CONFIG["max_seq_length"],
).to(CONFIG["device"])

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
print(f"model at {CONFIG['device']}")

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=CONFIG["lr"],
    weight_decay=CONFIG["weight_decay"],
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=CONFIG["decay_factor"],
    patience=CONFIG["patience"],
)

num_epochs = CONFIG["num_epochs"]
best_val_acc = 0.0
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

print(f"Training setup complete. Starting training for {num_epochs} epochs...")

In [None]:
# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch in train_pbar:
        # Move to device
        tof_data = batch["tof"].to(device)
        acc_data = batch["acc"].to(device)
        rot_data = batch["rot"].to(device)
        thm_data = batch["thm"].to(device)
        labels = batch["label"].to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(tof_data, acc_data, rot_data, thm_data)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Statistics
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

        # Update progress bar
        train_acc = 100.0 * train_correct / train_total
        train_pbar.set_postfix(
            {
                "Loss": f"{loss.item():.4f}",
                "Acc": f"{train_acc:.2f}%",
            },
        )

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        for batch in val_pbar:
            # Move to device
            tof_data = batch["tof"].to(device)
            acc_data = batch["acc"].to(device)
            rot_data = batch["rot"].to(device)
            thm_data = batch["thm"].to(device)
            labels = batch["label"].to(device)

            # Forward pass
            outputs = model(tof_data, acc_data, rot_data, thm_data)
            loss = criterion(outputs, labels)

            # Statistics
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            # Update progress bar
            val_acc = 100.0 * val_correct / val_total
            val_pbar.set_postfix(
                {
                    "Loss": f"{loss.item():.4f}",
                    "Acc": f"{val_acc:.2f}%",
                },
            )

    # Calculate epoch metrics
    epoch_train_loss = train_loss / len(train_loader)
    epoch_train_acc = 100.0 * train_correct / train_total
    epoch_val_loss = val_loss / len(val_loader)
    epoch_val_acc = 100.0 * val_correct / val_total

    # Store metrics
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_acc)
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_acc)

    # Learning rate scheduling
    scheduler.step(epoch_val_acc)

    # Save best model
    if epoch_val_acc > best_val_acc:
        best_val_acc = epoch_val_acc
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "best_val_acc": best_val_acc,
                "label_encoder": label_encoder,
            },
            "../models/best_model.pt",
        )

    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%")
    print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%")
    print(f"  Best Val Acc: {best_val_acc:.2f}%")
    print("-" * 60)

print(f"Training completed! Best validation accuracy: {best_val_acc:.2f}%")

## Training Visualization

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(train_losses, label="Train Loss", color="blue")
ax1.plot(val_losses, label="Validation Loss", color="red")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Training and Validation Loss")
ax1.legend()
ax1.grid(True)

# Accuracy curves
ax2.plot(train_accuracies, label="Train Accuracy", color="blue")
ax2.plot(val_accuracies, label="Validation Accuracy", color="red")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Training and Validation Accuracy")
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print("Final Results:")
print(f"  Final Train Accuracy: {train_accuracies[-1]:.2f}%")
print(f"  Final Validation Accuracy: {val_accuracies[-1]:.2f}%")
print(f"  Best Validation Accuracy: {best_val_acc:.2f}%")

## Model Evaluation

In [None]:
# Load best model for evaluation
checkpoint = torch.load("../models/best_model.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Evaluate on validation set
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        # Move to device
        tof_data = batch["tof"].to(device)
        acc_data = batch["acc"].to(device)
        rot_data = batch["rot"].to(device)
        thm_data = batch["thm"].to(device)
        labels = batch["label"].to(device)

        # Forward pass
        outputs = model(tof_data, acc_data, rot_data, thm_data)
        _, predicted = torch.max(outputs, 1)

        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(all_labels, all_predictions)
print(f"Validation Accuracy: {accuracy:.4f}")

# Classification report
print("\nClassification Report:")
report = classification_report(
    all_labels,
    all_predictions,
    target_names=label_encoder.classes_,
    digits=3,
)
print(report)

In [None]:
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

plt.figure(figsize=(16, 12))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=label_encoder.classes_,
    yticklabels=label_encoder.classes_,
)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## Model Analysis

In [None]:
# Analyze per-class performance
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1, support = precision_recall_fscore_support(
    all_labels,
    all_predictions,
    average=None,
)

results_df = pl.DataFrame(
    {
        "gesture": label_encoder.classes_,
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "support": support,
    },
)

# Add target/non-target classification
results_df = results_df.with_columns(
    pl.when(pl.col("gesture").is_in(target_gestures))
    .then(pl.lit("Target"))
    .otherwise(pl.lit("Non-target"))
    .alias("category"),
)

# Sort by F1 score
results_df = results_df.sort("f1_score", descending=True)
print("Per-class Performance:")
print(results_df)

In [None]:
# Target vs Non-target performance
target_avg = (
    results_df.filter(pl.col("category") == "Target")
    .select(["precision", "recall", "f1_score"])
    .mean()
)
non_target_avg = (
    results_df.filter(pl.col("category") == "Non-target")
    .select(["precision", "recall", "f1_score"])
    .mean()
)

print("\nAverage Performance by Category:")
print(
    f"Target gestures - Precision: {target_avg['precision'][0]:.3f}, Recall: {target_avg['recall'][0]:.3f}, F1: {target_avg['f1_score'][0]:.3f}",
)
print(
    f"Non-target gestures - Precision: {non_target_avg['precision'][0]:.3f}, Recall: {non_target_avg['recall'][0]:.3f}, F1: {non_target_avg['f1_score'][0]:.3f}",
)

# Overall metrics
macro_avg = results_df.select(["precision", "recall", "f1_score"]).mean()
print(
    f"\nOverall Macro Average - Precision: {macro_avg['precision'][0]:.3f}, Recall: {macro_avg['recall'][0]:.3f}, F1: {macro_avg['f1_score'][0]:.3f}",
)

## Save Final Model and Metadata

In [None]:
# Create models directory if it doesn't exist
os.makedirs("../models", exist_ok=True)

# Save final model with all metadata
final_model_path = "../models/cmi_gesture_model_final.pt"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "model_config": {
            "num_classes": 18,
            "d_model": 128,
            "num_heads": 8,
            "max_seq_length": max_seq_length,
        },
        "label_encoder": label_encoder,
        "training_history": {
            "train_losses": train_losses,
            "train_accuracies": train_accuracies,
            "val_losses": val_losses,
            "val_accuracies": val_accuracies,
        },
        "final_metrics": {
            "best_val_accuracy": best_val_acc,
            "final_val_accuracy": accuracy,
            "macro_precision": macro_avg["precision"][0],
            "macro_recall": macro_avg["recall"][0],
            "macro_f1": macro_avg["f1_score"][0],
        },
        "feature_columns": {
            "tof_cols": tof_cols,
            "acc_cols": acc_cols,
            "rot_cols": rot_cols,
            "thm_cols": thm_cols,
        },
    },
    final_model_path,
)

print(f"Final model saved to: {final_model_path}")
print(f"Model file size: {os.path.getsize(final_model_path) / (1024*1024):.1f} MB")

print("\n=== Training Summary ===")
print(f"Total parameters: {total_params:,}")
print(f"Max sequence length: {max_seq_length}")
print(f"Training sequences: {len(train_sequences)}")
print(f"Validation sequences: {len(val_sequences)}")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Final validation accuracy: {accuracy*100:.2f}%")
print("Training completed successfully!")