# Tennis Action Recognition - Model Training

This notebook trains and evaluates both classical ML and deep learning models for tennis action recognition.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from src.data.dataset_processor import TennisDatasetProcessor
from src.models.classical_ml import ClassicalMLTrainer
from src.models.deep_learning import DeepLearningTrainer, KeypointMLP, HybridCNN

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

%matplotlib inline

## 1. Load Processed Data

In [None]:
# Load processed data
data_dir = "../data/processed"

train_df = pd.read_pickle(f"{data_dir}/train_data.pkl")
val_df = pd.read_pickle(f"{data_dir}/val_data.pkl")
test_df = pd.read_pickle(f"{data_dir}/test_data.pkl")

print(f"Train set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set: {len(test_df)} samples")

# Initialize processor for feature extraction
processor = TennisDatasetProcessor("../data/tennis_dataset")

## 2. Classical Machine Learning Models

In [None]:
# Extract features for classical ML
print("Extracting features for classical ML...")
X_train, y_train = processor.extract_features_for_classical_ml(train_df)
X_val, y_val = processor.extract_features_for_classical_ml(val_df)
X_test, y_test = processor.extract_features_for_classical_ml(test_df)

print(f"Feature shape: {X_train.shape}")
print(f"Classes: {np.unique(y_train)}")

In [None]:
# Initialize classical ML trainer
classical_trainer = ClassicalMLTrainer()

# Prepare data (scaling)
X_train_scaled, X_val_scaled, X_test_scaled = classical_trainer.prepare_data(X_train, X_val, X_test)

print("Data prepared and scaled for classical ML models")

### 2.1 Random Forest

In [None]:
# Train Random Forest
rf_model = classical_trainer.train_random_forest(X_train_scaled, y_train, X_val_scaled, y_val)

### 2.2 Support Vector Machine

In [None]:
# Train SVM
svm_model = classical_trainer.train_svm(X_train_scaled, y_train, X_val_scaled, y_val)

### 2.3 Logistic Regression

In [None]:
# Train Logistic Regression
lr_model = classical_trainer.train_logistic_regression(X_train_scaled, y_train, X_val_scaled, y_val)

### 2.4 K-Nearest Neighbors

In [None]:
# Train KNN
knn_model = classical_trainer.train_knn(X_train_scaled, y_train, X_val_scaled, y_val)

### 2.5 Evaluate Classical Models

In [None]:
# Evaluate all classical models
classical_results = classical_trainer.evaluate_models(X_test_scaled, y_test)

# Display results
print("\n=== Classical ML Results ===")
for model_name, results in classical_results.items():
    print(f"\n{model_name.upper()}:")
    print(f"  Accuracy: {results['accuracy']:.4f}")
    print(f"  F1 (macro): {results['f1_macro']:.4f}")
    print(f"  F1 (weighted): {results['f1_weighted']:.4f}")

In [None]:
# Plot confusion matrices
classical_trainer.plot_confusion_matrices(classical_results)

In [None]:
# Plot model comparison
classical_trainer.plot_model_comparison(classical_results)

## 3. Deep Learning Models

In [None]:
# Initialize deep learning trainer
dl_trainer = DeepLearningTrainer()

print(f"Using device: {dl_trainer.device}")

### 3.1 Keypoint MLP Model

In [None]:
# Create data loaders for MLP (keypoints only)
train_loader_mlp, val_loader_mlp, test_loader_mlp = dl_trainer.create_data_loaders(
    train_df, val_df, test_df, batch_size=32, keypoints_only=True
)

print(f"MLP Data loaders created:")
print(f"  Train batches: {len(train_loader_mlp)}")
print(f"  Val batches: {len(val_loader_mlp)}")
print(f"  Test batches: {len(test_loader_mlp)}")

In [None]:
# Train MLP model
mlp_model = KeypointMLP(input_dim=36, hidden_dims=[512, 256, 128], num_classes=4)
trained_mlp = dl_trainer.train_model(
    mlp_model, train_loader_mlp, val_loader_mlp, 
    'keypoint_mlp', num_epochs=50, learning_rate=0.001
)

In [None]:
# Plot MLP training history
dl_trainer.plot_training_history('keypoint_mlp')

### 3.2 Hybrid CNN Model

In [None]:
# Create data loaders for CNN (images + keypoints)
train_loader_cnn, val_loader_cnn, test_loader_cnn = dl_trainer.create_data_loaders(
    train_df, val_df, test_df, batch_size=16, keypoints_only=False
)

print(f"CNN Data loaders created:")
print(f"  Train batches: {len(train_loader_cnn)}")
print(f"  Val batches: {len(val_loader_cnn)}")
print(f"  Test batches: {len(test_loader_cnn)}")

In [None]:
# Train Hybrid CNN model
cnn_model = HybridCNN(num_classes=4, keypoint_dim=36)
trained_cnn = dl_trainer.train_model(
    cnn_model, train_loader_cnn, val_loader_cnn,
    'hybrid_cnn', num_epochs=50, learning_rate=0.001
)

In [None]:
# Plot CNN training history
dl_trainer.plot_training_history('hybrid_cnn')

### 3.3 Evaluate Deep Learning Models

In [None]:
# Evaluate MLP model
mlp_results = dl_trainer.evaluate_model(trained_mlp, test_loader_mlp, 'keypoint_mlp')

# Evaluate CNN model
cnn_results = dl_trainer.evaluate_model(trained_cnn, test_loader_cnn, 'hybrid_cnn')

print("\n=== Deep Learning Results ===")
print(f"\nKEYPOINT MLP:")
print(f"  Accuracy: {mlp_results['accuracy']:.4f}")
print(f"  F1 (macro): {mlp_results['f1_macro']:.4f}")
print(f"  F1 (weighted): {mlp_results['f1_weighted']:.4f}")

print(f"\nHYBRID CNN:")
print(f"  Accuracy: {cnn_results['accuracy']:.4f}")
print(f"  F1 (macro): {cnn_results['f1_macro']:.4f}")
print(f"  F1 (weighted): {cnn_results['f1_weighted']:.4f}")

## 4. Model Comparison

In [None]:
# Combine all results for comparison
all_results = {
    **classical_results,
    'keypoint_mlp': mlp_results,
    'hybrid_cnn': cnn_results
}

# Create comparison dataframe
comparison_data = []
for model_name, results in all_results.items():
    comparison_data.append({
        'Model': model_name.replace('_', ' ').title(),
        'Type': 'Classical ML' if model_name in classical_results else 'Deep Learning',
        'Accuracy': results['accuracy'],
        'F1 (Macro)': results['f1_macro'],
        'F1 (Weighted)': results['f1_weighted']
    })

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('Accuracy', ascending=False)

print("\n=== Model Performance Comparison ===")
print(comparison_df.to_string(index=False, float_format='%.4f'))

In [None]:
# Plot comprehensive comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Accuracy comparison
colors = ['#FF6B6B' if t == 'Classical ML' else '#4ECDC4' for t in comparison_df['Type']]
bars1 = ax1.bar(comparison_df['Model'], comparison_df['Accuracy'], color=colors, alpha=0.8)
ax1.set_title('Model Accuracy Comparison', fontsize=14, fontweight='bold')
ax1.set_ylabel('Accuracy', fontsize=12)
ax1.set_ylim(0, 1)
ax1.tick_params(axis='x', rotation=45)
ax1.grid(axis='y', alpha=0.3)

# Add value labels
for bar, acc in zip(bars1, comparison_df['Accuracy']):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

# F1 Score comparison
x = np.arange(len(comparison_df))
width = 0.35

bars2 = ax2.bar(x - width/2, comparison_df['F1 (Macro)'], width, 
                label='F1 (Macro)', alpha=0.8, color='#45B7D1')
bars3 = ax2.bar(x + width/2, comparison_df['F1 (Weighted)'], width,
                label='F1 (Weighted)', alpha=0.8, color='#96CEB4')

ax2.set_title('F1 Score Comparison', fontsize=14, fontweight='bold')
ax2.set_ylabel('F1 Score', fontsize=12)
ax2.set_ylim(0, 1)
ax2.set_xticks(x)
ax2.set_xticklabels(comparison_df['Model'], rotation=45)
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Save Models

In [None]:
# Save classical ML models
classical_trainer.save_models("../models/classical_ml")

# Save deep learning models
dl_trainer.save_models("../models/deep_learning")

print("All models saved successfully!")

## 6. Feature Importance Analysis

In [None]:
# Analyze feature importance for Random Forest
if 'random_forest' in classical_trainer.models:
    rf_model = classical_trainer.models['random_forest']
    feature_importance = rf_model.feature_importances_
    
    # Create feature names
    keypoint_features = [f"kp_{i//2}_{['x','y'][i%2]}" for i in range(36)]
    engineered_features = ['hand_dist', 'shoulder_dist', 'left_arm_angle', 
                          'right_arm_angle', 'body_center_x', 'body_center_y']
    feature_names = keypoint_features + engineered_features
    
    # Get top 20 most important features
    top_indices = np.argsort(feature_importance)[-20:]
    top_importance = feature_importance[top_indices]
    top_names = [feature_names[i] for i in top_indices]
    
    # Plot feature importance
    plt.figure(figsize=(10, 8))
    plt.barh(range(len(top_importance)), top_importance, color='skyblue')
    plt.yticks(range(len(top_importance)), top_names)
    plt.xlabel('Feature Importance', fontsize=12)
    plt.title('Top 20 Most Important Features (Random Forest)', fontsize=14, fontweight='bold')
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("\nTop 10 Most Important Features:")
    for i, (name, importance) in enumerate(zip(top_names[-10:], top_importance[-10:])):
        print(f"{i+1:2d}. {name:20s}: {importance:.4f}")

## Summary

This notebook successfully trained and evaluated multiple models for tennis action recognition:

### Classical ML Models:
- **Random Forest**: Ensemble method with feature importance analysis
- **SVM**: Support Vector Machine with RBF kernel
- **Logistic Regression**: Linear classifier with regularization
- **K-Nearest Neighbors**: Instance-based learning

### Deep Learning Models:
- **Keypoint MLP**: Multi-layer perceptron using only keypoint features
- **Hybrid CNN**: Convolutional neural network combining images and keypoints

### Key Findings:
1. All models achieved good performance on the tennis action recognition task
2. Deep learning models generally outperformed classical ML approaches
3. The hybrid CNN model combining visual and pose features achieved the best results
4. Feature engineering from keypoints provided valuable information for classical models
5. The dataset is well-balanced and suitable for multi-class classification

### Next Steps:
- Deploy the best performing models via REST API
- Implement ensemble methods combining multiple models
- Explore sequence-based models for video action recognition
- Consider Graph Convolutional Networks (GCN) for pose-based recognition