In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm
import torch.nn.functional as F

def extract_attention_weights(model, data_loader, event_mapper, device='cuda'):
    """
    Extract attention weights for different event types from the trained model
    """
    model.eval()

    # Dictionary to store attention weights for each event type
    event_attention = {}

    # Initialize dictionaries for each event type we want to analyze
    target_events = ['Goal', 'Foul', 'Corner', 'Yellow card', 'Red card']
    for event in target_events:
        event_idx = event_mapper.event_to_index(event)
        event_attention[event_idx] = []

    # Add special handling for cards (combine Yellow and Red)
    card_indices = [
        event_mapper.event_to_index('Yellow card'),
        event_mapper.event_to_index('Red card'),
        event_mapper.event_to_index('Yellow->red card')
    ]

    # Extract attention weights
    with torch.no_grad():
        for features, labels in tqdm(data_loader, desc="Extracting attention weights"):
            features = features.to(device)
            labels = labels.to(device)

            # We need to modify the forward pass to capture attention weights
            # This requires modifying the model class to return attention weights

            batch_size, seq_len, feat_dim = features.size()

            # Apply feature reduction (copied from model's forward method)
            x = features.clone()
            orig_x = x.clone()
            x = x.view(-1, feat_dim)
            x = model.feature_reducer(x.view(batch_size, seq_len, -1))

            # Multi-scale temporal convolution
            x_perm = x.permute(0, 2, 1)

            conv1_out = torch.relu(model.conv1(x_perm))
            conv2_out = torch.relu(model.conv2(x_perm))
            conv3_out = torch.relu(model.conv3(x_perm))

            # Concatenate convolutional outputs
            conv_combined = torch.cat([conv1_out, conv2_out, conv3_out], dim=1)

            # Apply batch normalization
            conv_combined = model.bn1(conv_combined)

            # Pass through LSTM
            lstm_in = conv_combined.permute(0, 2, 1)
            lstm_out, _ = model.lstm1(lstm_in)

            # Apply batch normalization
            lstm_bn = model.bn2(lstm_out.permute(0, 2, 1)).permute(0, 2, 1)

            # Get attention weights
            attn_scores = model.attention(lstm_bn).squeeze(-1)
            attn_weights = F.softmax(attn_scores, dim=1)

            # Store attention weights by event type
            for i in range(batch_size):
                label = labels[i].item()

                # Only store weights for our target events
                if label in event_attention:
                    event_attention[label].append(attn_weights[i].cpu().numpy())

                # Special handling for cards (combine different card types)
                elif label in card_indices:
                    # If it's the first card we're seeing, initialize the 'Card' category
                    if 'Card' not in event_attention:
                        event_attention['Card'] = []
                    event_attention['Card'].append(attn_weights[i].cpu().numpy())

    # Process the collected attention weights
    processed_weights = {}

    # Process target events
    for event in target_events:
        event_idx = event_mapper.event_to_index(event)
        if event_idx in event_attention and len(event_attention[event_idx]) > 0:
            # Average attention weights across all instances of this event
            processed_weights[event] = np.mean(np.vstack(event_attention[event_idx]), axis=0)

    # Process combined cards
    if 'Card' in event_attention and len(event_attention['Card']) > 0:
        processed_weights['Card'] = np.mean(np.vstack(event_attention['Card']), axis=0)

    return processed_weights

def visualize_attention(attention_weights, window_size=WINDOW_SIZE, fps=FPS):
    """
    Visualize attention weights for different event types
    """
    # Create time steps relative to event occurrence (center of window)
    center = window_size // 2
    seconds_per_frame = 1.0 / fps
    time_steps = np.arange(window_size) - center
    time_steps = time_steps * seconds_per_frame

    # Create plot
    fig, ax = plt.subplots(figsize=(12, 7))

    # Plot attention distributions
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    markers = ['o', 's', '^', 'D', 'x']

    # Process each event type
    for i, (event, weights) in enumerate(attention_weights.items()):
        # Ensure we only plot the central region (around the event)
        if len(weights) > window_size:
            center_idx = len(weights) // 2
            start_idx = max(0, center_idx - window_size // 2)
            end_idx = min(len(weights), center_idx + window_size // 2)
            plot_weights = weights[start_idx:end_idx]
            plot_time = time_steps[:len(plot_weights)]
        else:
            plot_weights = weights
            plot_time = time_steps[:len(plot_weights)]

        # Plot this event type
        ax.plot(plot_time, plot_weights, label=event, linewidth=2,
                color=colors[i % len(colors)], marker=markers[i % len(markers)], markersize=8)

        # Mark peak attention
        peak_idx = np.argmax(plot_weights)
        ax.annotate(f'Peak: {plot_weights[peak_idx]:.2f}',
                   xy=(plot_time[peak_idx], plot_weights[peak_idx]),
                   xytext=(10, (-15 if i % 2 == 0 else 10)),
                   textcoords='offset points',
                   arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=.2'))

    # Add vertical line at event occurrence (t=0)
    ax.axvline(x=0, color='black', linestyle='--', alpha=0.7, label='Event Occurrence')
    ax.text(0.2, 0.02, 'Event\nOccurrence', fontsize=9)

    # Customize plot
    ax.set_xlabel('Time Relative to Event (seconds)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Attention Weight', fontsize=12, fontweight='bold')
    ax.set_title('Temporal Attention Distribution Across Event Types', fontsize=14, fontweight='bold')
    ax.grid(linestyle='--', alpha=0.7)
    ax.legend(loc='upper left', frameon=True, framealpha=0.9)

    # Add shaded regions for pre-event and post-event
    min_time = min(time_steps)
    max_time = max(time_steps)
    ax.axvspan(min_time, 0, alpha=0.1, color='blue', label='Pre-event Context')
    ax.axvspan(0, max_time, alpha=0.1, color='red', label='Post-event Context')
    ax.text(min_time + 0.5, 0.23, 'Pre-event Context', fontsize=9, color='darkblue')
    ax.text(1, 0.23, 'Post-event Context', fontsize=9, color='darkred')

    plt.tight_layout()
    save_path = os.path.join(SOCCERNET_PATH, 'attention_visualization.png')
    plt.savefig(save_path, dpi=300)
    print(f"Saved visualization to {save_path}")
    plt.show()

    return fig

def run_attention_experiment(model, test_loader, event_mapper, device='cuda'):
    """
    Run the complete attention visualization experiment
    """
    print("Starting attention visualization experiment...")

    # 1. Extract attention weights from model
    attention_weights = extract_attention_weights(model, test_loader, event_mapper, device)
    print(f"Extracted attention weights for {len(attention_weights)} event types")

    # 2. Analyze and post-process the weights if needed
    print("Processing attention patterns...")
    for event, weights in attention_weights.items():
        peak_idx = np.argmax(weights)
        peak_time = (peak_idx - WINDOW_SIZE // 2) / FPS
        print(f"{event}: Peak attention at {peak_time:.2f}s relative to event occurrence")

    # 3. Visualize the attention distributions
    fig = visualize_attention(attention_weights)

    # 4. Save numerical results for future reference
    results_path = os.path.join(SOCCERNET_PATH, 'attention_analysis_results.npy')
    np.save(results_path, {event: weights for event, weights in attention_weights.items()})

    print("Experiment completed successfully!")
    return attention_weights

# Run the experiment as part of the main function
def run_attention_analysis():
    # Load the trained model
    model_path = '/content/drive/MyDrive/soccernet_event_detection_model.pth'

    # Create event mapper
    event_mapper = EventMapper()

    # Get a subset of test matches
    test_matches = get_match_paths()[-10:]  # Take the last 10 matches for testing

    # Create test dataset
    test_dataset = SoccerNetDataset(test_matches, event_mapper, mode='test')

    # Create data loader
    test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model
    input_dim = 2048  # ResNet features are 2048-dimensional
    hidden_dim = 256
    num_classes = event_mapper.get_num_classes()
    model = EventDetectionModel(input_dim, hidden_dim, num_classes, dropout_rate=0.5)

    # Load trained model weights
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    # Run the experiment
    attention_weights = run_attention_experiment(model, test_loader, event_mapper, device)

    # Clean up
    test_dataset.cleanup_cache()

    return attention_weights

if __name__ == "__main__":
    # Run the attention analysis experiment
    attention_weights = run_attention_analysis()