# XAI Interpretability Analysis: Motorcycle Racing Model

Este notebook demuestra cómo usar técnicas de explainabilidad para entender las decisiones de un modelo Transformer en la detección de maniobras de moto, generando visualizaciones listas para publicación académica.

## 1. Importar Librerías y Cargar Modelo

In [None]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

PROJECT_ROOT = Path('/workspaces/Coaching-for-Competitive-Motorcycle-Racing')
sys.path.insert(0, str(PROJECT_ROOT))

from src.analysis.explainability import (
    SimpleTransformerEncoder,
    load_model,
    generate_synthetic_telemetry,
    compute_attention_heatmap,
    compute_shap_importance
)

# Setup plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

print("✓ Libraries imported successfully")

In [None]:
# Cargar o crear modelo Transformer
model_path = PROJECT_ROOT / "models" / "transformer_maneuver_detector.pt"

if model_path.exists():
    print(f"Cargando modelo desde {model_path}...")
    model = load_model(model_path)
    print(f"✓ Modelo cargado: {type(model).__name__}")
else:
    print("Modelo no encontrado. Creando modelo de ejemplo...")
    model = SimpleTransformerEncoder(input_size=6, hidden_size=64, num_heads=4, num_layers=2)
    print(f"✓ Modelo creado: {type(model).__name__}")

# Mostrar arquitectura
print("\nArquitectura del modelo:")
print(model)

## 2. Cargar Datos de Telemetría

In [None]:
# Generar datos de telemetría sintética
# En un caso real, cargarías desde un CSV o HDF5 con datos reales
telemetry = generate_synthetic_telemetry(n_samples=500, n_sensors=6)
sensor_names = ["ax (m/s²)", "ay (m/s²)", "az (m/s²)", "gx (rad/s)", "gy (rad/s)", "gz (rad/s)"]

print(f"Datos de telemetría cargados:")
print(f"  Shape: {telemetry.shape}")
print(f"  Sensores: {sensor_names}")
print(f"\nEstadísticas:")
for i, name in enumerate(sensor_names):
    print(f"  {name:15} μ={telemetry[:, i].mean():7.3f}, σ={telemetry[:, i].std():6.3f}")

In [None]:
# Visualizar datos crudos de telemetría
fig, axes = plt.subplots(3, 2, figsize=(14, 8))
for idx, (ax, name) in enumerate(zip(axes.flat, sensor_names)):
    ax.plot(telemetry[:, idx], linewidth=1, alpha=0.7, color='steelblue')
    ax.set_title(name, fontweight='bold')
    ax.set_ylabel('Value (SI)')
    ax.grid(True, alpha=0.3)
    
plt.suptitle('Raw Telemetry Data: Motorcycle Maneuver', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("✓ Telemetría visualizada")

## 3. Visualizar Mapas de Atención (Attention Weights)

In [None]:
# Calcular mapas de atención
print("Calculando mapas de atención del Transformer...")
attention_heatmap = compute_attention_heatmap(model, telemetry, window_size=100)

print(f"Attention heatmap shape: {attention_heatmap.shape}")

# Visualizar attention weights
fig, ax = plt.subplots(figsize=(14, 4))
attn_row = attention_heatmap[-1, :] if len(attention_heatmap.shape) == 2 else attention_heatmap

im = ax.imshow(
    attn_row[np.newaxis, :],
    aspect='auto',
    cmap='RdYlBu_r',
    extent=[0, len(attn_row), 0, 1],
    vmin=0,
    vmax=attn_row.max()
)
ax.set_ylabel('Attention\nWeight', fontsize=11, fontweight='bold')
ax.set_xlabel('Time Step', fontsize=11, fontweight='bold')
ax.set_title('Attention Weights Over Time: Identification of Critical Phases', fontsize=12, fontweight='bold')

cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.15)
cbar.set_label('Attention Magnitude', fontsize=10)

# Marcar puntos críticos
critical_points = [
    (50, "Aceleración"),
    (150, "Frenada"),
    (300, "Máx. Inclinación"),
]
colors_cp = plt.cm.rainbow(np.linspace(0, 1, len(critical_points)))
for t, label in critical_points:
    ax.axvline(t, color='black', linestyle='--', linewidth=1, alpha=0.4)

plt.tight_layout()
plt.show()
print("✓ Mapas de atención visualizados")

## 4. Implementar SHAP para Feature Importance

In [None]:
# Calcular importancia de características con SHAP
print("Calculando SHAP feature importance...")
shap_importance, sensor_names_short = compute_shap_importance(telemetry, model=model, sensor_names=sensor_names)

print("\nImportancia de características (SHAP):")
for name, imp in zip(sensor_names_short, shap_importance):
    bar_len = int(imp * 50)
    print(f"  {name:15} │{'█' * bar_len:<50}│ {imp:.4f}")

# Crear DataFrame para análisis
importance_df = pd.DataFrame({
    'Sensor': sensor_names_short,
    'Importancia': shap_importance,
    'Tipo': ['Acelerómetro'] * 3 + ['Giroscopio'] * 3
})

print("\nResumen por tipo de sensor:")
print(importance_df.groupby('Tipo')['Importancia'].sum())

In [None]:
# Visualizar importancia de características
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart de importancia
colors = ['#FF6B6B', '#FF9999', '#FFB3B3', '#4ECDC4', '#45B7D1', '#3498DB']
axes[0].barh(importance_df['Sensor'], importance_df['Importancia'], color=colors, edgecolor='black', linewidth=1.2)
axes[0].set_xlabel('Importancia Relativa', fontsize=11, fontweight='bold')
axes[0].set_title('SHAP Feature Importance por Sensor', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='x')

# Pie chart por tipo
tipo_importance = importance_df.groupby('Tipo')['Importancia'].sum()
axes[1].pie(
    tipo_importance.values,
    labels=tipo_importance.index,
    autopct='%1.1f%%',
    colors=['#FF6B6B', '#4ECDC4'],
    startangle=90,
    explode=(0.05, 0.05),
    textprops={'fontsize': 11, 'fontweight': 'bold'}
)
axes[1].set_title('Contribución: Acelerómetro vs. Giroscopio', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()
print("✓ Feature importance visualizado")

## 5. Crear Visualización Combinada para Publicación

In [None]:
# Generar figura combinada lista para publicación académica
from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(16, 10))
gs = GridSpec(3, 2, figure=fig, height_ratios=[1.5, 1.5, 1], hspace=0.4, wspace=0.3)

# Panel 1: Raw Telemetry Signals
ax_signals = fig.add_subplot(gs[0, :])
time_axis = np.arange(telemetry.shape[0])
colors_sensors = plt.cm.tab10(np.linspace(0, 1, telemetry.shape[1]))

for i, (signal, name, color) in enumerate(zip(telemetry.T, sensor_names, colors_sensors)):
    ax_signals.plot(time_axis, signal, label=name, color=color, linewidth=2, alpha=0.8)

ax_signals.set_ylabel('Sensor Value (SI units)', fontsize=12, fontweight='bold')
ax_signals.set_title('(A) Raw Telemetry Data from Motorcycle Maneuver', fontsize=13, fontweight='bold', loc='left')
ax_signals.legend(loc='upper right', ncol=3, framealpha=0.95, fontsize=10)
ax_signals.grid(True, alpha=0.3)

# Panel 2: Attention Heatmap
ax_attn = fig.add_subplot(gs[1, :])
attn_row = attention_heatmap[-1, :] if len(attention_heatmap.shape) == 2 else attention_heatmap
im = ax_attn.imshow(
    attn_row[np.newaxis, :],
    aspect='auto',
    cmap='RdYlBu_r',
    extent=[0, len(attn_row), 0, 1],
)
ax_attn.set_ylabel('Attention\nWeight', fontsize=12, fontweight='bold')
ax_attn.set_xlabel('Time Step', fontsize=12, fontweight='bold')
ax_attn.set_title('(B) Transformer Attention Weights: Temporal Importance', fontsize=13, fontweight='bold', loc='left')
cbar = plt.colorbar(im, ax=ax_attn, orientation='horizontal', pad=0.1)
cbar.set_label('Attention Magnitude', fontsize=10)

# Panel 3: Feature Importance
ax_importance = fig.add_subplot(gs[2, 0])
bars = ax_importance.barh(importance_df['Sensor'], importance_df['Importancia'], color=colors, 
                           edgecolor='black', linewidth=1.2)
ax_importance.set_xlabel('Relative Importance', fontsize=11, fontweight='bold')
ax_importance.set_title('(C) SHAP Feature Importance', fontsize=13, fontweight='bold', loc='left')
ax_importance.set_xlim(0, max(importance_df['Importancia']) * 1.15)

for bar, val in zip(bars, importance_df['Importancia']):
    width = bar.get_width()
    ax_importance.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
                      f'{val:.4f}', ha='left', va='center', fontsize=9, fontweight='bold')

# Panel 4: Summary Information
ax_summary = fig.add_subplot(gs[2, 1])
ax_summary.axis('off')

summary_text = f"""
MODEL EXPLAINABILITY SUMMARY
{'─'*40}

Input Configuration:
  • Temporal window: {telemetry.shape[0]} timesteps
  • Sampling rate: 50 Hz
  • Sensors: {telemetry.shape[1]} (3 accel + 3 gyro)

Model Architecture:
  • Type: Transformer Encoder
  • Hidden size: 64
  • Attention heads: 4
  • Encoder layers: 2

SHAP Insights:
  • Top sensor: {importance_df.loc[importance_df['Importancia'].idxmax(), 'Sensor']}
  • Accel contribution: {tipo_importance['Acelerómetro']:.2%}
  • Gyro contribution: {tipo_importance['Giroscopio']:.2%}

Interpretation:
The attention heatmap highlights
critical phases in the maneuver.
Feature importance shows which
sensors drive the model's decision.
"""

ax_summary.text(0.05, 0.95, summary_text, fontsize=10, verticalalignment='top', 
               family='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.6))

plt.suptitle('Explainability Analysis: Motorcycle Maneuver Detection Model', 
            fontsize=15, fontweight='bold', y=0.995)
plt.savefig(PROJECT_ROOT / 'outputs' / 'explainability_combined.pdf', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Figura combinada para publicación guardada en outputs/explainability_combined.pdf")

## 6. Analizar Decisiones por Sensor

In [None]:
# Análisis profundo: Acelerómetro vs. Giroscopio en diferentes maniobras
print("=" * 60)
print("ANÁLISIS DE CONTRIBUCIÓN POR TIPO DE SENSOR")
print("=" * 60)

accel_importance = importance_df[importance_df['Tipo'] == 'Acelerómetro']['Importancia'].sum()
gyro_importance = importance_df[importance_df['Tipo'] == 'Giroscopio']['Importancia'].sum()

print(f"\n1. RESUMEN GENERAL")
print(f"   Importancia Acelerómetro (ax, ay, az): {accel_importance:.4f} ({accel_importance*100:.1f}%)")
print(f"   Importancia Giroscopio    (gx, gy, gz): {gyro_importance:.4f} ({gyro_importance*100:.1f}%)")

print(f"\n2. DESGLOSE POR SENSOR")
for idx, row in importance_df.iterrows():
    sensor = row['Sensor']
    imp = row['Importancia']
    sensor_type = row['Tipo']
    print(f"   {sensor:15} ({sensor_type:12}): {imp:.4f}")

print(f"\n3. INTERPRETACIÓN FÍSICA")
print(f"   • El Acelerómetro captura dinámicas de frenada y aceleración.")
print(f"   • El Giroscopio captura dinámicas de giro (yaw) e inclinación (roll).")
print(f"   • Relación Accel:Gyro = {accel_importance/gyro_importance:.2f}:1")
print(f"     → {'Modelo enfatiza dinámicas longitudinales (frenada/aceleración)' if accel_importance > gyro_importance else 'Modelo enfatiza dinámicas rotacionales (giros/inclinaciones)'}")

# Tabla resumen para publicación
print(f"\n4. TABLA RESUMEN (formato LaTeX para paper)")
print(r"\begin{table}[h]")
print(r"\centering")
print(r"\begin{tabular}{lcc}")
print(r"Sensor & Importancia & Porcentaje \\")
print(r"\hline")
for idx, row in importance_df.iterrows():
    print(f"{row['Sensor']} & {row['Importancia']:.4f} & {row['Importancia']*100:.1f}\\% \\\\")
print(r"\hline")
print(f"Acelerómetro Total & {accel_importance:.4f} & {accel_importance*100:.1f}\\% \\\\")
print(f"Giroscopio Total   & {gyro_importance:.4f} & {gyro_importance*100:.1f}\\% \\\\")
print(r"\end{tabular}")
print(r"\caption{SHAP Feature Importance: Sensor Contributions to Maneuver Detection}")
print(r"\label{tab:shap_importance}")
print(r"\end{table}")

print("\n✓ Análisis completado")

In [None]:
"""
Notebook: Interpretability Deep Dive

This Jupyter notebook demonstrates how to use the explainability module
to analyze model decisions, visualize attention patterns, and compute
feature importance for motorcycle racing models.

Steps:
1. Train a simple Transformer on synthetic telemetry
2. Extract attention weights and visualize them
3. Compute SHAP values for feature importance
4. Generate publication-ready explainability figures

Usage:
    jupyter notebook notebooks/explainability_analysis.ipynb
"""