In [10]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from tqdm.notebook import tqdm
from google.colab import drive
import gc
from torch.nn.utils.rnn import pad_sequence
from moviepy.editor import VideoFileClip, concatenate_videoclips
import librosa

# Constants
FPS = 2  # Feature extraction rate
WINDOW_SIZE = 40  # As per your training scripts
EVENT_WINDOW = 4  # From your event detection script
HIGHLIGHT_SECONDS_BEFORE = 5  # Seconds before event to include in highlight
HIGHLIGHT_SECONDS_AFTER = 7   # Seconds after event to include in highlight
MIN_IMPORTANCE_THRESHOLD = 0.25  # Minimum importance score to include in highlights
MAX_EVENTS_PER_HALF = 5  # Limit to top 5 events per half

# Event importance scores - predefined from your data
EVENT_IMPORTANCE = {
    "Goal": 0.9560,
    "Offside": 0.6015,
    "Shot": 0.5415,  # Shots off target
    "Shot on target": 0.3765,
    "Foul": 0.5302,
    "Penalty": 0.4277,
    "Red card": 0.2364,
    "Yellow->red card": 0.2174,
    "Direct free-kick": 0.1627,
    "Ball out of play": 0.1019,
    "Yellow card": 0.0811,
    "Corner": 0.0695,
    "Indirect free-kick": 0.0170,
    "Clearance": 0.0127,
    "Kick-off": 0.0125,
    "Throw-in": 0.0087,
    "Substitution": 0.0042,
    "Background": 0.0
}

# Event mapper class from your event detection script
class EventMapper:
    def __init__(self):
        self.events = [
            "Ball out of play", "Throw-in", "Foul", "Indirect free-kick",
            "Clearance", "Shot", "Shot on target", "Goal", "Corner", "Substitution",
            "Kick-off", "Yellow card", "Offside", "Direct free-kick", "Red card",
            "Yellow->red card", "Penalty", "Background"
        ]
        self.event_to_idx = {event: i for i, event in enumerate(self.events)}
        self.idx_to_event = {i: event for i, event in enumerate(self.events)}

    def get_num_classes(self):
        return len(self.events)

    def event_to_index(self, event):
        return self.event_to_idx.get(event, self.event_to_idx["Background"])

    def index_to_event(self, idx):
        return self.idx_to_event[idx]

# Event Detection Model from your script
class EventDetectionModel(nn.Module):
    """Advanced neural network for soccer event detection"""

    def __init__(self, input_dim, hidden_dim, num_classes, dropout_rate=0.5):
        super(EventDetectionModel, self).__init__()

        # Input dimension reduction - input now includes difference features
        self.input_dim = input_dim * 2  # Original features + temporal difference features

        # Feature reduction
        self.feature_reducer = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim * 2),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(WINDOW_SIZE),
            nn.Dropout(dropout_rate)
        )

        # 1D convolutional layers with different kernel sizes for multi-scale temporal features
        self.conv1 = nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=7, padding=3)

        # Bi-directional LSTM layers
        self.lstm1 = nn.LSTM(
            input_size=hidden_dim * 3,  # Combined conv outputs
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_rate
        )

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

        # Output layers with skip connection
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.classifier = nn.Linear(hidden_dim // 2, num_classes)

        # Additional skip connection
        self.skip_connection = nn.Linear(hidden_dim * 2, hidden_dim // 2)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm1d(hidden_dim * 3)
        self.bn2 = nn.BatchNorm1d(hidden_dim * 2)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        self.bn4 = nn.BatchNorm1d(hidden_dim // 2)

    def forward(self, x):
        batch_size, seq_len, feat_dim = x.size()

        # Apply feature reduction
        orig_x = x.clone()  # Save for residual connection
        x = x.view(-1, feat_dim)
        x = self.feature_reducer(x.view(batch_size, seq_len, -1))

        # Multi-scale temporal convolution
        x_perm = x.permute(0, 2, 1)  # [batch, hidden_dim*2, seq_len]

        conv1_out = torch.relu(self.conv1(x_perm))
        conv2_out = torch.relu(self.conv2(x_perm))
        conv3_out = torch.relu(self.conv3(x_perm))

        # Concatenate convolutional outputs
        conv_combined = torch.cat([conv1_out, conv2_out, conv3_out], dim=1)

        # Apply batch normalization
        conv_combined = self.bn1(conv_combined)

        # Pass through LSTM
        lstm_in = conv_combined.permute(0, 2, 1)  # [batch, seq_len, hidden_dim*3]
        lstm_out, _ = self.lstm1(lstm_in)  # [batch, seq_len, hidden_dim*2]

        # Apply batch normalization
        lstm_bn = self.bn2(lstm_out.permute(0, 2, 1)).permute(0, 2, 1)

        # Apply attention mechanism
        attn_weights = self.attention(lstm_bn).squeeze(-1)  # [batch, seq_len]
        attn_weights = torch.softmax(attn_weights, dim=1).unsqueeze(-1)  # [batch, seq_len, 1]

        # Context vector is weighted sum of LSTM outputs
        context = torch.sum(lstm_bn * attn_weights, dim=1)  # [batch, hidden_dem*2]

        # Feed-forward layers with skip connection
        out1 = torch.relu(self.fc1(context))
        out1 = self.bn3(out1)
        out1 = nn.functional.dropout(out1, p=0.4, training=self.training)

        out2 = torch.relu(self.fc2(out1))
        out2 = self.bn4(out2)

        # Skip connection
        skip = self.skip_connection(context)
        out = out2 + skip

        # Final classification
        logits = self.classifier(out)

        return logits

# Importance Scoring Model from your first script
class TemporalEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.conv = torch.nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.lstm = torch.nn.LSTM(hidden_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.norm = torch.nn.LayerNorm(hidden_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (batch, time, dim) -> (batch, dim, time)
        x = self.conv(x).relu()
        x = x.permute(0, 2, 1)  # (batch, time, hidden_dim)
        x, _ = self.lstm(x)
        x = self.norm(x)
        return x

class HighlightModel(torch.nn.Module):
    def __init__(self, video_dim, audio_dim, hidden_dim):
        super().__init__()
        self.video_encoder = TemporalEncoder(video_dim, hidden_dim)
        self.audio_encoder = TemporalEncoder(audio_dim, hidden_dim)
        self.fusion = torch.nn.Linear(hidden_dim * 2, hidden_dim)
        self.scorer = torch.nn.Linear(hidden_dim, 1)

    def forward(self, video, audio, event_timestamps):
        event_timestamps = event_timestamps.long()
        video_features = self.video_encoder(video)  # (batch, time, hidden)
        audio_features = self.audio_encoder(audio)  # (batch, time, hidden)
        fused_features = torch.cat([video_features, audio_features], dim=-1)  # (batch, time, hidden*2)
        fused_features = self.fusion(fused_features).relu()  # (batch, time, hidden)
        event_features = fused_features[torch.arange(fused_features.size(0))[:, None], event_timestamps]
        scores = self.scorer(event_features).sigmoid()  # (batch, num_events, 1)
        return scores.squeeze(-1)  # (batch, num_events)

def load_models(event_model_path, highlight_model_path, device):
    """Load both models from disk"""
    event_mapper = EventMapper()
    num_classes = event_mapper.get_num_classes()
    event_model = EventDetectionModel(input_dim=2048, hidden_dim=256, num_classes=num_classes).to(device)

    # Load the checkpoint
    checkpoint = torch.load(event_model_path, map_location=device)
    # Extract the model state dictionary
    event_model.load_state_dict(checkpoint['model_state_dict'])
    event_model.eval()

    highlight_model = HighlightModel(video_dim=2048, audio_dim=20, hidden_dim=512).to(device)
    highlight_model.load_state_dict(torch.load(highlight_model_path, map_location=device))
    highlight_model.eval()

    return event_model, highlight_model, event_mapper

def extract_audio_features(audio_path, target_length, feature_rate=FPS):
    """Extract MFCC features from audio file"""
    try:
        y, sr = librosa.load(audio_path, sr=None)
        hop_length = int(sr / feature_rate)
        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20, hop_length=hop_length)
        audio_features = mfcc.T
        if audio_features.shape[0] < target_length:
            padding = np.zeros((target_length - audio_features.shape[0], audio_features.shape[1]))
            audio_features = np.vstack([audio_features, padding])
        elif audio_features.shape[0] > target_length:
            audio_features = audio_features[:target_length]
        return torch.tensor(audio_features, dtype=torch.float32), y, sr
    except Exception as e:
        print(f"Error extracting audio features: {e}")
        return None, None, None

def detect_events(event_model, match_features, device):
    """Detect events in the match using the event detection model"""
    event_model.eval()
    results = []

    with torch.no_grad():
        for i in range(0, len(match_features) - WINDOW_SIZE + 1, 1):
            window_features = match_features[i:i+WINDOW_SIZE]
            # Normalize features
            window_features = (window_features - np.mean(window_features, axis=0)) / (np.std(window_features, axis=0) + 1e-5)
            # Add temporal difference features
            if window_features.shape[0] > 1:
                diff_features = np.diff(window_features, axis=0)
                diff_features = np.vstack([np.zeros((1, window_features.shape[1])), diff_features])
                combined_features = np.concatenate([window_features, diff_features], axis=1)
            else:
                combined_features = np.concatenate([window_features, np.zeros_like(window_features)], axis=1)

            inputs = torch.tensor(combined_features, dtype=torch.float32).unsqueeze(0).to(device)

            outputs = event_model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            event_id = torch.argmax(outputs, dim=1).item()
            confidence = probabilities[0, event_id].item()

            frame_idx = i + WINDOW_SIZE // 2
            time_seconds = frame_idx / FPS

            results.append({
                'frame': frame_idx,
                'time': time_seconds,
                'event_id': event_id,
                'confidence': confidence
            })

    return results

def filter_events(events, event_mapper, confidence_threshold=0.7, min_frame_distance=10):
    """Filter events - remove duplicates and low confidence predictions"""
    events.sort(key=lambda x: x['confidence'], reverse=True)

    filtered_events = []
    used_frames = set()

    for event in events:
        if event_mapper.index_to_event(event['event_id']) == "Background":
            continue

        if event['confidence'] < confidence_threshold:
            continue

        frame = event['frame']
        is_nearby = any(abs(frame - used_frame) < min_frame_distance for used_frame in used_frames)

        if not is_nearby:
            used_frames.add(frame)
            event['event_name'] = event_mapper.index_to_event(event['event_id'])
            filtered_events.append(event)

    filtered_events.sort(key=lambda x: x['frame'])
    return filtered_events

def score_events(filtered_events, highlight_model, match_video_features, match_audio_features, device):
    """Score events using the highlight importance model"""
    if not filtered_events:
        return []

    highlight_model.eval()

    timestamps = [event['frame'] for event in filtered_events]

    video_tensor = torch.tensor(match_video_features, dtype=torch.float32).unsqueeze(0).to(device)
    audio_tensor = torch.tensor(match_audio_features, dtype=torch.float32).unsqueeze(0).to(device)
    timestamp_tensor = torch.tensor(timestamps, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        importance_scores = highlight_model(video_tensor, audio_tensor, timestamp_tensor)

    for i, event in enumerate(filtered_events):
        model_score = importance_scores[0, i].item()
        event_name = event['event_name']
        predefined_score = EVENT_IMPORTANCE.get(event_name, 0.0)

        combined_score = 0.6 * model_score + 0.4 * predefined_score
        event['model_score'] = model_score
        event['predefined_score'] = predefined_score
        event['combined_score'] = combined_score

    filtered_events.sort(key=lambda x: x['combined_score'], reverse=True)
    return filtered_events

def generate_highlight_timestamps(scored_events, max_highlight_duration=180, max_events=MAX_EVENTS_PER_HALF):
    """Generate timestamps for the highlight video, limited to max_events"""
    print("Generating highlight timestamps...")
    # Sort by score to select top events
    scored_events = sorted(scored_events, key=lambda x: x['combined_score'], reverse=True)[:max_events]
    # Sort by time for chronological order
    scored_events.sort(key=lambda x: x['time'])

    highlight_segments = []
    total_duration = 0

    for event in scored_events:
        print(f"Evaluating event: {event['event_name']} at {event['time']:.1f}s, "
              f"Combined Score: {event['combined_score']:.4f}, "
              f"Model Score: {event['model_score']:.4f}, "
              f"Predefined Score: {event['predefined_score']:.4f}")

        if event['combined_score'] < MIN_IMPORTANCE_THRESHOLD:
            print(f"Skipping event: Score {event['combined_score']:.4f} below threshold {MIN_IMPORTANCE_THRESHOLD}")
            continue

        start_time = max(0, event['time'] - HIGHLIGHT_SECONDS_BEFORE)
        end_time = event['time'] + HIGHLIGHT_SECONDS_AFTER
        segment_duration = end_time - start_time

        if total_duration + segment_duration > max_highlight_duration:
            print(f"Skipping event: Adding {segment_duration}s would exceed max duration "
                  f"(Current: {total_duration}s, Max: {max_highlight_duration}s)")
            if len(highlight_segments) > 0:
                continue

        highlight_segments.append({
            'start': start_time,
            'end': end_time,
            'event': event['event_name'],
            'score': event['combined_score']
        })
        total_duration += segment_duration
        print(f"Added segment: {event['event_name']} from {start_time}s to {end_time}s, "
              f"Duration: {segment_duration}s, Total Duration: {total_duration}s")

        if total_duration >= max_highlight_duration:
            print("Max duration reached, stopping segment generation")
            break

    if highlight_segments:
        print("Merging overlapping segments...")
        merged_segments = [highlight_segments[0]]
        for segment in highlight_segments[1:]:
            prev = merged_segments[-1]
            if segment['start'] <= prev['end']:
                prev['end'] = max(prev['end'], segment['end'])
                if segment['score'] > prev['score']:
                    prev['event'] = segment['event']
                    prev['score'] = segment['score']
                print(f"Merged segment: {prev['event']} from {prev['start']}s to {prev['end']}s")
            else:
                merged_segments.append(segment)
                print(f"Added non-overlapping segment: {segment['event']} from {segment['start']}s to {segment['end']}s")

        print(f"Generated {len(merged_segments)} highlight segments")
        return merged_segments

    print("No highlight segments generated")
    return []

def create_highlight_video(video_path, highlight_segments, output_path):
    """Create the highlight video by extracting and concatenating clips"""
    print(f"Attempting to create highlight video: {output_path}")
    print(f"Input video path: {video_path}")
    print(f"Highlight segments: {highlight_segments}")

    try:
        if not os.path.exists(video_path):
            print(f"Video file not found: {video_path}")
            return

        print(f"Loading video file: {video_path}")
        video = VideoFileClip(video_path)
        print(f"Video duration: {video.duration}s, FPS: {video.fps}")

        clips = []
        for segment in highlight_segments:
            start_time = segment['start']
            end_time = segment['end']
            print(f"Processing segment: {segment['event']} from {start_time}s to {end_time}s")

            start_time = max(0, min(start_time, video.duration))
            end_time = max(start_time, min(end_time, video.duration))

            if end_time <= start_time:
                print(f"Skipping invalid segment: start_time ({start_time}s) >= end_time ({end_time}s)")
                continue

            try:
                print(f"Extracting subclip from {start_time}s to {end_time}s")
                clip = video.subclip(start_time, end_time)
                clips.append(clip)
                print(f"Added clip of duration {clip.duration}s")
            except Exception as e:
                print(f"Error extracting subclip from {start_time}s to {end_time}s: {e}")

        if clips:
            print(f"Concatenating {len(clips)} clips")
            final_clip = concatenate_videoclips(clips, method="compose")
            print(f"Writing output video to {output_path}")
            final_clip.write_videofile(
                output_path,
                codec="libx264",
                audio_codec="aac",
                temp_audiofile="temp-audio.m4a",
                remove_temp=True,
                verbose=True
            )
            print(f"Highlight video successfully saved to {output_path}")
        else:
            print("No valid clips to concatenate!")

    except Exception as e:
        print(f"Error in create_highlight_video: {e}")
    finally:
        if 'video' in locals():
            video.close()
        if 'final_clip' in locals():
            final_clip.close()
        if 'clips' in locals():
            for clip in clips:
                if clip:
                    clip.close()

def combine_highlight_videos(half1_video_path, half2_video_path, output_path):
    """Combine two highlight videos into a single video"""
    print(f"Attempting to combine highlight videos into: {output_path}")
    print(f"Half 1 video: {half1_video_path}")
    print(f"Half 2 video: {half2_video_path}")

    try:
        clips = []
        if os.path.exists(half1_video_path):
            print(f"Loading Half 1 video: {half1_video_path}")
            clip1 = VideoFileClip(half1_video_path)
            clips.append(clip1)
            print(f"Half 1 clip loaded: Duration {clip1.duration}s")
        else:
            print(f"Half 1 video not found: {half1_video_path}")

        if os.path.exists(half2_video_path):
            print(f"Loading Half 2 video: {half2_video_path}")
            clip2 = VideoFileClip(half2_video_path)
            clips.append(clip2)
            print(f"Half 2 clip loaded: Duration {clip2.duration}s")
        else:
            print(f"Half 2 video not found: {half2_video_path}")

        if clips:
            print(f"Concatenating {len(clips)} clips")
            final_clip = concatenate_videoclips(clips, method="compose")
            print(f"Writing combined video to {output_path}")
            final_clip.write_videofile(
                output_path,
                codec="libx264",
                audio_codec="aac",
                temp_audiofile="temp-audio.m4a",
                remove_temp=True,
                verbose=True
            )
            print(f"Combined highlight video saved to {output_path}")

            final_clip.close()
            for clip in clips:
                clip.close()
        else:
            print("No valid clips to concatenate for combined video!")

    except Exception as e:
        print(f"Error in combine_highlight_videos: {e}")
    finally:
        if 'final_clip' in locals():
            final_clip.close()
        if 'clips' in locals():
            for clip in clips:
                if clip:
                    clip.close()

def generate_highlights(match_dir, event_model, highlight_model, event_mapper, output_dir, device):
    """Main function to generate highlights for a match"""
    os.makedirs(output_dir, exist_ok=True)

    half1_features_path = os.path.join(match_dir, "1_ResNET_TF2.npy")
    half2_features_path = os.path.join(match_dir, "2_ResNET_TF2.npy")

    if not os.path.exists(half1_features_path) or not os.path.exists(half2_features_path):
        print(f"Error: Feature files not found in {match_dir}")
        return

    half1_features = np.load(half1_features_path)
    half2_features = np.load(half2_features_path)

    half1_video_path = os.path.join(match_dir, "1_224p.mkv")
    half2_video_path = os.path.join(match_dir, "2_224p.mkv")

    if not os.path.exists(half1_video_path):
        print(f"Warning: Half 1 video file not found: {half1_video_path}")
    if not os.path.exists(half2_video_path):
        print(f"Warning: Half 2 video file not found: {half2_video_path}")

    half1_audio_path = os.path.join(match_dir, "1_224p.wav")
    half2_audio_path = os.path.join(match_dir, "2_224p.wav")

    match_name = os.path.basename(match_dir)

    all_highlight_segments = []

    for half_idx, (features, video_path, audio_path) in enumerate(
        [(half1_features, half1_video_path, half1_audio_path),
         (half2_features, half2_video_path, half2_audio_path)],
        start=1):

        print(f"\nProcessing half {half_idx} of {match_name}...")

        if os.path.exists(audio_path):
            audio_features, audio_signal, sr = extract_audio_features(audio_path, features.shape[0])
            if audio_features is None:
                print(f"Failed to extract audio features from {audio_path}")
                audio_features = torch.zeros(features.shape[0], 20)
        else:
            print(f"Audio file not found: {audio_path}, using zeros for audio features")
            audio_features = torch.zeros(features.shape[0], 20)
            audio_signal, sr = None, None

        print("Detecting events...")
        events = detect_events(event_model, features, device)

        print("Filtering events...")
        filtered_events = filter_events(events, event_mapper)

        if not filtered_events:
            print(f"No events detected in half {half_idx}")
            continue

        print("Scoring events...")
        scored_events = score_events(filtered_events, highlight_model, features, audio_features, device)

        print("Generating highlight timestamps...")
        highlight_segments = generate_highlight_timestamps(scored_events)
        print(f"Generated {len(highlight_segments)} highlight segments")

        for segment in highlight_segments:
            all_highlight_segments.append({
                'half': half_idx,
                'video_path': video_path,
                'start': segment['start'],
                'end': segment['end'],
                'event': segment['event'],
                'score': segment['score']
            })

        print(f"\nDetected {len(scored_events)} events in half {half_idx}:")
        for i, event in enumerate(scored_events[:10]):
            print(f"{i+1}. {event['event_name']} at {event['time']:.1f}s, Score: {event['combined_score']:.4f}")

        if video_path and os.path.exists(video_path) and highlight_segments:
            output_path = os.path.join(output_dir, f"{match_name}_half{half_idx}_highlights.mp4")
            print(f"\nCreating highlight video for half {half_idx}...")
            create_highlight_video(video_path, highlight_segments, output_path)
        else:
            print(f"Skipping video creation for half {half_idx}: Video path valid={os.path.exists(video_path)}, Segments={len(highlight_segments)}")

        highlight_json_path = os.path.join(output_dir, f"{match_name}_half{half_idx}_highlights.json")
        with open(highlight_json_path, 'w') as f:
            json.dump({
                'match': match_name,
                'half': half_idx,
                'events': scored_events,
                'highlight_segments': highlight_segments
            }, f, indent=2)
        print(f"Highlight data saved to {highlight_json_path}")

        plt.figure(figsize=(15, 6))
        for event in scored_events:
            plt.axvline(x=event['time'], color='r', alpha=0.3, linestyle='--')
            plt.text(event['time'], 1.0, event['event_name'], rotation=90, alpha=0.7)

        for segment in highlight_segments:
            plt.axvspan(segment['start'], segment['end'], alpha=0.2, color='green')

        plt.title(f"Highlights for {match_name} - Half {half_idx}")
        plt.xlabel("Time (seconds)")
        plt.ylabel("Event Importance")
        plt.grid(True, alpha=0.3)
        plt.xlim(0, features.shape[0] / FPS)
        vis_path = os.path.join(output_dir, f"{match_name}_half{half_idx}_highlights_vis.png")
        plt.savefig(vis_path)
        plt.close()

    # Combine existing highlight videos
    half1_highlight_path = os.path.join(output_dir, f"{match_name}_half1_highlights.mp4")
    half2_highlight_path = os.path.join(output_dir, f"{match_name}_half2_highlights.mp4")
    combined_output_path = os.path.join(output_dir, f"{match_name}_combined_highlights.mp4")
    combine_highlight_videos(half1_highlight_path, half2_highlight_path, combined_output_path)

def main():
    """Main entry point"""
    try:
        drive.mount('/content/drive', force_remount=False)
    except Exception as e:
        print(f"Drive mounting issue (may already be mounted): {e}")

    event_model_path = '/content/drive/MyDrive/soccernet/soccernet_event_detection_model.pth'
    highlight_model_path = '/content/drive/MyDrive/soccernet/highlight_model.pth'

    match_dir = '/content/drive/MyDrive/soccernet/england_epl/2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley'
    output_dir = '/content/drive/MyDrive/soccernet/highlights'

    os.makedirs(output_dir, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print("Loading models...")
    event_model, highlight_model, event_mapper = load_models(event_model_path, highlight_model_path, device)

    generate_highlights(match_dir, event_model, highlight_model, event_mapper, output_dir, device)

    print("Highlight generation complete!")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cpu
Loading models...

Processing half 1 of 2015-02-21 - 18-00 Chelsea 1 - 1 Burnley...
Detecting events...


  audio_tensor = torch.tensor(match_audio_features, dtype=torch.float32).unsqueeze(0).to(device)



Filtering events...
Scoring events...
Generating highlight timestamps...
Generating highlight timestamps...
Evaluating event: Corner at 807.5s, Combined Score: 0.2830, Model Score: 0.4253, Predefined Score: 0.0695
Added segment: Corner from 802.5s to 814.5s, Duration: 12.0s, Total Duration: 12.0s
Evaluating event: Corner at 879.0s, Combined Score: 0.2829, Model Score: 0.4252, Predefined Score: 0.0695
Added segment: Corner from 874.0s to 886.0s, Duration: 12.0s, Total Duration: 24.0s
Evaluating event: Direct free-kick at 914.5s, Combined Score: 0.3200, Model Score: 0.4249, Predefined Score: 0.1627
Added segment: Direct free-kick from 909.5s to 921.5s, Duration: 12.0s, Total Duration: 36.0s
Evaluating event: Corner at 1113.5s, Combined Score: 0.2828, Model Score: 0.4250, Predefined Score: 0.0695
Added segment: Corner from 1108.5s to 1120.5s, Duration: 12.0s, Total Duration: 48.0s
Evaluating event: Foul at 2530.5s, Combined Score: 0.4688, Model Score: 0.4278, Predefined Score: 0.5302
Adde



MoviePy - Done.
Moviepy - Writing video /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.mp4





Moviepy - Done !
Moviepy - video ready /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.mp4
Highlight video successfully saved to /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.mp4
Highlight data saved to /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.json

Processing half 2 of 2015-02-21 - 18-00 Chelsea 1 - 1 Burnley...
Detecting events...
Filtering events...
Scoring events...
Generating highlight timestamps...
Generating highlight timestamps...
Evaluating event: Offside at 317.5s, Combined Score: 0.4958, Model Score: 0.4254, Predefined Score: 0.6015
Added segment: Offside from 312.5s to 324.5s, Duration: 12.0s, Total Duration: 12.0s
Evaluating event: Offside at 322.5s, Combined Score: 0.4959, Model Score: 0.4254, Predefined Score: 0.6015
Added segment: Offside from 317.5s to 329.5s, Duration: 12.0s, Total Duration: 24.0s
Ev



MoviePy - Done.
Moviepy - Writing video /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half2_highlights.mp4





Moviepy - Done !
Moviepy - video ready /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half2_highlights.mp4
Highlight video successfully saved to /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half2_highlights.mp4
Highlight data saved to /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half2_highlights.json
Attempting to combine highlight videos into: /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_combined_highlights.mp4
Half 1 video: /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.mp4
Half 2 video: /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half2_highlights.mp4
Loading Half 1 video: /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_half1_highlights.mp4
Half 1 clip loaded: Duration 60.0s
Loading Half 2 video: /



MoviePy - Done.
Moviepy - Writing video /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_combined_highlights.mp4





Moviepy - Done !
Moviepy - video ready /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_combined_highlights.mp4
Combined highlight video saved to /content/drive/MyDrive/soccernet/highlights/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley_combined_highlights.mp4
Highlight generation complete!
