In [3]:
import torch
import torch.nn as nn
import torchvision.models.video as video_models
import openl3
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torchvision.io import read_video
import torchaudio
from pathlib import Path
import random

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class VideoEncoder(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        base_model = video_models.r3d_18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # remove final FC
        self.proj_head = ProjectionHead(512, proj_dim)

    def forward(self, x):  # (B, C, T, H, W)
        x = self.feature_extractor(x)  # (B, 512, 1, 1, 1)
        x = x.view(x.size(0), -1)
        return self.proj_head(x)

In [None]:
class OpenL3Encoder(nn.Module):
    def __init__(self, proj_dim=128, input_repr="mel256", content_type="music", embedding_size=512):
        super().__init__()
        self.sr = 48000
        self.model = openl3.models.load_audio_embedding_model(
            input_repr=input_repr,
            content_type=content_type,
            embedding_size=embedding_size
        )
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

        self.proj_head = ProjectionHead(embedding_size, proj_dim)

    def forward(self, x):  # x: (B, L)
        embeddings = []
        for i in range(x.size(0)):
            audio_np = x[i].detach().cpu().numpy()
            emb, _ = openl3.get_audio_embedding(audio_np, self.sr, model=self.model, center=True)
            emb_mean = torch.tensor(emb.mean(axis=0), device=x.device).float()
            embeddings.append(emb_mean)
        embeddings = torch.stack(embeddings)  # (B, 512)
        return self.proj_head(embeddings)

In [None]:
class AVContrastiveModel(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        self.video_encoder = VideoEncoder(proj_dim)
        self.audio_encoder = OpenL3Encoder(proj_dim)

    def forward(self, video, audio):
        z_video = self.video_encoder(video)  # (B, proj_dim)
        z_audio = self.audio_encoder(audio)  # (B, proj_dim)
        return z_video, z_audio


In [None]:
def contrastive_loss(z1, z2, temperature=0.07):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0)  # (2B, D)

    sim = torch.matmul(z, z.T)  # cosine similarity
    B = z1.size(0)
    mask = ~torch.eye(2 * B, dtype=bool, device=z.device)

    # Similarity with negatives masked out
    logits = sim[mask].view(2 * B, -1) / temperature
    targets = torch.arange(B, device=z.device)
    targets = torch.cat([targets + B, targets])

    return F.cross_entropy(logits, targets)


In [None]:
class AVContrastiveDataset(Dataset):
    def __init__(
        self,
        root_dir="clipped_data",
        aug_root="augmented_data",
        video_aug_dirs=["crop_color", "crop_sobel"],
        audio_aug_dirs=["bg_noise", "drc"],
        num_frames=16,
        video_size=(112, 112),
        audio_sr=48000
    ):
        self.root = Path(root_dir)
        self.aug_root = Path(aug_root)
        self.video_aug_dirs = video_aug_dirs
        self.audio_aug_dirs = audio_aug_dirs
        self.num_frames = num_frames
        self.video_size = video_size
        self.audio_sr = audio_sr

        # Collect all .mp4 files under class folders
        self.samples = []
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            for video_path in (self.root / cls).glob("*.mp4"):
                base_name = video_path.stem
                self.samples.append((cls, base_name))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        cls, base = self.samples[idx]
        label = self.class_to_idx[cls]

        # Choose random augmentations
        video_aug = random.choice(self.video_aug_dirs)
        audio_aug = random.choice(self.audio_aug_dirs)

        # Construct paths
        video_path = self.aug_root / video_aug / f"{base}.mp4"
        audio_path = self.aug_root / audio_aug / f"{base}.wav"

        # Load video
        video, _, _ = read_video(str(video_path), pts_unit="sec")
        video = video.permute(0, 3, 1, 2).float() / 255.0  # T x C x H x W

        # Resize and crop (manual resize to match ResNet input)
        T_total = video.size(0)
        if T_total > self.num_frames:
            start = random.randint(0, T_total - self.num_frames)
            video = video[start : start + self.num_frames]
        else:
            repeat = (self.num_frames + T_total - 1) // T_total
            video = video.repeat((repeat, 1, 1, 1))[:self.num_frames]

        video = torch.nn.functional.interpolate(video, size=self.video_size, mode='bilinear')
        video = video.permute(1, 0, 2, 3)  # â†’ C x T x H x W

        # Load and resample audio
        waveform, sr = torchaudio.load(audio_path)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)  # Mono
        if sr != self.audio_sr:
            resampler = torchaudio.transforms.Resample(sr, self.audio_sr)
            waveform = resampler(waveform)

        return video, waveform, label


In [None]:
# Assume all model components are defined:
# - AVContrastiveModel
# - contrastive_loss
# - AVContrastiveDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and optimizer
model = AVContrastiveModel(proj_dim=128).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Dataset and DataLoader
dataset = AVContrastiveDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for i, (video, audio, _) in enumerate(dataloader):
        # video: (B, C, T, H, W), audio: (B, L)
        video = video.to(device)
        audio = [a.to(device) for a in audio]  # individual waveforms (already variable length)

        # Forward pass
        optimizer.zero_grad()
        z_video, z_audio = model(video, audio)  # OpenL3 handles audio per-sample
        loss = contrastive_loss(z_video, z_audio)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 10 == 0:
            print(f"[Epoch {epoch+1}] Step {i}/{len(dataloader)} - Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}")
