In [1]:
import torch
import torch.nn as nn
import torchvision.models.video as video_models
import openl3
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torchvision.io import read_video
import torchaudio
from pathlib import Path
import random

In [2]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, x):
        return self.net(x)

In [3]:
class VideoEncoder(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        base_model = video_models.r3d_18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # remove final FC
        self.proj_head = ProjectionHead(512, proj_dim)

    def forward(self, x):  # (B, C, T, H, W)
        x = self.feature_extractor(x)  # (B, 512, 1, 1, 1)
        x = x.view(x.size(0), -1)
        return self.proj_head(x)

In [19]:
class OpenL3Encoder(nn.Module):
    def __init__(self, proj_dim=128, input_repr="mel256", content_type="music", embedding_size=512):
        super().__init__()
        self.sr = 48000
        self.model = openl3.models.load_audio_embedding_model(
            input_repr=input_repr,
            content_type=content_type,
            embedding_size=embedding_size
        )

        self.proj_head = ProjectionHead(embedding_size, proj_dim)

    def forward(self, x):  # x: (B, L)
        embeddings = []
        batch_size = len(x) # Get batch size from list length

        # Check if the batch is empty
        if batch_size == 0:
            # Handle empty batch: maybe return an empty tensor of the expected shape
            # The projection head expects input dimension embedding_size
            # Example: return an empty tensor for the projection head input
             empty_proj_input = torch.empty(0, self.proj_head.net[0].in_features)
             # If x was expected to be on a specific device, match it if possible, else use default
             device = self.proj_head.net[0].weight.device # Get device from proj_head params
             empty_proj_input = empty_proj_input.to(device)
             return self.proj_head(empty_proj_input)


        # Iterate through the list of tensors
        for i in range(batch_size):
            audio_tensor = x[i] # Get the i-th tensor
            # Ensure tensor is 1D numpy array for openl3
            # Squeeze potentially removes channel dim if present (e.g., [1, L] -> [L])
            audio_np = audio_tensor.squeeze().detach().cpu().numpy()

            # Check if audio_np is actually 1D after squeeze
            if audio_np.ndim != 1:
                # Handle unexpected dimensions, e.g. could be empty after processing
                print(f"Warning: Audio sample {i} has unexpected shape {audio_np.shape} after processing. Skipping.")
                # Option: append a zero tensor or skip. Skipping requires careful handling later.
                # For simplicity, let's try to add zeros if shape is bad, but this might hide issues.
                # A better approach might be to ensure the dataset always returns valid 1D audio.
                # Assuming embedding_size is the dimension needed:
                emb_mean = torch.zeros(self.proj_head.net[0].in_features, device=audio_tensor.device).float()

            elif audio_np.size == 0:
                 print(f"Warning: Audio sample {i} is empty after processing. Skipping.")
                 emb_mean = torch.zeros(self.proj_head.net[0].in_features, device=audio_tensor.device).float()
            else:
                # Process with OpenL3 using the correct model attribute name
                emb, _ = openl3.get_audio_embedding(audio_np, self.sr, model=self.model, center=True)
                # Aggregate embeddings and move to the correct device
                emb_mean = torch.tensor(emb.mean(axis=0), device=audio_tensor.device).float()

            embeddings.append(emb_mean)

        # Stack embeddings for the batch
        embeddings_batch = torch.stack(embeddings)  # (B, embedding_size)

        # Apply the PyTorch projection head
        return self.proj_head(embeddings_batch)

In [5]:
class AVContrastiveModel(nn.Module):
    def __init__(self, proj_dim=128):
        super().__init__()
        self.video_encoder = VideoEncoder(proj_dim)
        self.audio_encoder = OpenL3Encoder(proj_dim)

    def forward(self, video, audio):
        z_video = self.video_encoder(video)  # (B, proj_dim)
        z_audio = self.audio_encoder(audio)  # (B, proj_dim)
        return z_video, z_audio


In [21]:
def contrastive_loss(z1, z2, temperature=0.07):
    # z1, z2 shapes: (B, D)
    B = z1.size(0)
    D = z1.size(1)

    # Normalize features
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # Concatenate features: Video embeddings first, then Audio embeddings
    z = torch.cat([z1, z2], dim=0)  # Shape: (2B, D)

    # Calculate pairwise cosine similarity
    # sim[i, j] = similarity between z[i] and z[j]
    sim = torch.matmul(z, z.T)  # Shape: (2B, 2B)

    # --- REMOVE MASKING AND RESHAPING ---
    # mask = ~torch.eye(2 * B, dtype=bool, device=z.device)
    # logits = sim[mask].view(2 * B, -1) / temperature # Incorrect shape (2B, 2B-1)
    # --- FIX: Use the full similarity matrix ---
    logits = sim / temperature # Shape: (2B, 2B)

    # Create targets:
    # For the first B rows (videos z1), the positive match is the corresponding audio (z2) at index i+B
    # For the second B rows (audios z2), the positive match is the corresponding video (z1) at index i
    targets_arange = torch.arange(B, device=z.device)
    # Targets for rows 0 to B-1 should be B to 2B-1
    # Targets for rows B to 2B-1 should be 0 to B-1
    targets = torch.cat([targets_arange + B, targets_arange]) # Shape: (2B,) Correct indices for (2B, 2B) logits

    # Calculate cross-entropy loss
    # logits shape (2B, 2B), targets shape (2B,) with values in [0, 2B-1]
    return F.cross_entropy(logits, targets)


In [15]:
class AVContrastiveDataset(Dataset):
    def __init__(
        self,
        root_dir="clipped_data",      # Original data structure (for finding samples)
        aug_root="augmented_data",    # Root for augmented files
        video_aug_dirs=["crop_color", "crop_sobel"], # Subdirs under aug_root
        audio_aug_dirs=["bg_noise", "drc"],          # Subdirs under aug_root
        num_frames=16,
        video_size=(112, 112),
        audio_sr=48000
    ):
        self.root = Path(root_dir)
        self.aug_root = Path(aug_root)
        self.video_aug_dirs = video_aug_dirs
        self.audio_aug_dirs = audio_aug_dirs
        self.num_frames = num_frames
        self.video_size = video_size
        self.audio_sr = audio_sr

        # Collect all .mp4 files under class folders in the ORIGINAL directory
        self.samples = []
        # Ensure root directory exists
        if not self.root.is_dir():
             raise FileNotFoundError(f"Root directory '{self.root}' not found.")

        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        if not self.classes:
             print(f"Warning: No class subdirectories found in '{self.root}'. Dataset will be empty.")

        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            class_path = self.root / cls
            if not class_path.is_dir():
                print(f"Warning: Expected directory, but found file: {class_path}")
                continue
            found_files = False
            for video_path in class_path.glob("*.mp4"):
                base_name = video_path.stem
                self.samples.append((cls, base_name)) # Store class and base filename
                found_files = True
            if not found_files:
                 print(f"Warning: No .mp4 files found in class directory: {class_path}")


        if not self.samples:
             print(f"Warning: No samples collected. Check '{self.root}' structure and content.")


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Retrieve the class name (e.g., 'violin') and base filename (e.g., 'video_001')
        cls, base = self.samples[idx]
        label = self.class_to_idx[cls]

        # Choose random augmentations
        video_aug = random.choice(self.video_aug_dirs)
        audio_aug = random.choice(self.audio_aug_dirs)

        # --- CORRECTED PATH CONSTRUCTION ---
        # Assumes structure: augmented_data/<aug_type>/<class>/<base_name>.mp4
        video_path = self.aug_root / video_aug / cls / f"{base}.mp4"
        audio_path = self.aug_root / audio_aug / cls / f"{base}.wav"
        # --- END CORRECTION ---

        try:
            # Load video
            video, _, _ = read_video(str(video_path), pts_unit="sec")

            # Check for empty video immediately after loading
            if video.nelement() == 0:
                print(f"WARNING: Loaded video tensor is empty for {video_path}!")
                # Decide how to handle: skip, raise error, return dummy?
                # For now, raising an error is safest to alert you.
                raise RuntimeError(f"Video file {video_path} loaded with 0 frames or elements.")

            video = video.permute(0, 3, 1, 2).float() / 255.0  # T x C x H x W

            T_total = video.size(0)

            # Explicitly check for T_total <= 0 before division/modulo
            if T_total <= 0:
                 raise RuntimeError(f"Video file {video_path} resulted in T_total={T_total} after permute.")

            # Resize and crop video frames
            if T_total > self.num_frames:
                start = random.randint(0, T_total - self.num_frames)
                video = video[start : start + self.num_frames]
            else:
                # Now T_total > 0 is guaranteed here
                repeat = (self.num_frames + T_total - 1) // T_total
                video = video.repeat((repeat, 1, 1, 1))[:self.num_frames]

            video = torch.nn.functional.interpolate(video, size=self.video_size, mode='bilinear', align_corners=False) # Added align_corners=False often recommended
            video = video.permute(1, 0, 2, 3)  # → C x T x H x W

            # Load and resample audio
            waveform, sr = torchaudio.load(str(audio_path))
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)  # Mono
            if sr != self.audio_sr:
                resampler = torchaudio.transforms.Resample(sr, self.audio_sr)
                waveform = resampler(waveform)
            # Ensure waveform is 1D (B, L) -> (L) for OpenL3 later? Check model input.
            # The OpenL3Encoder seems to handle batching internally by iterating,
            # so returning individual waveforms might be correct.
            # waveform = waveform.squeeze(0) # If needed

        except FileNotFoundError:
             print(f"ERROR: File not found at path: {video_path} or {audio_path}")
             # Depending on desired behavior, you might want to return None
             # and handle it in the DataLoader's collate_fn, or raise the error.
             raise # Re-raise the error for now
        except Exception as e:
            print(f"ERROR processing sample idx {idx} (cls='{cls}', base='{base}') at path {video_path} / {audio_path}: {e}")
            raise # Re-raise the error

        return video, waveform.squeeze(0), label # Return waveform likely needs to be (L,) not (1, L)

In [22]:
# Assume all model components are defined:
# - AVContrastiveModel
# - contrastive_loss
# - AVContrastiveDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and optimizer
model = AVContrastiveModel(proj_dim=128).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Dataset and DataLoader
dataset = AVContrastiveDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for i, (video, audio, _) in enumerate(dataloader):
        # video: (B, C, T, H, W), audio: (B, L)
        video = video.to(device)
        audio = [a.to(device) for a in audio]  # individual waveforms (already variable length)

        # Forward pass
        optimizer.zero_grad()
        z_video, z_audio = model(video, audio)  # OpenL3 handles audio per-sample
        loss = contrastive_loss(z_video, z_audio)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 10 == 0:
            print(f"[Epoch {epoch+1}] Step {i}/{len(dataloader)} - Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}")




[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 886ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 811ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 788ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 793ms/step
[Epoch 1] Step 0/258 - Loss: 14.9666
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 783ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 788ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 808ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 807ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 798ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 867ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 842ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 830ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 809ms/step
[1m2/2[0m [

KeyboardInterrupt: 