# ECAPA-TDNN + TCAV on Mel Concepts

This notebook:
- Loads a pretrained SpeechBrain ECAPA-TDNN model
- Uses synthetic mel-based concepts (rising / dropping / constant lines, etc.)
- Computes CAVs on 3 internal SE-Res2 blocks
- Computes TCAV scores:
  - Globally (all 90 recordings)
  - Per speaker (eden / idan / yoav)
- Saves results to `tcav_results.csv`
- Draws heatmaps for:
  - Concept × Layer (global)
  - Concept × Layer per speaker


In [7]:
from pathlib import Path
import sys
import numpy as np
import torch

from tqdm import tqdm
from sklearn.linear_model import SGDClassifier

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Project root = parent of this notebook folder
ROOT = Path("..").resolve()

# Add concept folder to Python path
sys.path.append(str(ROOT / "concept"))

CONCEPT_ROOT = ROOT / "concept" / "positive concepts dataset"
DATA_DIR = ROOT / "data"

from Preprocess import audio_to_mel_spectrogram
from PreprocessParams import FREQUENCY_BIN_COUNT

np.random.seed(42)
torch.manual_seed(42)


if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running on GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Running on CPU")



Running on CPU


In [None]:
import torchaudio

# Ensure torchaudio backend works
try:
    print("Torchaudio backends:", torchaudio.list_audio_backends())
except Exception as e:
    print("Warning: torchaudio backend issue:", e)
    print("Continuing with soundfile backend only...")

from speechbrain.pretrained import EncoderClassifier
from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN

print("Loading SpeechBrain ECAPA-TDNN...")

try:
    # The SAFE way to load SpeechBrain models
    model = EncoderClassifier.from_hparams(
        source="speechbrain/spkrec-ecapa-voxceleb",
        savedir=str(ROOT / "ecapa_pretrained"),
        run_opts={"device": str(device)},  # <-- CRITICAL
    )
except Exception as e:
    print("ERROR loading ECAPA:", e)
    raise

# Extract ECAPA backbone
backbone: ECAPA_TDNN = model.mods["embedding_model"]
backbone.to(device)
backbone.eval()

print("ECAPA loaded successfully.")
print("Running on device:", device)
print()
print("Printing backbone architecture:")
print(backbone)
print()
print("Backbone blocks:")
print(backbone.blocks)


AttributeError: module 'torchaudio' has no attribute 'list_audio_backends'

In [None]:
def mel_to_input(mel: np.ndarray) -> torch.Tensor:
    """
    mel: (F, T) = (64, T) → Tensor (1, T, 64) on device
    """
    mel = torch.tensor(mel, dtype=torch.float32)  # (F, T)
    return mel.T.unsqueeze(0).to(device)  # (1, T, F)


def get_activation_from_layer(mel: np.ndarray, layer: torch.nn.Module) -> np.ndarray:
    """
    Forward a single mel through ECAPA backbone and capture
    the output activation of the given layer (no gradient).
    """
    activations = {}

    def hook_fn(_, __, out):
        activations["A"] = out.detach().cpu().numpy()

    handle = layer.register_forward_hook(hook_fn)
    x = mel_to_input(mel)
    with torch.no_grad():
        _ = backbone(x)
    handle.remove()

    return activations["A"]  # shape (1, C, T') or (1, T', C)


def activation_to_vec(A: np.ndarray) -> np.ndarray:
    """
    Convert a layer activation to a 1D vector by averaging over time.
    """
    A = A.squeeze(0)  # remove batch → (C, T') or (T', C)
    if A.ndim != 2:
        raise RuntimeError(f"Expected 2D activation, got shape {A.shape}")

    # If first dim is channels
    if A.shape[0] < A.shape[1]:
        return A.mean(axis=1)  # (C,)
    else:
        return A.mean(axis=0)  # (C,)


In [None]:
# All concept subfolders in positive concepts dataset
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])

print("Found concept dirs:")
for d in concept_dirs:
    print(" -", d.name)

if len(concept_dirs) < 2:
    raise RuntimeError("Need at least 2 concept folders for TCAV (pos vs neg).")


In [None]:
from collections import defaultdict

def load_eval_mels_and_group():
    wavs = sorted(DATA_DIR.glob("*.wav"))
    if len(wavs) == 0:
        raise RuntimeError(f"No wav files found in {DATA_DIR}")

    wav_files = wavs
    eval_mels = []

    for w in wav_files:
        mel = audio_to_mel_spectrogram(w)
        if mel.shape[0] != FREQUENCY_BIN_COUNT:
            raise RuntimeError(f"Mel dimension mismatch in {w}, got {mel.shape}")
        eval_mels.append(mel)

    # Group by speaker prefix: eden / idan / yoav
    speaker_mels = defaultdict(list)
    speaker_wavs = defaultdict(list)

    for w, mel in zip(wav_files, eval_mels):
        name = w.name.lower()
        if name.startswith("eden"):
            speaker = "eden"
        elif name.startswith("idan"):
            speaker = "idan"
        elif name.startswith("yoav"):
            speaker = "yoav"
        else:
            speaker = "other"
        speaker_mels[speaker].append(mel)
        speaker_wavs[speaker].append(w.name)

    return wav_files, eval_mels, speaker_mels, speaker_wavs

wav_files, eval_mels, speaker_mels, speaker_wavs = load_eval_mels_and_group()

print("Total eval mels:", len(eval_mels))
for spk, files in speaker_wavs.items():
    print(f"Speaker {spk}: {len(files)} files")


In [None]:
def build_cav_for_concept_and_layer(layer_name: str,
                                    concept_dir: Path,
                                    all_concept_dirs: list[Path]) -> np.ndarray:
    """
    Build a CAV vector for:
      - positive examples from concept_dir
      - negative examples from all other concept dirs (sampled)
    """
    layer = TARGET_LAYERS[layer_name]

    pos_paths = sorted(concept_dir.glob("*.npy"))
    if len(pos_paths) == 0:
        raise RuntimeError(f"No .npy files in {concept_dir}")

    # Collect all negative npy files from other concept dirs
    neg_paths_all = []
    for d in all_concept_dirs:
        if d == concept_dir:
            continue
        neg_paths_all.extend(sorted(d.glob("*.npy")))

    if len(neg_paths_all) == 0:
        raise RuntimeError("No negative concept samples found.")

    # Sample negatives to have roughly balanced classes
    n_pos = len(pos_paths)
    n_neg = min(len(neg_paths_all), n_pos * 2)  # up to 2x negatives
    neg_indices = np.random.choice(len(neg_paths_all), size=n_neg, replace=False)
    neg_paths = [neg_paths_all[i] for i in neg_indices]

    X = []
    Y = []

    # Positive = concept
    for p in pos_paths:
        mel = np.load(p)
        A = get_activation_from_layer(mel, layer)
        X.append(activation_to_vec(A))
        Y.append(1)

    # Negative = all other concepts
    for p in neg_paths:
        mel = np.load(p)
        A = get_activation_from_layer(mel, layer)
        X.append(activation_to_vec(A))
        Y.append(0)

    X = np.vstack(X)
    Y = np.array(Y)

    clf = SGDClassifier(loss="hinge", alpha=1e-4, max_iter=2000)
    clf.fit(X, Y)

    cav = clf.coef_.reshape(-1)
    cav /= (np.linalg.norm(cav) + 1e-8)
    return cav


In [None]:
def tcav_score_all(layer_name: str, cav: np.ndarray) -> float:
    """
    Compute TCAV score on all eval_mels:
    fraction of examples where directional derivative > 0.
    """
    layer = TARGET_LAYERS[layer_name]
    positives = 0
    total = 0

    for mel in eval_mels:
        x = mel_to_input(mel)
        x.requires_grad_(True)

        activations = {}

        def hook_fn(_, __, out):
            out.retain_grad()
            activations["A"] = out

        handle = layer.register_forward_hook(hook_fn)
        out = backbone(x)
        handle.remove()

        # simple scalar "task" = norm of embedding
        loss = out.norm()
        loss.backward()

        grad = activations["A"].grad.detach().cpu().numpy().squeeze(0)

        # time-mean
        if grad.shape[0] < grad.shape[1]:
            g_vec = grad.mean(axis=1)
        else:
            g_vec = grad.mean(axis=0)

        directional_derivative = np.dot(g_vec, cav)
        positives += (directional_derivative > 0)
        total += 1

        # clear gradients for safety
        backbone.zero_grad()
        model.zero_grad()

    return positives / total if total > 0 else 0.0


In [None]:
def tcav_score_for_speaker(layer_name: str,
                           cav: np.ndarray,
                           speaker: str) -> float:
    """
    TCAV score restricted to a given speaker's mel list.
    """
    if speaker not in speaker_mels:
        raise RuntimeError(f"No mels for speaker '{speaker}'")

    layer = TARGET_LAYERS[layer_name]
    mels = speaker_mels[speaker]

    positives = 0
    total = 0

    for mel in mels:
        x = mel_to_input(mel)
        x.requires_grad_(True)

        activations = {}

        def hook_fn(_, __, out):
            out.retain_grad()
            activations["A"] = out

        handle = layer.register_forward_hook(hook_fn)
        out = backbone(x)
        handle.remove()

        loss = out.norm()
        loss.backward()

        grad = activations["A"].grad.detach().cpu().numpy().squeeze(0)

        if grad.shape[0] < grad.shape[1]:
            g_vec = grad.mean(axis=1)
        else:
            g_vec = grad.mean(axis=0)

        directional_derivative = np.dot(g_vec, cav)
        positives += (directional_derivative > 0)
        total += 1

        backbone.zero_grad()
        model.zero_grad()

    return positives / total if total > 0 else 0.0


In [None]:
results_global = []   # list of dicts: {concept, layer, tcav}
cavs = {}            # (concept_name, layer_name) -> cav vector

for cdir in tqdm(concept_dirs, desc="Concepts"):
    cname = cdir.name
    print(f"\n=== Concept: {cname} ===")

    for layer_name in TARGET_LAYERS.keys():
        print(f"  Building CAV for layer {layer_name} ...")
        cav = build_cav_for_concept_and_layer(layer_name, cdir, concept_dirs)
        cavs[(cname, layer_name)] = cav

        print("  Computing global TCAV...")
        score = tcav_score_all(layer_name, cav)
        results_global.append({
            "Concept": cname,
            "Layer": layer_name,
            "Speaker": "all",
            "TCAV": score,
        })
        print(f"    TCAV(all) = {score:.3f}")


In [None]:
results_speaker = []

for (cname, layer_name), cav in cavs.items():
    for speaker in ["eden", "idan", "yoav"]:
        if speaker not in speaker_mels or len(speaker_mels[speaker]) == 0:
            continue

        s_score = tcav_score_for_speaker(layer_name, cav, speaker)
        results_speaker.append({
            "Concept": cname,
            "Layer": layer_name,
            "Speaker": speaker,
            "TCAV": s_score,
        })
        print(f"{cname} | {layer_name} | {speaker} → {s_score:.3f}")


In [None]:
df_global = pd.DataFrame(results_global)
df_speaker = pd.DataFrame(results_speaker)

df_all = pd.concat([df_global, df_speaker], ignore_index=True)

csv_path = ROOT / "tcav_results.csv"
df_all.to_csv(csv_path, index=False)

print("Saved all TCAV results to:", csv_path)
df_all.head()


In [None]:
# Filter only global rows (Speaker == "all")
df_global_only = df_all[df_all["Speaker"] == "all"]

heatmap_df = df_global_only.pivot(
    index="Concept",
    columns="Layer",
    values="TCAV"
).sort_index()

plt.figure(figsize=(10, max(4, 0.4 * len(heatmap_df))))
sns.heatmap(
    heatmap_df,
    annot=True,
    cmap="viridis",
    fmt=".2f",
    linewidths=0.5
)
plt.title("Global TCAV Heatmap (all speakers)")
plt.tight_layout()
plt.show()


In [None]:
for speaker in ["eden", "idan", "yoav"]:
    df_spk = df_all[df_all["Speaker"] == speaker]
    if df_spk.empty:
        continue

    pivot_df = df_spk.pivot(
        index="Concept",
        columns="Layer",
        values="TCAV"
    ).sort_index()

    plt.figure(figsize=(10, max(4, 0.4 * len(pivot_df))))
    sns.heatmap(
        pivot_df,
        annot=True,
        cmap="mako",
        fmt=".2f",
        linewidths=0.5
    )
    plt.title(f"TCAV Heatmap – Speaker: {speaker}")
    plt.tight_layout()
    plt.show()
