# CMI Gesture Recognition - Demographic-Enhanced Inference

This notebook demonstrates inference using demographic features through LightGBM and gesture embeddings from the branched model.

In [None]:
import sys
import os
from pathlib import Path
import warnings

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import yaml
from collections import defaultdict
import lightgbm as lgb
import pickle
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

from src.dataset import CMIDataset, SequenceProcessor, prepare_gesture_labels
from src.model import create_model

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"LightGBM version: {lgb.__version__}")

## Load Model and Configuration

In [None]:
# Load config
with open(project_root / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"- Model d_model: {config['model']['d_model']}")
print(f"- Model layers: {config['model']['num_layers']}")
print(f"- Max sequence length for chunking: {config['data']['max_seq_length']}")
print(f"- Max sequence length for positional encoding: {config['model']['max_seq_length']}")

# Find the best trained model
experiment_dirs = list((project_root / 'experiments').glob('cmi_training_*'))
if not experiment_dirs:
    raise FileNotFoundError("No training experiments found")

# Use the most recent experiment
latest_experiment = max(experiment_dirs, key=lambda x: x.name)
model_path = latest_experiment / 'models' / 'best_model.pt'

print(f"Loading model from: {model_path}")

# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(model_path, map_location=device)

# Create model with saved configuration
model_config = checkpoint['model_config']
gesture_model = create_model(**model_config).to(device)
gesture_model.load_state_dict(checkpoint['model_state_dict'])
gesture_model.eval()

print(f"Gesture model loaded successfully on {device}")
print(f"Model config: {model_config}")

# Load label encoder
label_encoder = checkpoint.get('label_encoder')
if label_encoder is None:
    print("Warning: No label encoder found in checkpoint")
    # Load from training data as fallback
    train_df = pl.read_csv(project_root / 'dataset' / 'train.csv')
    train_df, label_encoder, _, _ = prepare_gesture_labels(train_df)
    print(f"Label encoder loaded from training data with {len(label_encoder.classes_)} classes")
else:
    print(f"Label encoder loaded with {len(label_encoder.classes_)} classes")

## Extract Gesture Embeddings from Pre-trained Model

In [None]:
class GestureEmbeddingExtractor:
    """Extract embeddings from the gesture branch model before classification layer."""
    
    def __init__(self, gesture_model, device):
        self.gesture_model = gesture_model
        self.device = device
        self.sequence_processor = SequenceProcessor()
        
    def extract_embedding(self, sequence):
        """Extract embedding from a single sequence."""
        try:
            # Process sequence into chunks
            chunks = self._process_sequence_for_inference(sequence)
            
            # Create dataset from chunks
            dataset = CMIDataset(
                chunks,
                max_length=config['data']['max_seq_length']
            )
            
            # Create dataloader
            dataloader = DataLoader(
                dataset,
                batch_size=len(chunks),  # Process all chunks at once
                shuffle=False,
                num_workers=0
            )
            
            # Extract embeddings
            all_embeddings = []
            
            self.gesture_model.eval()
            with torch.no_grad():
                for batch in dataloader:
                    # Move to device
                    tof_data = batch["tof"].to(self.device)
                    acc_data = batch["acc"].to(self.device)
                    rot_data = batch["rot"].to(self.device)
                    thm_data = batch["thm"].to(self.device)
                    chunk_start_idx = batch.get("chunk_start_idx")
                    if chunk_start_idx is not None:
                        chunk_start_idx = chunk_start_idx.to(self.device)
                    
                    # Extract embeddings before classification layer
                    embedding = self._extract_embedding_from_model(
                        tof_data, acc_data, rot_data, thm_data, chunk_start_idx
                    )
                    
                    all_embeddings.extend(embedding.cpu().numpy())
            
            # Aggregate embeddings across chunks (average)
            if len(all_embeddings) > 1:
                avg_embedding = np.mean(all_embeddings, axis=0)
            else:
                avg_embedding = all_embeddings[0]
            
            return avg_embedding
            
        except Exception as e:
            print(f"Error extracting embedding: {e}")
            # Return zero embedding as fallback
            return np.zeros(self.gesture_model.d_model)
    
    def _extract_embedding_from_model(self, tof_data, acc_data, rot_data, thm_data, chunk_start_idx):
        """Extract embedding from model just before classification layer."""
        # Step 1: Process each sensor branch
        tof_out = self.gesture_model.tof_branch(tof_data)  # (batch_size, d_model, seq_len)
        acc_out = self.gesture_model.acc_branch(acc_data)  # (batch_size, d_model, seq_len)
        rot_out = self.gesture_model.rot_branch(rot_data)  # (batch_size, d_model, seq_len)
        thm_out = self.gesture_model.thm_branch(thm_data)  # (batch_size, d_model, seq_len)

        # Step 2: Concatenate all sensor features
        fused = torch.cat(
            (tof_out, acc_out, rot_out, thm_out),
            dim=1,
        )  # (batch_size, 4*d_model, seq_len)

        # Apply normalization
        fused = self.gesture_model.fusion_norm(fused)

        # Step 3: Transpose for transformer input (batch_size, seq_len, 4*d_model)
        fused = fused.transpose(1, 2)

        # Step 4: Apply feature selection transformer
        transformed = self.gesture_model.feature_transformer(
            fused,
            chunk_start_idx,
        )  # (batch_size, seq_len, d_model)

        # Step 5: Global pooling over sequence dimension
        pooled = self.gesture_model.global_pool(transformed.transpose(1, 2)).squeeze(
            -1,
        )  # (batch_size, d_model)
        
        # Return the embedding (before classification layer)
        return pooled
    
    def _process_sequence_for_inference(self, sequence):
        """Process a single sequence for inference."""
        if isinstance(sequence, pl.DataFrame):
            sequence_id = sequence['sequence_id'][0]
        else:
            sequence_id = sequence['sequence_id'].iloc[0]
        
        try:
            # Create enhanced features using FeatureProcessor
            enhanced_features = self.sequence_processor.feature_processor.create_sequence_features(sequence)
            
            # Apply chunking
            chunks = self.sequence_processor._chunk_sequence(
                enhanced_features,
                0,  # dummy gesture_id for inference
                sequence_id,
                config['data']['max_seq_length'],
            )
            
            return chunks
            
        except Exception as e:
            print(f"Error processing enhanced features: {e}")
            # Fallback to original processing
            return self._fallback_sequence_processing(sequence, sequence_id)
    
    def _fallback_sequence_processing(self, sequence, sequence_id):
        """Fallback sequence processing using basic features."""
        # Define feature columns
        acc_cols = ["acc_x", "acc_y", "acc_z"]
        rot_cols = ["rot_w", "rot_x", "rot_y", "rot_z"]
        thm_cols = [f"thm_{i}" for i in range(1, 6)]
        tof_cols = [f"tof_{i}_v{j}" for i in range(1, 6) for j in range(64)]
        
        if isinstance(sequence, pl.DataFrame):
            seq_data = sequence.select(acc_cols + rot_cols + thm_cols + tof_cols).to_numpy()
        else:
            seq_data = sequence[acc_cols + rot_cols + thm_cols + tof_cols].values
        
        # Create basic sequence dictionary
        return [{
            "sequence_id": sequence_id,
            "data": seq_data,
            "label": 0  # Dummy label for inference
        }]

# Initialize embedding extractor
embedding_extractor = GestureEmbeddingExtractor(gesture_model, device)
print("Gesture embedding extractor initialized")

## Prepare Demographic Features and Train LightGBM

In [None]:
def prepare_demographic_features(demographics_df):
    """Prepare and encode demographic features."""
    if demographics_df is None:
        return None, None
    
    # Convert to pandas if needed
    if isinstance(demographics_df, pl.DataFrame):
        demo_df = demographics_df.to_pandas()
    else:
        demo_df = demographics_df.copy()
    
    # Create demographic feature encoders
    categorical_encoders = {}
    numerical_features = []
    
    # Identify feature types
    for col in demo_df.columns:
        if col == 'participant_id':
            continue  # Skip ID column
        
        if demo_df[col].dtype in ['object', 'category']:
            # Categorical feature - use label encoding
            encoder = LabelEncoder()
            demo_df[f'{col}_encoded'] = encoder.fit_transform(demo_df[col].fillna('unknown'))
            categorical_encoders[col] = encoder
            numerical_features.append(f'{col}_encoded')
        else:
            # Numerical feature
            demo_df[col] = demo_df[col].fillna(demo_df[col].median())
            numerical_features.append(col)
    
    # Scale numerical features
    scaler = StandardScaler()
    feature_matrix = scaler.fit_transform(demo_df[numerical_features])
    
    return {
        'features': feature_matrix,
        'feature_names': numerical_features,
        'categorical_encoders': categorical_encoders,
        'scaler': scaler,
        'demo_df': demo_df
    }, demo_df[numerical_features]

def encode_demographics_for_inference(demographics_row, demo_preprocessor):
    """Encode demographics for a single participant during inference."""
    if demo_preprocessor is None or demographics_row is None:
        return np.zeros(10)  # Return default features if no demographics
    
    try:
        # Convert to dict if needed
        if hasattr(demographics_row, 'to_dict'):
            demo_dict = demographics_row.to_dict()
        else:
            demo_dict = demographics_row
        
        # Prepare feature vector
        feature_vector = []
        
        for feature_name in demo_preprocessor['feature_names']:
            original_col = feature_name.replace('_encoded', '')
            
            if feature_name.endswith('_encoded'):
                # Categorical feature
                encoder = demo_preprocessor['categorical_encoders'][original_col]
                value = demo_dict.get(original_col, 'unknown')
                try:
                    encoded_value = encoder.transform([str(value)])[0]
                except ValueError:
                    # Handle unseen categories
                    encoded_value = encoder.transform(['unknown'])[0]
                feature_vector.append(encoded_value)
            else:
                # Numerical feature
                value = demo_dict.get(original_col, demo_preprocessor['demo_df'][original_col].median())
                feature_vector.append(value)
        
        # Scale features
        feature_vector = np.array(feature_vector).reshape(1, -1)
        scaled_features = demo_preprocessor['scaler'].transform(feature_vector)
        
        return scaled_features.flatten()
        
    except Exception as e:
        print(f"Error encoding demographics: {e}")
        return np.zeros(len(demo_preprocessor['feature_names']))

# Load and prepare training data for LightGBM
print("Loading training data for LightGBM training...")

try:
    # Load training data
    train_df = pl.read_csv(project_root / 'dataset' / 'train.csv')
    train_demographics = pl.read_csv(project_root / 'dataset' / 'train_demographics.csv')
    
    # Prepare gesture labels
    train_df, train_label_encoder, _, _ = prepare_gesture_labels(train_df)
    
    # Prepare demographic features
    demo_preprocessor, demo_features = prepare_demographic_features(train_demographics)
    
    print(f"Demographic features shape: {demo_features.shape if demo_features is not None else 'None'}")
    print(f"Training data shape: {train_df.shape}")
    
except Exception as e:
    print(f"Error loading training data: {e}")
    demo_preprocessor = None
    print("Continuing without demographic preprocessing...")

## Train LightGBM on Combined Features (Optional Training Step)

In [None]:
def train_lightgbm_classifier(train_df, train_demographics, demo_preprocessor, embedding_extractor, num_samples=500):
    """Train LightGBM classifier on combined gesture embeddings and demographic features."""
    
    print(f"Training LightGBM on {num_samples} samples...")
    
    # Extract features for training
    features_list = []
    labels_list = []
    
    # Get unique sequences (limit for training efficiency)
    unique_sequences = train_df['sequence_id'].unique()[:num_samples]
    
    for i, seq_id in enumerate(unique_sequences):
        if i % 50 == 0:
            print(f"Processing sequence {i+1}/{len(unique_sequences)}...")
        
        try:
            # Get sequence data
            sequence_data = train_df.filter(pl.col('sequence_id') == seq_id)
            
            # Extract gesture embedding
            gesture_embedding = embedding_extractor.extract_embedding(sequence_data)
            
            # Get participant demographics
            if 'participant_id' in sequence_data.columns:
                participant_id = sequence_data['participant_id'][0]
                demo_row = train_demographics.filter(pl.col('participant_id') == participant_id)
                if len(demo_row) > 0:
                    demo_features = encode_demographics_for_inference(demo_row.to_pandas().iloc[0], demo_preprocessor)
                else:
                    demo_features = np.zeros(len(demo_preprocessor['feature_names'])) if demo_preprocessor else np.zeros(10)
            else:
                demo_features = np.zeros(len(demo_preprocessor['feature_names'])) if demo_preprocessor else np.zeros(10)
            
            # Combine features
            combined_features = np.concatenate([gesture_embedding, demo_features])
            features_list.append(combined_features)
            
            # Get label
            label = sequence_data['gesture_id'][0]
            labels_list.append(label)
            
        except Exception as e:
            print(f"Error processing sequence {seq_id}: {e}")
            continue
    
    if len(features_list) == 0:
        print("No features extracted for training")
        return None
    
    # Convert to arrays
    X = np.array(features_list)
    y = np.array(labels_list)
    
    print(f"Training features shape: {X.shape}")
    print(f"Training labels shape: {y.shape}")
    
    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    
    # Train LightGBM
    lgb_params = {
        'objective': 'multiclass',
        'num_class': len(np.unique(y)),
        'metric': 'multi_logloss',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.1,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': -1
    }
    
    # Create datasets
    train_dataset = lgb.Dataset(X_train, label=y_train)
    val_dataset = lgb.Dataset(X_val, label=y_val, reference=train_dataset)
    
    # Train model
    lgb_model = lgb.train(
        lgb_params,
        train_dataset,
        valid_sets=[val_dataset],
        num_boost_round=100,
        callbacks=[lgb.early_stopping(10), lgb.log_evaluation(0)]
    )
    
    # Evaluate
    val_preds = lgb_model.predict(X_val)
    val_preds_class = np.argmax(val_preds, axis=1)
    
    print("\nValidation Results:")
    print(classification_report(y_val, val_preds_class, target_names=[f'Class_{i}' for i in range(len(np.unique(y)))]))
    
    return lgb_model

# Train LightGBM (optional - can be skipped if pre-trained model exists)
lgb_model_path = project_root / 'models' / 'lightgbm_demographic_model.pkl'

if lgb_model_path.exists():
    print("Loading pre-trained LightGBM model...")
    with open(lgb_model_path, 'rb') as f:
        lgb_model = pickle.load(f)
    print("LightGBM model loaded successfully")
else:
    if demo_preprocessor is not None:
        print("Training new LightGBM model...")
        lgb_model = train_lightgbm_classifier(
            train_df, train_demographics, demo_preprocessor, embedding_extractor, num_samples=200
        )
        
        if lgb_model is not None:
            # Save model
            lgb_model_path.parent.mkdir(exist_ok=True)
            with open(lgb_model_path, 'wb') as f:
                pickle.dump(lgb_model, f)
            print(f"LightGBM model saved to {lgb_model_path}")
    else:
        print("Cannot train LightGBM without demographic preprocessor")
        lgb_model = None

## Demographic-Enhanced Inference Function

In [None]:
def predict_with_demographics(sequence: pl.DataFrame, demographics: pl.DataFrame = None) -> str:
    """
    Predict gesture using combined gesture embeddings and demographic features.
    
    Args:
        sequence: Polars DataFrame containing sensor data for one sequence
        demographics: Optional demographics data
        
    Returns:
        String containing the predicted gesture name
    """
    try:
        # Extract gesture embedding
        gesture_embedding = embedding_extractor.extract_embedding(sequence)
        
        # Extract demographic features
        if demographics is not None and demo_preprocessor is not None:
            # Get participant demographics
            if 'participant_id' in sequence.columns:
                participant_id = sequence['participant_id'][0]
                demo_row = demographics.filter(pl.col('participant_id') == participant_id)
                if len(demo_row) > 0:
                    demo_features = encode_demographics_for_inference(
                        demo_row.to_pandas().iloc[0], demo_preprocessor
                    )
                else:
                    demo_features = np.zeros(len(demo_preprocessor['feature_names']))
            else:
                demo_features = np.zeros(len(demo_preprocessor['feature_names']))
        else:
            demo_features = np.zeros(10)  # Default demographic features
        
        # Combine features
        combined_features = np.concatenate([gesture_embedding, demo_features]).reshape(1, -1)
        
        # Make prediction using LightGBM if available
        if lgb_model is not None:
            lgb_prediction = lgb_model.predict(combined_features)
            predicted_class = np.argmax(lgb_prediction[0])
            
            # Convert to gesture name
            gesture_name = label_encoder.inverse_transform([predicted_class])[0]
            return gesture_name
        else:
            # Fallback to original gesture model prediction
            print("LightGBM model not available, using fallback prediction")
            return fallback_gesture_prediction(sequence)
            
    except Exception as e:
        print(f"Error in demographic-enhanced prediction: {e}")
        return fallback_gesture_prediction(sequence)

def fallback_gesture_prediction(sequence):
    """Fallback to original gesture model prediction."""
    try:
        # Convert to pandas if needed
        if isinstance(sequence, pl.DataFrame):
            sequence_data = sequence.to_pandas()
        else:
            sequence_data = sequence
        
        # Process sequence into chunks
        chunks = embedding_extractor._process_sequence_for_inference(sequence)
        
        # Create dataset from chunks
        dataset = CMIDataset(
            chunks,
            max_length=config['data']['max_seq_length']
        )
        
        # Create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=len(chunks),
            shuffle=False,
            num_workers=0
        )
        
        # Run inference
        all_probabilities = []
        
        gesture_model.eval()
        with torch.no_grad():
            for batch in dataloader:
                # Move to device
                tof_data = batch["tof"].to(device)
                acc_data = batch["acc"].to(device)
                rot_data = batch["rot"].to(device)
                thm_data = batch["thm"].to(device)
                chunk_start_idx = batch.get("chunk_start_idx")
                if chunk_start_idx is not None:
                    chunk_start_idx = chunk_start_idx.to(device)
                
                # Forward pass
                outputs = gesture_model(
                    tof_data, acc_data, rot_data, thm_data, chunk_start_idx
                )
                
                # Get probabilities
                probabilities = F.softmax(outputs, dim=1)
                all_probabilities.extend(probabilities.cpu().numpy())
        
        # Aggregate predictions across chunks
        if len(all_probabilities) > 1:
            avg_probabilities = np.mean(all_probabilities, axis=0)
        else:
            avg_probabilities = all_probabilities[0]
        
        # Get final prediction
        final_prediction = np.argmax(avg_probabilities)
        
        # Convert to gesture name
        gesture_name = label_encoder.inverse_transform([final_prediction])[0]
        return gesture_name
        
    except Exception as e:
        print(f"Error in fallback prediction: {e}")
        return "Text on phone"  # Default gesture

print("Demographic-enhanced prediction function ready")

## Inference Server Setup

In [None]:
# Import the CMI inference server
try:
    import kaggle_evaluation.cmi_inference_server
    
    # Create inference server with our demographic-enhanced predict function
    inference_server = kaggle_evaluation.cmi_inference_server.CMIInferenceServer(predict_with_demographics)
    
    print("Demographic-enhanced inference server created successfully")
    print("Available methods:")
    print("- inference_server.serve(): Start the server for competition environment")
    print("- inference_server.run_local_gateway(): Run local testing")
    
except ImportError as e:
    print(f"kaggle_evaluation not available: {e}")
    print("This is expected when running outside Kaggle environment")
    print("You can still test the predict_with_demographics function directly")
    
    # Create a mock server for local testing
    class MockInferenceServer:
        def __init__(self, predict_fn):
            self.predict_fn = predict_fn
            
        def serve(self):
            print("Mock server: serve() called")
            
        def run_local_gateway(self, data_paths=None):
            print("Mock server: run_local_gateway() called")
            if data_paths:
                print(f"Data paths: {data_paths}")
                # Load test data and run predictions
                self.test_predictions(data_paths)
                
        def test_predictions(self, data_paths):
            """Test predictions on provided data."""
            try:
                test_csv, demographics_csv = data_paths
                
                # Load test data
                test_df = pl.read_csv(test_csv)
                
                if os.path.exists(demographics_csv):
                    demographics_df = pl.read_csv(demographics_csv)
                else:
                    demographics_df = None
                
                print(f"Loaded test data: {test_df.shape}")
                print(f"Loaded demographics: {demographics_df.shape if demographics_df is not None else 'None'}")
                
                # Test on a few sequences
                unique_sequences = test_df['sequence_id'].unique()[:3]  # Test first 3 sequences
                
                for seq_id in unique_sequences:
                    sequence_data = test_df.filter(pl.col('sequence_id') == seq_id)
                    prediction = self.predict_fn(sequence_data, demographics_df)
                    print(f"Sequence {seq_id}: {prediction}")
                    
            except Exception as e:
                print(f"Error in test predictions: {e}")
    
    inference_server = MockInferenceServer(predict_with_demographics)

print("Demographic-enhanced inference server setup complete")

## Run Inference Server

In [None]:
# Check if running in competition environment or locally
if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    print("Running in competition environment - starting demographic-enhanced inference server")
    inference_server.serve()
else:
    print("Running locally - testing demographic-enhanced inference server")
    
    # Define data paths for local testing
    test_data_path = project_root / 'dataset' / 'test.csv'
    demographics_path = project_root / 'dataset' / 'test_demographics.csv'
    
    if test_data_path.exists():
        data_paths = (str(test_data_path), str(demographics_path))
        inference_server.run_local_gateway(data_paths=data_paths)
    else:
        print(f"Test data not found at {test_data_path}")
        print("Please provide the correct path to test data files")
        
        # Example of direct function testing
        print("\nExample: Testing predict_with_demographics function directly...")
        print("You can call: predict_with_demographics(sequence_dataframe, demographics_dataframe)")
        print("Where sequence_dataframe contains sensor data for one sequence")
        print("And demographics_dataframe contains participant demographic information")