# Joint Embeddings (Projection Heads + Species Supervision)

This notebook learns a shared audio/video embedding space using projection heads and a multi-positive CLIP-style contrastive loss.
It keeps baseline analyses (audio-only, video-only, concat when aligned), UMAP visualizations, clustering metrics, and linear probes.


In [30]:
from pathlib import Path
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [31]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import LabelEncoder
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_mutual_info_score, normalized_mutual_info_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score

In [32]:
# Config
root = Path("/Users/wendycao/fish/processed")
audio_dir = root / "surfperch_embeddings"
video_dir = root / "dinov2_embeddings"

out_dir = root / "joint_embeddings_projection"
out_dir.mkdir(parents=True, exist_ok=True)

proj_dim = 256
hidden_dim = 512  # set to None for single Linear
lr = 1e-3
weight_decay = 1e-4
epochs = 30
steps_per_epoch = 200
species_per_batch = 8
items_per_species = 4
batch_size = species_per_batch * items_per_species

tau = 0.07
seed = 42
split_frac = 0.3  # held-out fraction

rng = np.random.default_rng(seed)
torch.manual_seed(seed)

# Load base embeddings + metadata
audio_emb = np.load(audio_dir / "embeddings.npy")
meta_audio = pd.read_csv(audio_dir / "metadata.csv")

video_emb = np.load(video_dir / "embeddings.npy")
meta_video = pd.read_csv(video_dir / "metadata.csv")

print("audio_emb:", audio_emb.shape, "meta_audio:", meta_audio.shape)
print("video_emb:", video_emb.shape, "meta_video:", meta_video.shape)


audio_emb: (219, 1280) meta_audio: (219, 2)
video_emb: (265, 768) meta_video: (265, 2)


In [33]:
# Labels and train/held-out split

def ensure_species_column(df: pd.DataFrame, path_col_candidates=("clip_path", "path", "file_path")) -> pd.DataFrame:
    df = df.copy()
    if "species" in df.columns:
        return df
    path_col = None
    for c in path_col_candidates:
        if c in df.columns:
            path_col = c
            break
    if path_col is None:
        raise ValueError("No species column and no path column found to derive species.")
    df["species"] = df[path_col].apply(lambda p: Path(p).parent.name)
    return df

meta_audio = ensure_species_column(meta_audio)
meta_video = ensure_species_column(meta_video)

if len(meta_audio) != len(audio_emb):
    raise ValueError(f"audio metadata length {len(meta_audio)} != audio embeddings {len(audio_emb)}")
if len(meta_video) != len(video_emb):
    raise ValueError(f"video metadata length {len(meta_video)} != video embeddings {len(video_emb)}")

# Train/held-out split per modality, stratified by species

def stratified_split_indices(labels, test_size=0.3, seed=42, min_per_class=2):
    labels = np.array(labels)
    idx = np.arange(len(labels))

    # Drop classes with fewer than min_per_class to avoid stratify errors
    counts = pd.Series(labels).value_counts()
    keep_labels = counts[counts >= min_per_class].index
    keep_mask = np.isin(labels, keep_labels)
    dropped = idx[~keep_mask]

    if keep_mask.sum() == 0:
        raise ValueError("No classes with at least 2 samples; cannot stratify.")

    idx_keep = idx[keep_mask]
    labels_keep = labels[keep_mask]

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
    train_rel, test_rel = next(splitter.split(idx_keep, labels_keep))
    train_idx = idx_keep[train_rel]
    test_idx = idx_keep[test_rel]

    return train_idx, test_idx, dropped

species_audio = meta_audio["species"].tolist()
species_video = meta_video["species"].tolist()

audio_train_idx, audio_test_idx, audio_dropped = stratified_split_indices(
    species_audio, test_size=split_frac, seed=seed
)
video_train_idx, video_test_idx, video_dropped = stratified_split_indices(
    species_video, test_size=split_frac, seed=seed
)

np.save(out_dir / "audio_train_idx.npy", audio_train_idx)
np.save(out_dir / "audio_test_idx.npy", audio_test_idx)
np.save(out_dir / "audio_dropped_idx.npy", audio_dropped)
np.save(out_dir / "video_train_idx.npy", video_train_idx)
np.save(out_dir / "video_test_idx.npy", video_test_idx)
np.save(out_dir / "video_dropped_idx.npy", video_dropped)

print("Saved split indices to:", out_dir)
print("Dropped audio samples (singletons):", len(audio_dropped))
print("Dropped video samples (singletons):", len(video_dropped))


Saved split indices to: /Users/wendycao/fish/processed/joint_embeddings_projection
Dropped audio samples (singletons): 3
Dropped video samples (singletons): 3


In [34]:
# Optional: aligned pairs for concat baseline

def add_key_column(df, path_col="clip_path"):
    keys = []
    for p in df[path_col].tolist():
        stem = Path(p).stem
        if stem.endswith("_cropped"):
            stem = stem[:-len("_cropped")]
        keys.append(stem)
    df = df.copy()
    df["key"] = keys
    return df

concat_available = False
concat_audio = None
concat_video = None
concat_species = None

if "clip_path" in meta_audio.columns and "clip_path" in meta_video.columns:
    meta_audio_k = add_key_column(meta_audio, "clip_path").reset_index().rename(columns={"index": "audio_idx"})
    meta_video_k = add_key_column(meta_video, "clip_path").reset_index().rename(columns={"index": "video_idx"})

    merged = meta_audio_k.merge(
        meta_video_k,
        on=["species", "key"],
        suffixes=("_audio", "_video"),
    )

    if len(merged) > 0:
        audio_indices = merged["audio_idx"].to_numpy()
        video_indices = merged["video_idx"].to_numpy()
        concat_audio = audio_emb[audio_indices]
        concat_video = video_emb[video_indices]
        concat_species = merged["species"].tolist()
        concat_available = True

    print("Aligned pairs for concat baseline:", len(merged))
else:
    print("No clip_path in one or both metadata files; concat baseline skipped.")


Aligned pairs for concat baseline: 219


In [35]:
# Helper functions

def l2norm(x, eps=1e-12):
    return x / (np.linalg.norm(x, axis=1, keepdims=True) + eps)


def recall_at_k(query_emb, query_species, db_emb, db_species, ks=(1, 5, 10), batch_size=256, exclude_self=False):
    max_k = max(ks)
    q_species = np.array(query_species)
    db_species = np.array(db_species)

    db_t = torch.from_numpy(db_emb).float()
    correct = {k: 0 for k in ks}
    n = len(query_emb)

    for i in range(0, n, batch_size):
        q = torch.from_numpy(query_emb[i:i+batch_size]).float()
        sims = q @ db_t.T

        if exclude_self:
            # assumes query and db are the same set and aligned by index
            for row in range(sims.shape[0]):
                sims[row, i + row] = -1e9

        topk = torch.topk(sims, k=max_k, dim=1).indices.numpy()

        for row, idxs in enumerate(topk):
            label = q_species[i + row]
            retrieved = db_species[idxs]
            for k in ks:
                if label in retrieved[:k]:
                    correct[k] += 1

    return {k: correct[k] / n for k in ks}


def cluster_metrics(X, y, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    preds = kmeans.fit_predict(X)
    ami = adjusted_mutual_info_score(y, preds)
    nmi = normalized_mutual_info_score(y, preds)
    return ami, nmi


def linear_probe(X, y, train_frac=0.4, seed=42, min_per_class=2):
    y = np.array(y)
    # Drop classes with fewer than min_per_class samples
    counts = pd.Series(y).value_counts()
    keep = counts[counts >= min_per_class].index
    keep_mask = np.isin(y, keep)
    dropped = len(y) - keep_mask.sum()
    X = X[keep_mask]
    y = y[keep_mask]

    if dropped > 0:
        print(f"linear_probe: dropped {dropped} samples from classes with <{min_per_class} examples")

    # Guard against inf/nan
    X = np.nan_to_num(X, copy=False, posinf=0.0, neginf=0.0)

    # If still not stratifiable, fall back to non-stratified split
    counts = pd.Series(y).value_counts()
    if len(counts) < 2 or counts.min() < 2:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, train_size=train_frac, random_state=seed, shuffle=True
        )
    else:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, train_size=train_frac, random_state=seed, stratify=y
        )

    # Standardize for numerical stability
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler(with_mean=True, with_std=True)
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    clf = LogisticRegression(max_iter=1000, n_jobs=-1)
    clf.fit(X_train, y_train)
    pred = clf.predict(X_test)
    acc = accuracy_score(y_test, pred)
    f1 = f1_score(y_test, pred, average="macro")
    return acc, f1


In [36]:
# Projection heads + multi-positive contrastive loss

from collections import defaultdict

def sample_species_batch(species_list, audio_map, video_map, m, k, rng):
    chosen = rng.choice(species_list, size=m, replace=len(species_list) < m)
    a_idx = []
    v_idx = []
    a_lab = []
    v_lab = []
    for s in chosen:
        a_pool = audio_map[s]
        v_pool = video_map[s]
        a_pick = rng.choice(a_pool, size=k, replace=len(a_pool) < k)
        v_pick = rng.choice(v_pool, size=k, replace=len(v_pool) < k)
        a_idx.extend(a_pick)
        v_idx.extend(v_pick)
        a_lab.extend([s] * k)
        v_lab.extend([s] * k)
    return np.array(a_idx), np.array(v_idx), np.array(a_lab), np.array(v_lab)


class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=512):
        super().__init__()
        if hidden_dim is None:
            self.net = nn.Linear(in_dim, out_dim)
        else:
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, out_dim),
            )

    def forward(self, x):
        x = self.net(x)
        return F.normalize(x, dim=1)


def multi_positive_clip_loss(z_a, z_v, y_a, y_v, tau=0.07):
    logits = (z_a @ z_v.T) / tau
    y_a = y_a.view(-1, 1)
    y_v = y_v.view(1, -1)
    pos_mask = (y_a == y_v)

    neg_inf = torch.finfo(logits.dtype).min

    # audio -> video
    logits_pos = logits.masked_fill(~pos_mask, neg_inf)
    log_prob_pos = torch.logsumexp(logits_pos, dim=1) - torch.logsumexp(logits, dim=1)
    valid_a = pos_mask.any(dim=1)
    loss_a = -log_prob_pos[valid_a].mean() if valid_a.any() else torch.tensor(0.0, device=logits.device)

    # video -> audio
    logits_t = logits.T
    pos_mask_t = pos_mask.T
    logits_pos_t = logits_t.masked_fill(~pos_mask_t, neg_inf)
    log_prob_pos_t = torch.logsumexp(logits_pos_t, dim=1) - torch.logsumexp(logits_t, dim=1)
    valid_v = pos_mask_t.any(dim=1)
    loss_v = -log_prob_pos_t[valid_v].mean() if valid_v.any() else torch.tensor(0.0, device=logits.device)

    return 0.5 * (loss_a + loss_v)


In [37]:
# Train projection heads (train split only)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# Build species -> indices (train only)
audio_by_species = defaultdict(list)
for i in audio_train_idx:
    audio_by_species[species_audio[i]].append(i)

video_by_species = defaultdict(list)
for i in video_train_idx:
    video_by_species[species_video[i]].append(i)

species_shared = sorted(set(audio_by_species) & set(video_by_species))
print("Shared species (train):", len(species_shared))

if len(species_shared) == 0:
    raise ValueError("No shared species between audio/video in train split.")

# Map species -> int id for labels
species_to_id = {s: i for i, s in enumerate(species_shared)}

# Heads
audio_head = ProjectionHead(audio_emb.shape[1], proj_dim, hidden_dim=hidden_dim).to(device)
video_head = ProjectionHead(video_emb.shape[1], proj_dim, hidden_dim=hidden_dim).to(device)

opt = torch.optim.AdamW(list(audio_head.parameters()) + list(video_head.parameters()),
                        lr=lr, weight_decay=weight_decay)

for epoch in range(1, epochs + 1):
    audio_head.train()
    video_head.train()
    total = 0.0
    n_steps = 0

    for _ in range(steps_per_epoch):
        a_idx, v_idx, a_lab, v_lab = sample_species_batch(
            species_shared, audio_by_species, video_by_species,
            species_per_batch, items_per_species, rng
        )
        a = torch.from_numpy(audio_emb[a_idx]).float().to(device)
        v = torch.from_numpy(video_emb[v_idx]).float().to(device)
        y_a = torch.from_numpy(np.array([species_to_id[s] for s in a_lab])).long().to(device)
        y_v = torch.from_numpy(np.array([species_to_id[s] for s in v_lab])).long().to(device)

        z_a = audio_head(a)
        z_v = video_head(v)
        loss = multi_positive_clip_loss(z_a, z_v, y_a, y_v, tau=tau)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total += loss.item()
        n_steps += 1

    avg = total / max(1, n_steps)
    print(f"Epoch {epoch:02d}/{epochs}  loss={avg:.4f}")


device: cpu
Shared species (train): 43
Epoch 01/30  loss=1.1455
Epoch 02/30  loss=0.1143
Epoch 03/30  loss=0.0291
Epoch 04/30  loss=0.0052
Epoch 05/30  loss=0.0027
Epoch 06/30  loss=0.0020
Epoch 07/30  loss=0.0014
Epoch 08/30  loss=0.0012
Epoch 09/30  loss=0.0010
Epoch 10/30  loss=0.0008
Epoch 11/30  loss=0.0007
Epoch 12/30  loss=0.0006
Epoch 13/30  loss=0.0006
Epoch 14/30  loss=0.0005
Epoch 15/30  loss=0.0004
Epoch 16/30  loss=0.0004
Epoch 17/30  loss=0.0003
Epoch 18/30  loss=0.0003
Epoch 19/30  loss=0.0003
Epoch 20/30  loss=0.0003
Epoch 21/30  loss=0.0002
Epoch 22/30  loss=0.0002
Epoch 23/30  loss=0.0002
Epoch 24/30  loss=0.0002
Epoch 25/30  loss=0.0002
Epoch 26/30  loss=0.0001
Epoch 27/30  loss=0.0001
Epoch 28/30  loss=0.0001
Epoch 29/30  loss=0.0001
Epoch 30/30  loss=0.0001


In [38]:
# Export learned joint embeddings (L2-normalized)

@torch.no_grad()
def encode_all(emb, head, batch_size=1024):
    head.eval()
    out = []
    for i in range(0, len(emb), batch_size):
        x = torch.from_numpy(emb[i:i+batch_size]).float().to(device)
        z = head(x).cpu().numpy()
        out.append(z)
    return np.vstack(out)


def l2_normalize(X, eps=1e-12):
    return X / (np.linalg.norm(X, axis=1, keepdims=True) + eps)


audio_joint = l2_normalize(encode_all(audio_emb, audio_head))
video_joint = l2_normalize(encode_all(video_emb, video_head))

print("audio_joint:", audio_joint.shape, "video_joint:", video_joint.shape)

np.save(out_dir / "audio_joint.npy", audio_joint)
np.save(out_dir / "video_joint.npy", video_joint)

# Save metadata with species
meta_audio_out = meta_audio.copy()
meta_video_out = meta_video.copy()
meta_audio_out.to_csv(out_dir / "audio_metadata.csv", index=False)
meta_video_out.to_csv(out_dir / "video_metadata.csv", index=False)

print("Saved joint embeddings and metadata to:", out_dir)


audio_joint: (219, 256) video_joint: (265, 256)
Saved joint embeddings and metadata to: /Users/wendycao/fish/processed/joint_embeddings_projection


## Cross-modal retrieval evaluation (held-out)


In [39]:
# Reusable helpers

def l2_normalize(X, eps=1e-12):
    return X / (np.linalg.norm(X, axis=1, keepdims=True) + eps)


def get_path_column(df, candidates=("clip_path", "path", "file_path")):
    for c in candidates:
        if c in df.columns:
            return c
    raise ValueError("No path column found in metadata.")


def cross_modal_metrics(query_emb, query_species, db_emb, db_species, ks=(1, 5, 10), batch_size=256):
    Q = l2_normalize(query_emb)
    DB = l2_normalize(db_emb)

    db_species_arr = np.array(db_species)
    query_species_arr = np.array(query_species)
    n_db = len(db_species_arr)

    recall = {k: [] for k in ks}
    precision = {k: [] for k in ks}
    ranks = []

    for i in range(0, len(Q), batch_size):
        sims = Q[i:i+batch_size] @ DB.T
        order = np.argsort(-sims, axis=1)

        for row in range(order.shape[0]):
            qsp = query_species_arr[i + row]

            if not (db_species_arr == qsp).any():
                ranks.append(n_db + 1)
                for k in ks:
                    recall[k].append(0.0)
                    precision[k].append(0.0)
                continue

            ranked_idx = order[row]
            ranked_species = db_species_arr[ranked_idx]
            match_positions = np.where(ranked_species == qsp)[0]
            rank = int(match_positions[0] + 1) if match_positions.size > 0 else n_db + 1
            ranks.append(rank)

            for k in ks:
                topk_species = ranked_species[:k]
                recall[k].append(1.0 if qsp in topk_species else 0.0)
                precision[k].append(float(np.mean(topk_species == qsp)))

    metrics = {
        "recall": {k: float(np.mean(recall[k])) for k in ks},
        "precision": {k: float(np.mean(precision[k])) for k in ks},
        "median_rank": float(np.median(ranks)) if len(ranks) else float("nan"),
        "ranks": np.array(ranks, dtype=int),
    }
    return metrics


def average_chance(query_species, db_species):
    db_counts = pd.Series(db_species).value_counts()
    n_db = len(db_species)
    if n_db == 0:
        return 0.0
    chances = [(db_counts.get(s, 0) / n_db) for s in query_species]
    return float(np.mean(chances)) if len(chances) else 0.0


def save_cross_modal_examples(
    query_emb,
    query_species,
    query_paths,
    db_emb,
    db_species,
    db_paths,
    out_csv,
    seed=42,
    n_samples=20,
    topk=5,
):
    Q = l2_normalize(query_emb)
    DB = l2_normalize(db_emb)

    rng = np.random.default_rng(seed)
    n_samples = min(n_samples, len(Q))
    if n_samples == 0:
        print(f"No queries available for examples: {out_csv}")
        return

    sample_idx = rng.choice(len(Q), size=n_samples, replace=False)
    rows = []

    for qi in sample_idx:
        sims = Q[qi] @ DB.T
        order = np.argsort(-sims)

        q_path = query_paths[qi]
        q_species = query_species[qi]

        top1 = order[0]
        top1_path = db_paths[top1]
        top1_species = db_species[top1]
        top1_score = float(sims[top1])

        topk_list = []
        for idx in order[:topk]:
            topk_list.append(f"{db_paths[idx]}|{db_species[idx]}|{sims[idx]:.4f}")
        topk_str = ";".join(topk_list)

        ranked_species = np.array(db_species)[order]
        match_positions = np.where(ranked_species == q_species)[0]
        rank = int(match_positions[0] + 1) if match_positions.size > 0 else len(db_species) + 1

        rows.append({
            "query_clip_path": q_path,
            "query_species": q_species,
            "top1_clip_path": top1_path,
            "top1_species": top1_species,
            "top1_score": top1_score,
            "top5": topk_str,
            "rank_of_first_same_species": rank,
        })

    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print("Saved examples:", out_csv)


# Build held-out query/db splits
path_audio = get_path_column(meta_audio)
path_video = get_path_column(meta_video)

audio_test_idx = np.array(audio_test_idx)
video_test_idx = np.array(video_test_idx)

audio_q = audio_joint[audio_test_idx]
video_q = video_joint[video_test_idx]

video_db = video_joint[video_test_idx]
audio_db = audio_joint[audio_test_idx]

audio_q_species = meta_audio.loc[audio_test_idx, "species"].to_numpy()
video_q_species = meta_video.loc[video_test_idx, "species"].to_numpy()

video_db_species = meta_video.loc[video_test_idx, "species"].to_numpy()
audio_db_species = meta_audio.loc[audio_test_idx, "species"].to_numpy()

audio_q_paths = meta_audio.loc[audio_test_idx, path_audio].to_numpy()
video_q_paths = meta_video.loc[video_test_idx, path_video].to_numpy()

video_db_paths = meta_video.loc[video_test_idx, path_video].to_numpy()
audio_db_paths = meta_audio.loc[audio_test_idx, path_audio].to_numpy()

# Metrics
metrics_a2v = cross_modal_metrics(audio_q, audio_q_species, video_db, video_db_species)
metrics_v2a = cross_modal_metrics(video_q, video_q_species, audio_db, audio_db_species)

chance_a2v = average_chance(audio_q_species, video_db_species)
chance_v2a = average_chance(video_q_species, audio_db_species)

print("Audio -> Video (held-out):",
      f"R@1={metrics_a2v['recall'][1]:.3f}",
      f"R@5={metrics_a2v['recall'][5]:.3f}",
      f"R@10={metrics_a2v['recall'][10]:.3f}",
      f"P@1={metrics_a2v['precision'][1]:.3f}",
      f"P@5={metrics_a2v['precision'][5]:.3f}",
      f"P@10={metrics_a2v['precision'][10]:.3f}",
      f"median_rank={metrics_a2v['median_rank']:.1f}")
print("Audio -> Video chance (avg species freq in DB):", f"{chance_a2v:.3f}")

print("Video -> Audio (held-out):",
      f"R@1={metrics_v2a['recall'][1]:.3f}",
      f"R@5={metrics_v2a['recall'][5]:.3f}",
      f"R@10={metrics_v2a['recall'][10]:.3f}",
      f"P@1={metrics_v2a['precision'][1]:.3f}",
      f"P@5={metrics_v2a['precision'][5]:.3f}",
      f"P@10={metrics_v2a['precision'][10]:.3f}",
      f"median_rank={metrics_v2a['median_rank']:.1f}")
print("Video -> Audio chance (avg species freq in DB):", f"{chance_v2a:.3f}")

# Qualitative retrieval examples
save_cross_modal_examples(
    audio_q,
    audio_q_species,
    audio_q_paths,
    video_db,
    video_db_species,
    video_db_paths,
    out_dir / "cross_modal_examples_audio_to_video.csv",
    seed=seed,
)

save_cross_modal_examples(
    video_q,
    video_q_species,
    video_q_paths,
    audio_db,
    audio_db_species,
    audio_db_paths,
    out_dir / "cross_modal_examples_video_to_audio.csv",
    seed=seed,
)


Audio -> Video (held-out): R@1=0.138 R@5=0.262 R@10=0.323 P@1=0.138 P@5=0.098 P@10=0.080 median_rank=17.0
Audio -> Video chance (avg species freq in DB): 0.032
Video -> Audio (held-out): R@1=0.076 R@5=0.304 R@10=0.481 P@1=0.076 P@5=0.094 P@10=0.090 median_rank=14.0
Video -> Audio chance (avg species freq in DB): 0.032
Saved examples: /Users/wendycao/fish/processed/joint_embeddings_projection/cross_modal_examples_audio_to_video.csv
Saved examples: /Users/wendycao/fish/processed/joint_embeddings_projection/cross_modal_examples_video_to_audio.csv


  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[qi] @ DB.T
  sims = Q[qi] @ DB.T
  sims = Q[qi] @ DB.T


In [44]:
# Padding baseline (not a learned joint space)

# Raw embeddings (audio 1280-D, video 768-D)
audio_raw = audio_emb.astype(np.float32)
video_raw = video_emb.astype(np.float32)

if audio_raw.ndim != 2 or video_raw.ndim != 2:
    raise ValueError(f"Expected 2D embeddings, got audio {audio_raw.shape}, video {video_raw.shape}")
if audio_raw.shape[1] != 1280:
    print(f"Warning: expected audio dim 1280, got {audio_raw.shape[1]}")
if video_raw.shape[1] != 768:
    print(f"Warning: expected video dim 768, got {video_raw.shape[1]}")

def pad_video_to_dim(video_emb, target_dim=1280):
    n, d = video_emb.shape
    if d > target_dim:
        raise ValueError(f"Video dim {d} exceeds target dim {target_dim}")
    if d == target_dim:
        return video_emb
    pad = target_dim - d
    return np.pad(video_emb, ((0, 0), (0, pad)), mode='constant')

video_padded = pad_video_to_dim(video_raw, target_dim=audio_raw.shape[1])

# Full-set retrieval (no training, so evaluate on all samples)
audio_q_pad = audio_raw
video_q_pad = video_padded

video_db_pad = video_padded
audio_db_pad = audio_raw

audio_q_species_full = meta_audio["species"].to_numpy()
video_q_species_full = meta_video["species"].to_numpy()

video_db_species_full = meta_video["species"].to_numpy()
audio_db_species_full = meta_audio["species"].to_numpy()

# Metrics
metrics_pad_a2v = cross_modal_metrics(audio_q_pad, audio_q_species_full, video_db_pad, video_db_species_full)
metrics_pad_v2a = cross_modal_metrics(video_q_pad, video_q_species_full, audio_db_pad, audio_db_species_full)

print("Padding baseline (not a learned joint space)")
print("Audio -> Video (padded, full set)",
      f"R@1={metrics_pad_a2v['recall'][1]:.3f}",
      f"R@5={metrics_pad_a2v['recall'][5]:.3f}",
      f"R@10={metrics_pad_a2v['recall'][10]:.3f}",
      f"median_rank={metrics_pad_a2v['median_rank']:.1f}")
print("Video -> Audio (padded, full set)",
      f"R@1={metrics_pad_v2a['recall'][1]:.3f}",
      f"R@5={metrics_pad_v2a['recall'][5]:.3f}",
      f"R@10={metrics_pad_v2a['recall'][10]:.3f}",
      f"median_rank={metrics_pad_v2a['median_rank']:.1f}")

# Shuffle sanity check: permute padded video embeddings
rng = np.random.default_rng(seed)

perm_db = rng.permutation(len(video_db_pad))
video_db_pad_shuf = video_db_pad[perm_db]
metrics_pad_a2v_shuf = cross_modal_metrics(audio_q_pad, audio_q_species_full, video_db_pad_shuf, video_db_species_full)

perm_q = rng.permutation(len(video_q_pad))
video_q_pad_shuf = video_q_pad[perm_q]
metrics_pad_v2a_shuf = cross_modal_metrics(video_q_pad_shuf, video_q_species_full, audio_db_pad, audio_db_species_full)

print("Shuffle sanity check (padded video permuted)")
print("Audio -> Video (shuffled, full set)",
      f"R@1={metrics_pad_a2v_shuf['recall'][1]:.3f}",
      f"R@5={metrics_pad_a2v_shuf['recall'][5]:.3f}",
      f"R@10={metrics_pad_a2v_shuf['recall'][10]:.3f}",
      f"median_rank={metrics_pad_a2v_shuf['median_rank']:.1f}")
print("Video -> Audio (shuffled, full set)",
      f"R@1={metrics_pad_v2a_shuf['recall'][1]:.3f}",
      f"R@5={metrics_pad_v2a_shuf['recall'][5]:.3f}",
      f"R@10={metrics_pad_v2a_shuf['recall'][10]:.3f}",
      f"median_rank={metrics_pad_v2a_shuf['median_rank']:.1f}")

# Qualitative retrieval examples (padding baseline)
save_cross_modal_examples(
    audio_q_pad,
    audio_q_species_full,
    meta_audio[get_path_column(meta_audio)].to_numpy(),
    video_db_pad,
    video_db_species_full,
    meta_video[get_path_column(meta_video)].to_numpy(),
    out_dir / "padding_baseline_examples_audio_to_video.csv",
    seed=seed,
    n_samples=20,
    topk=5,
)

save_cross_modal_examples(
    video_q_pad,
    video_q_species_full,
    meta_video[get_path_column(meta_video)].to_numpy(),
    audio_db_pad,
    audio_db_species_full,
    meta_audio[get_path_column(meta_audio)].to_numpy(),
    out_dir / "padding_baseline_examples_video_to_audio.csv",
    seed=seed,
    n_samples=20,
    topk=5,
)


Padding baseline (not a learned joint space)
Audio -> Video (padded, full set) R@1=0.009 R@5=0.087 R@10=0.132 median_rank=34.0
Video -> Audio (padded, full set) R@1=0.011 R@5=0.091 R@10=0.325 median_rank=19.0
Shuffle sanity check (padded video permuted)
Audio -> Video (shuffled, full set) R@1=0.046 R@5=0.169 R@10=0.265 median_rank=27.0
Video -> Audio (shuffled, full set) R@1=0.030 R@5=0.091 R@10=0.257 median_rank=26.0
Saved examples: /Users/wendycao/fish/processed/joint_embeddings_projection/padding_baseline_examples_audio_to_video.csv
Saved examples: /Users/wendycao/fish/processed/joint_embeddings_projection/padding_baseline_examples_video_to_audio.csv


  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[i:i+batch_size] @ DB.T
  sims = Q[qi] @ DB.T
  sims = Q[qi] @ DB.T
  sims = Q[qi] @ DB.T


In [41]:
# Cluster metrics (AMI/NMI)

# Ensure base embeddings are available
if 'audio_base' not in globals():
    audio_base = l2norm(audio_emb.astype(np.float32))
if 'video_base' not in globals():
    video_base = l2norm(video_emb.astype(np.float32))

le_audio = LabelEncoder().fit(species_audio)
y_audio = le_audio.transform(species_audio)
le_video = LabelEncoder().fit(species_video)
y_video = le_video.transform(species_video)

print("=== AMI/NMI ===")

# Audio-only
ami_a, nmi_a = cluster_metrics(audio_base, y_audio, n_clusters=len(le_audio.classes_))
print(f"audio-only | AMI={ami_a:.4f}  NMI={nmi_a:.4f}")

# Video-only
ami_v, nmi_v = cluster_metrics(video_base, y_video, n_clusters=len(le_video.classes_))
print(f"video-only | AMI={ami_v:.4f}  NMI={nmi_v:.4f}")

# Learned joint (audio and video separately)
ami_ja, nmi_ja = cluster_metrics(audio_joint, y_audio, n_clusters=len(le_audio.classes_))
ami_jv, nmi_jv = cluster_metrics(video_joint, y_video, n_clusters=len(le_video.classes_))
print(f"joint(audio) | AMI={ami_ja:.4f}  NMI={nmi_ja:.4f}")
print(f"joint(video) | AMI={ami_jv:.4f}  NMI={nmi_jv:.4f}")

# Optional combined joint
combined_joint = np.vstack([audio_joint, video_joint])
combined_labels = np.concatenate([y_audio, y_video])
ami_c, nmi_c = cluster_metrics(combined_joint, combined_labels, n_clusters=len(np.unique(combined_labels)))
print(f"joint(combined) | AMI={ami_c:.4f}  NMI={nmi_c:.4f}")

# Concat baseline if available
if concat_available:
    le_combo = LabelEncoder().fit(concat_species)
    y_combo = le_combo.transform(concat_species)
    combo = np.concatenate([l2norm(concat_audio.astype(np.float32)), l2norm(concat_video.astype(np.float32))], axis=1)
    ami_c2, nmi_c2 = cluster_metrics(combo, y_combo, n_clusters=len(le_combo.classes_))
    print(f"concat | AMI={ami_c2:.4f}  NMI={nmi_c2:.4f}")
else:
    print("concat | skipped (no aligned pairs)")


=== AMI/NMI ===


  ret = a @ b
  ret = a @ b
  ret = a @ b


audio-only | AMI=0.0944  NMI=0.5638


  ret = a @ b
  ret = a @ b
  ret = a @ b


video-only | AMI=0.4630  NMI=0.7196
joint(audio) | AMI=0.5754  NMI=0.8135
joint(video) | AMI=0.7605  NMI=0.8758


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


joint(combined) | AMI=0.6963  NMI=0.8076


  ret = a @ b
  ret = a @ b
  ret = a @ b


concat | AMI=0.3722  NMI=0.7143


In [42]:
# Linear classifier evaluation (40% train, stratified)

print("\n=== Linear Probe (LogReg; 40% train) ===")

acc, f1 = linear_probe(audio_base, y_audio, train_frac=0.4, seed=seed)
print(f"audio-only | acc={acc:.4f}  macroF1={f1:.4f}")

acc, f1 = linear_probe(video_base, y_video, train_frac=0.4, seed=seed)
print(f"video-only | acc={acc:.4f}  macroF1={f1:.4f}")

acc, f1 = linear_probe(audio_joint, y_audio, train_frac=0.4, seed=seed)
print(f"joint(audio) | acc={acc:.4f}  macroF1={f1:.4f}")

acc, f1 = linear_probe(video_joint, y_video, train_frac=0.4, seed=seed)
print(f"joint(video) | acc={acc:.4f}  macroF1={f1:.4f}")

if concat_available:
    le_combo = LabelEncoder().fit(concat_species)
    y_combo = le_combo.transform(concat_species)
    combo = np.concatenate([l2norm(concat_audio.astype(np.float32)), l2norm(concat_video.astype(np.float32))], axis=1)
    acc, f1 = linear_probe(combo, y_combo, train_frac=0.4, seed=seed)
    print(f"concat | acc={acc:.4f}  macroF1={f1:.4f}")
else:
    print("concat | skipped (no aligned pairs)")



=== Linear Probe (LogReg; 40% train) ===
linear_probe: dropped 3 samples from classes with <2 examples


python(89606) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89607) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89608) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89609) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89610) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89611) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89612) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89613) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89614) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(89615) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
  raw_prediction = X

audio-only | acc=0.1615  macroF1=0.0896
linear_probe: dropped 3 samples from classes with <2 examples


  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  ret = a @ b
  ret = a @ b
  ret = a @ b


video-only | acc=0.5696  macroF1=0.4036
linear_probe: dropped 3 samples from classes with <2 examples


  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  ret = a @ b
  ret = a @ b
  ret = a @ b
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  ret = a @ b
  ret = a @ b
  ret = a @ b


joint(audio) | acc=0.6231  macroF1=0.4654
linear_probe: dropped 3 samples from classes with <2 examples
joint(video) | acc=0.7595  macroF1=0.5994
linear_probe: dropped 3 samples from classes with <2 examples


  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  raw_prediction = X @ weights.T + intercept  # ndarray, likely C-contiguous
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
  grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights


concat | acc=0.3692  macroF1=0.2603


  ret = a @ b
  ret = a @ b
  ret = a @ b
