In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
pip install torch numpy librosa moviepy av wandb



In [3]:
# pip install av

In [4]:
import torch
import torch.nn as nn
import numpy as np
import os
import json
from torch.utils.data import Dataset, DataLoader
import librosa
from moviepy.editor import VideoFileClip

In [5]:
!pip uninstall wandb -y
!pip install wandb --upgrade

Found existing installation: wandb 0.19.10
Uninstalling wandb-0.19.10:
  Successfully uninstalled wandb-0.19.10
Collecting wandb
  Using cached wandb-0.19.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Using cached wandb-0.19.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.3 MB)
Installing collected packages: wandb
Successfully installed wandb-0.19.10


In [None]:
import wandb
# Initialize WandB
wandb.login(key='add api key')
wandb.init(project="soccernet_highlights", config={
    "batch_size": 4,
    "epochs": 3,
    "learning_rate": 1e-3,
    "video_dim": 2048,
    "audio_dim": 20,
    "hidden_dim": 512
})


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msupraja2010341[0m ([33mfyproject[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from torchvision.models import resnet50
from torchvision.io import read_video
from torchvision import transforms

In [8]:
# Model Definition
class TemporalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.conv = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (batch, time, dim) -> (batch, dim, time)
        x = self.conv(x).relu()
        x = x.permute(0, 2, 1)  # (batch, time, hidden_dim)
        x, _ = self.lstm(x)
        x = self.norm(x)
        return x

In [9]:
class HighlightModel(nn.Module):
    def __init__(self, video_dim, audio_dim, hidden_dim):
        super().__init__()
        self.video_encoder = TemporalEncoder(video_dim, hidden_dim)
        self.audio_encoder = TemporalEncoder(audio_dim, hidden_dim)
        self.fusion = nn.Linear(hidden_dim * 2, hidden_dim)
        self.scorer = nn.Linear(hidden_dim, 1)

    def forward(self, video, audio, event_timestamps):
        event_timestamps=event_timestamps.long()
        video_features = self.video_encoder(video)  # (batch, time, hidden)
        audio_features = self.audio_encoder(audio)  # (batch, time, hidden)
        fused_features = torch.cat([video_features, audio_features], dim=-1)  # (batch, time, hidden*2)
        fused_features = self.fusion(fused_features).relu()  # (batch, time, hidden)
        event_features = fused_features[torch.arange(fused_features.size(0))[:, None], event_timestamps]
        scores = self.scorer(event_features).sigmoid()  # (batch, num_events, 1)
        return scores.squeeze(-1)  # (batch, num_events)

In [10]:
import glob
class SoccerNetDataset(Dataset):
    def __init__(self, data_dir, split_file, feature_rate=1, max_events=50, device="cuda",top=10):
        self.data_dir = data_dir
        self.feature_rate = feature_rate
        self.max_events = max_events
        self.device = device
        self.game_dirs = self._load_games(split_file)[:top]  # Limit to 10 games as per original
        # self.top=top

    def _load_games(self, split_file):
        with open(split_file) as f:
            game_dirs = json.load(f)  # List of absolute game folder paths
        valid_dirs = [d for d in game_dirs if os.path.exists(d)]
        if len(valid_dirs) < len(game_dirs):
            print(f"Warning: {len(game_dirs) - len(valid_dirs)} game directories not found")
        return valid_dirs

    def _load_video_features(self, game_dir):
        feature_files = sorted(glob.glob(os.path.join(game_dir, "*_ResNET_TF2.npy")))
        if len(feature_files) < 2:
            raise FileNotFoundError(f"Expected 1_ResNET_TF2.npy and 2_ResNET_TF2.npy in {game_dir}, found {len(feature_files)}")
        features1 = np.load(feature_files[0])
        features2 = np.load(feature_files[1])
        video_features = np.concatenate([features1, features2], axis=0)  # (T, 2048)
        return torch.tensor(video_features, dtype=torch.float32).to(self.device)

    def _load_audio_features(self, game_dir, target_length):
        audio_files = sorted(glob.glob(os.path.join(game_dir, "*_224p.wav")))
        if len(audio_files) < 2:
            raise FileNotFoundError(f"Expected 1_224p.wav and 2_224p.wav in {game_dir}, found {len(audio_files)}")
        mfccs = []
        audio_signals = []
        sr = None
        for audio_file in audio_files:
            y, sr = librosa.load(audio_file, sr=None)
            hop_length = int(sr / self.feature_rate)
            mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20, hop_length=hop_length)
            mfccs.append(mfcc.T)
            audio_signals.append(y)
        audio_features = np.concatenate(mfccs, axis=0)
        audio_features = librosa.util.fix_length(audio_features.T, size=target_length, axis=1).T
        audio_signal = np.concatenate(audio_signals)
        return (torch.tensor(audio_features, dtype=torch.float32).to(self.device),
                audio_signal, sr)

    def _parse_game_time(self, game_time):
        """Convert gameTime (e.g., '1 - 00:41') to seconds."""
        half, time_str = game_time.split(" - ")
        minutes, seconds = map(int, time_str.split(":"))
        total_seconds = minutes * 60 + seconds
        return int(half), total_seconds

    def _load_event_timestamps(self, game_dir, video_length, audio_signal, sr):
        annotation_file = os.path.join(game_dir, "Labels-v2.json")
        if not os.path.exists(annotation_file):
            raise FileNotFoundError(f"No Labels-v2.json found in {game_dir}")
        with open(annotation_file) as f:
            data = json.load(f)
        events = data.get("annotations", [])
        half1_duration = self._get_half1_duration(os.path.join(game_dir, "1_224p.mp4"))
        half1_duration_frames = int(half1_duration * self.feature_rate)

        timestamps = []
        audio_scores = []
        window_seconds = 5  # ±5 seconds window for audio energy
        window_samples = int(window_seconds * sr)

        for event in events:
            half, time_seconds = self._parse_game_time(event["gameTime"])
            timestamp = int(time_seconds * self.feature_rate)
            if half == 2:
                timestamp += half1_duration_frames
            if timestamp < video_length:
                timestamps.append(timestamp)
                # Compute audio energy in ±5s window
                audio_time = time_seconds + (half1_duration if half == 2 else 0)
                center_sample = int(audio_time * sr)
                start_sample = max(0, center_sample - window_samples)
                end_sample = min(len(audio_signal), center_sample + window_samples)
                window = audio_signal[start_sample:end_sample]
                rms = librosa.feature.rms(y=window, frame_length=2048, hop_length=512)
                score = np.mean(rms) if rms.size > 0 else 0.0
                audio_scores.append(score)

        # Normalize audio scores to [0, 1]
        if audio_scores:
            audio_scores = np.array(audio_scores)
            min_score, max_score = audio_scores.min(), audio_scores.max()
            if max_score > min_score:
                audio_scores = (audio_scores - min_score) / (max_score - min_score)
            else:
                audio_scores = np.zeros_like(audio_scores)

        num_events = len(timestamps)
        if num_events == 0:
            timestamps = [0]
            audio_scores = [0.0]
            num_events = 0
        if num_events > self.max_events:
            timestamps = timestamps[:self.max_events]
            audio_scores = audio_scores[:self.max_events]
            num_events = self.max_events
        else:
            timestamps += [0] * (self.max_events - num_events)
            audio_scores += [0.0] * (self.max_events - num_events)

        return (torch.tensor(timestamps, dtype=torch.long),
                torch.tensor(audio_scores, dtype=torch.float32),
                num_events)

    def _get_half1_duration(self, video_path):
        try:
            video = VideoFileClip(video_path)
            duration = video.duration
            video.close()
            return duration
        except Exception:
            return 2700  # Default: 45 minutes

    def __len__(self):
        return len(self.game_dirs)

    def __getitem__(self, idx):
        game_dir = self.game_dirs[idx]
        try:
            video_tensor = self._load_video_features(game_dir)
            audio_tensor, audio_signal, sr = self._load_audio_features(game_dir, video_tensor.shape[0])
            timestamps, scores, num_events = self._load_event_timestamps(game_dir, video_tensor.shape[0], audio_signal, sr)
            return {
                "video": video_tensor,
                "audio": audio_tensor,
                "timestamps": timestamps.to(self.device),
                "scores": scores.to(self.device),
                "num_events": num_events
            }
        except Exception as e:
            print(f"Skipping {game_dir} due to error: {e}")
            return None

In [11]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None

    # Convert each item to torch.Tensor with consistent dtype
    videos = [b["video"].clone().detach() for b in batch]
    audios = [b["audio"].clone().detach() for b in batch]
    timestamps = [b["timestamps"].clone().detach().long() for b in batch]
    scores = [b["scores"].clone().detach() for b in batch]  # (Ti,)
    num_events = [b["num_events"] for b in batch]

    # Pad all sequences to max length in batch
    padded_videos = pad_sequence(videos, batch_first=True)        # (B, T_max, 2048)
    padded_audios = pad_sequence(audios, batch_first=True)        # (B, T_max, 20)
    padded_timestamps = pad_sequence(timestamps, batch_first=True) # (B, T_max)
    padded_scores = pad_sequence(scores, batch_first=True)         # (B, T_max)

    return {
        "video": padded_videos,
        "audio": padded_audios,
        "timestamps": padded_timestamps,
        "scores": padded_scores,
        "num_events": num_events
    }


In [12]:
def train_model(data_dir, train_split, valid_split, batch_size=1, epochs=10, lr=1e-3, device="cuda"):
    # Initialize datasets and dataloaders
    train_dataset = SoccerNetDataset(data_dir, train_split, device=device,top=50)
    val_dataset = SoccerNetDataset(data_dir, valid_split, device=device,top=10)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=custom_collate)

    print(f"Training dataset size: {len(train_dataset)} games")
    print(f"Validation dataset size: {len(val_dataset)} games")

    # Initialize model
    model = HighlightModel(video_dim=2048, audio_dim=20, hidden_dim=512).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    print("Started training")

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_batches = 0
        train_correct = 0
        train_total = 0

        for batch_idx, batch in enumerate(train_loader):
            if batch is None:
                continue
            video = batch["video"].to(device)             # (B, T, 2048)
            audio = batch["audio"].to(device)             # (B, T, 20)
            timestamps = batch["timestamps"].to(device)   # (B, T)
            scores = batch["scores"].to(device)           # (B, T)
            num_events = batch["num_events"]              # List of actual lengths per sample

            if all(n == 0 for n in num_events):
                continue

            optimizer.zero_grad()
            pred_scores = model(video, audio, timestamps) # (B, T)

            loss = torch.tensor(0.0, device=device, requires_grad=True)
            valid_samples = 0
            for i in range(len(num_events)):
                if num_events[i] > 0:
                    valid_scores = pred_scores[i, :num_events[i]]
                    valid_gt = scores[i, :num_events[i]]
                    loss = loss + criterion(valid_scores, valid_gt)

                    # Accuracy
                    pred_labels = (valid_scores >= 0.5).float()
                    train_correct += (pred_labels == valid_gt).float().sum().item()
                    train_total += valid_gt.numel()

                    valid_samples += 1

            if valid_samples > 0:
                loss = loss / valid_samples
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                train_batches += 1
                # train_accuracy = train_correct / train_total if train_total > 0 else 0.0
                wandb.log({"batch_train_loss": loss.item()})

                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(train_loader)}, "
                      f"Train Loss: {loss.item():.4f}")

        train_loss = train_loss / train_batches if train_batches > 0 else 0.0
        # train_accuracy = train_correct / train_total if train_total > 0 else 0.0

        print("Started validation")
        # Validation
        model.eval()
        val_loss = 0.0
        val_mae = 0.0
        val_count = 0
        val_batches = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                if batch is None:
                    continue
                video = batch["video"].to(device)
                audio = batch["audio"].to(device)
                timestamps = batch["timestamps"].to(device)
                scores = batch["scores"].to(device)
                num_events = batch["num_events"]

                if all(n == 0 for n in num_events):
                    continue

                pred_scores = model(video, audio, timestamps)
                loss = torch.tensor(0.0, device=device)
                for i in range(len(num_events)):
                    if num_events[i] > 0:
                        valid_scores = pred_scores[i, :num_events[i]]
                        valid_gt = scores[i, :num_events[i]]
                        loss = loss + criterion(valid_scores, valid_gt)

                        val_mae += torch.mean(torch.abs(valid_scores - valid_gt)).item()
                        val_count += 1

                        pred_labels = (valid_scores >= 0.5).float()
                        val_correct += (pred_labels == valid_gt).float().sum().item()
                        val_total += valid_gt.numel()

                if val_count > 0:
                    val_loss += (loss / len(num_events)).item()
                    val_batches += 1
                    # val_accuracy = val_correct / val_total if val_total > 0 else 0.0
                    print(f"Epoch {epoch+1}/{epochs}, Validation Batch {batch_idx+1}/{len(val_loader)}, "
                          f"Val Loss: {(loss / len(num_events)).item():.4f}, Val MAE: {val_mae:.4f}")

        val_loss = val_loss / val_batches if val_batches > 0 else 0.0
        val_mae = val_mae / val_count if val_count > 0 else 0.0
        # val_accuracy = val_correct / val_total if val_total > 0 else 0.0

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            # "train_accuracy": train_accuracy,
            "val_loss": val_loss,
            "val_mae": val_mae,
            # "val_accuracy": val_accuracy
        })

        print(f"Epoch {epoch+1}/{epochs}, "
              f"Train Loss: {train_loss:.4f},"
              f"Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}")
        torch.cuda.empty_cache()

    # Save model
    torch.save(model.state_dict(), "highlight_model.pth")
    wandb.save("highlight_model.pth")

In [13]:
if __name__ == "__main__":
    data_dir = "/content/drive/MyDrive/soccernet"  # Update with your path
    train_split = os.path.join(data_dir, "train.json")
    valid_split = os.path.join(data_dir, "valid.json")
    # print(train_split)
    train_model(data_dir, train_split, valid_split, batch_size=4, epochs=3, lr=1e-3, device="cuda")

Training dataset size: 50 games
Validation dataset size: 10 games
Started training
Epoch 1/3, Batch 1/13, Train Loss: 0.6910
Epoch 1/3, Batch 2/13, Train Loss: 1.1279
Epoch 1/3, Batch 3/13, Train Loss: 1.1055
Epoch 1/3, Batch 4/13, Train Loss: 0.7002
Epoch 1/3, Batch 5/13, Train Loss: 0.7086
Epoch 1/3, Batch 6/13, Train Loss: 0.7347
Epoch 1/3, Batch 7/13, Train Loss: 0.6927
Epoch 1/3, Batch 8/13, Train Loss: 0.7411
Epoch 1/3, Batch 9/13, Train Loss: 0.6361
Epoch 1/3, Batch 10/13, Train Loss: 0.7127
Epoch 1/3, Batch 11/13, Train Loss: 0.7064
Epoch 1/3, Batch 12/13, Train Loss: 0.6978
Epoch 1/3, Batch 13/13, Train Loss: 0.6933
Started validation
Epoch 1/3, Validation Batch 1/3, Val Loss: 0.6995, Val MAE: 0.8144
Epoch 1/3, Validation Batch 2/3, Val Loss: 0.7103, Val MAE: 1.5820
Epoch 1/3, Validation Batch 3/3, Val Loss: 0.7288, Val MAE: 2.0801
Epoch 1/3, Train Loss: 0.7652,Val Loss: 0.7129, Val MAE: 0.2080
Epoch 2/3, Batch 1/13, Train Loss: 0.7028
Epoch 2/3, Batch 2/13, Train Loss: 0.6944