In [71]:
# %%
from typing import Optional, Tuple, List
from pathlib import Path
import sys
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T

from torch.utils.data import Dataset, DataLoader
from captum.concept import TCAV, Concept
from wespeaker.cli.speaker import load_model

from tqdm import tqdm

In [72]:
# %%
# -------- Project paths / device --------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

TCAV_DEVICE = torch.device("cpu")
print("TCAV_DEVICE =", TCAV_DEVICE)

ATTR_CSV_PATH = Path(
    PROJECT_ROOT / "resnet_293" / "speaker_similarity_ranking_team.csv"
)
WAV_FOLDER = Path(PROJECT_ROOT / "data" / "wavs")
CONCEPT_ROOT = Path(PROJECT_ROOT / "concept" / "concepts_dataset_resnet_293")

HEAD_PATH = Path(PROJECT_ROOT / "data" / "heads" / "resnet_293_speaker_head.pt")

assert ATTR_CSV_PATH.exists(), f"Missing {ATTR_CSV_PATH}"
assert WAV_FOLDER.exists(), f"Missing {WAV_FOLDER}"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

# ====== choose TCAV layer ======
# LAYER_KEY = "layer3.63.conv3"  
# OUT_CSV = Path(f"concepts_dataset_resnet_293_{LAYER_KEY}.csv")

CONCEPT_SAMPLES = 100
RANDOM_SAMPLES = 100
BATCH_SIZE_CONCEPT = 1
FORCE_TRAIN_CAVS = False

PROJECT_ROOT = /home/SpeakerRec/BioVoice
TCAV_DEVICE = cpu


In [73]:
# %%
# -------- Load WeSpeaker ResNet-293 backbone --------
speaker = load_model(PROJECT_ROOT / "wespeaker-voxceleb-resnet293-LM")
backbone = speaker.model

backbone = backbone.to(TCAV_DEVICE).eval()

print("ResNet-293 backbone loaded")

{'data_type': 'shard', 'dataloader_args': {'batch_size': 32, 'drop_last': True, 'num_workers': 16, 'pin_memory': False, 'prefetch_factor': 8}, 'dataset_args': {'aug_prob': 0.6, 'fbank_args': {'dither': 1.0, 'frame_length': 25, 'frame_shift': 10, 'num_mel_bins': 80}, 'num_frms': 200, 'shuffle': True, 'shuffle_args': {'shuffle_size': 2500}, 'spec_aug': False, 'spec_aug_args': {'max_f': 8, 'max_t': 10, 'num_f_mask': 1, 'num_t_mask': 1, 'prob': 0.6}, 'speed_perturb': True}, 'exp_dir': 'exp/ResNet293-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150', 'gpus': [0, 1], 'log_batch_interval': 100, 'loss': 'CrossEntropyLoss', 'loss_args': {}, 'margin_scheduler': 'MarginScheduler', 'margin_update': {'epoch_iter': 17062, 'final_margin': 0.2, 'fix_start_epoch': 40, 'increase_start_epoch': 20, 'increase_type': 'exp', 'initial_margin': 0.0, 'update_margin': True}, 'model': 'ResNet293', 'model_args': {'embed_dim': 256, 'feat_dim': 80, 'pooling_func': 'TSTP', 'two_emb_layer':

  checkpoint = torch.load(path, map_location="cpu")


ResNet-293 backbone loaded


In [74]:
# %%
# -------- Load your speaker head ckpt --------
ckpt = torch.load(HEAD_PATH, map_location="cpu")

speaker_to_id = ckpt["speaker_to_id"]
SPEAKERS = ckpt.get("speakers", list(speaker_to_id.keys()))

# build reverse mapping (THIS IS THE FIX)
id_to_speaker = {i: s for s, i in speaker_to_id.items()}

fc_w = ckpt["state_dict"]["fc.weight"]
in_dim = int(fc_w.shape[1])
num_classes = int(fc_w.shape[0])

print("Loaded head:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("Head in_dim:", in_dim, "num_classes:", num_classes)


class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)


head = SpeakerHead(in_dim=in_dim, num_classes=num_classes).to(TCAV_DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()

Loaded head: /home/SpeakerRec/BioVoice/data/heads/resnet_293_speaker_head.pt
Speakers: ['eden', 'idan', 'yoav']
Head in_dim: 256 num_classes: 3


  ckpt = torch.load(HEAD_PATH, map_location="cpu")


SpeakerHead(
  (fc): Linear(in_features=256, out_features=3, bias=True)
)

In [75]:
# %%
class WeSpeakerForTCAV(nn.Module):
    """
    TCAV-safe wrapper:
    - Uses WeSpeaker backbone
    - DOES NOT apply the speaker head
    - Returns fake logits with correct shape
    """

    def __init__(self, backbone: nn.Module, num_classes: int):
        super().__init__()
        self.backbone = backbone
        self.num_classes = num_classes

    def forward(self, x):
        """
        x: (B, T, F)
        """

        out = self.backbone(x)

        # unwrap wespeaker output
        if isinstance(out, (tuple, list)):
            emb = out[0]
        else:
            emb = out

        # flatten anything strange
        if emb.ndim == 0:
            emb = emb.view(1, 1)
        elif emb.ndim == 1:
            emb = emb.unsqueeze(0)
        elif emb.ndim > 2:
            emb = emb.reshape(emb.size(0), -1)

        B = emb.shape[0]

        # ---- CRITICAL PART ----
        # TCAV only needs gradients + target index.
        # So return dummy logits with correct batch size.
        logits = emb[:, :self.num_classes]


        return logits

wrapped_model = (
    WeSpeakerForTCAV(
        backbone=speaker.model,
        num_classes=len(SPEAKERS),
    )
    .to(TCAV_DEVICE)
    .eval()
)



# For TCAV we DO need gradients through backbone
for p in wrapped_model.backbone.parameters():
    p.requires_grad_(True)

# # head grads not needed
# for p in wrapped_model.head.parameters():
#     p.requires_grad_(False)

print("wrapped_model ready on", next(wrapped_model.parameters()).device)

wrapped_model ready on cpu


In [76]:
# %%
# -------- Resolve layer name strings for Captum (ALL layers) --------

TARGET_LAYERS = {
    "layer1.9.conv3": wrapped_model.backbone.layer1[9].conv3,
    "layer2.19.conv3": wrapped_model.backbone.layer2[19].conv3,
    "layer3.63.conv3": wrapped_model.backbone.layer3[63].conv3,
    "layer4.2.conv3": wrapped_model.backbone.layer4[2].conv3,
}


def module_name_in_model(model: nn.Module, target_module: nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError("Could not find target module in wrapped_model.named_modules()")


# Resolve ALL layer names
LAYER_NAMES = []
for key, module in TARGET_LAYERS.items():
    layer_name = module_name_in_model(wrapped_model, module)
    LAYER_NAMES.append(layer_name)
    print(f"{key} -> {layer_name}")

print("\nUsing layers for TCAV:")
for ln in LAYER_NAMES:
    print(" ", ln)

layer1.9.conv3 -> backbone.layer1.9.conv3
layer2.19.conv3 -> backbone.layer2.19.conv3
layer3.63.conv3 -> backbone.layer3.63.conv3
layer4.2.conv3 -> backbone.layer4.2.conv3

Using layers for TCAV:
  backbone.layer1.9.conv3
  backbone.layer2.19.conv3
  backbone.layer3.63.conv3
  backbone.layer4.2.conv3


In [77]:
# %%
# -------- Infer N_MELS + TARGET_FRAMES from your concepts (avoid mismatch bugs) --------
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")


def infer_mels_and_frames(concept_dirs: List[Path]) -> Tuple[int, int]:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is None:
            continue
        arr = np.load(f)
        if arr.ndim != 2:
            raise RuntimeError(
                f"Concept file {f} expected 2D [MELS, FRAMES], got {arr.shape}"
            )
        return int(arr.shape[0]), int(arr.shape[1])
    raise RuntimeError("Could not infer mel bins/frames from concept dirs")


N_MELS, TARGET_FRAMES = infer_mels_and_frames(concept_dirs)
print("Inferred from concepts: N_MELS =", N_MELS, "TARGET_FRAMES =", TARGET_FRAMES)

concept_names = [d.name for d in concept_dirs]
print("Concepts:", concept_names)

Inferred from concepts: N_MELS = 80 TARGET_FRAMES = 304
Concepts: ['long_constant_thick', 'long_constant_thick_Vibrato', 'long_dropping_flat_thick', 'long_dropping_flat_thick_Vibrato', 'long_dropping_steep_thick', 'long_dropping_steep_thin', 'long_rising_flat_thick', 'long_rising_steep_thick', 'long_rising_steep_thin', 'short_constant_thick', 'short_dropping_steep_thick', 'short_dropping_steep_thin', 'short_rising_steep_thick', 'short_rising_steep_thin']


In [78]:
# %%
# -------- Audio -> mel pipeline (MUST match concept space) --------
SAMPLE_RATE = 16000
FRAME_LENGTH_MS = 25
FRAME_SHIFT_MS = 10
WIN_LENGTH = int(SAMPLE_RATE * FRAME_LENGTH_MS / 1000)  # 400
HOP_LENGTH = int(SAMPLE_RATE * FRAME_SHIFT_MS / 1000)  # 160
N_FFT = WIN_LENGTH  # keep same as before (works for torchaudio)

mel_transform = T.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=N_MELS,
    n_fft=N_FFT,
    win_length=WIN_LENGTH,
    hop_length=HOP_LENGTH,
    center=True,  # אם אצלך ב-Activation-CAM זה עבד טוב, תשאיר True
    power=2.0,
).to(TCAV_DEVICE)


def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """
    mel_3d: [1, N_MELS, T]
    returns: [1, N_MELS, target_frames]
    """
    Tcur = int(mel_3d.shape[-1])
    if Tcur == target_frames:
        return mel_3d
    if Tcur > target_frames:
        start = (Tcur - target_frames) // 2
        return mel_3d[..., start : start + target_frames]
    pad = target_frames - Tcur
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)


def postprocess_like_concepts(mel: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    mel: [1, N_MELS, T] (energy)
    -> log + CMN over time (per mel bin)
    """
    mel = torch.clamp(mel, min=0.0)
    mel = torch.log(mel + eps)
    mel = mel - mel.mean(dim=-1, keepdim=True)
    return mel


def wav_path_to_tcav_input(path: Path) -> torch.Tensor:
    """
    Returns input for TCAV / WeSpeaker backbone
    Shape: [1, T, F]
    """
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1].float()

    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)

    wav = wav.to(TCAV_DEVICE)

    mel = mel_transform(wav)  # [1, F, T]
    mel = postprocess_like_concepts(mel)
    mel = fix_mel_frames(mel, TARGET_FRAMES)  # [1, F, T]

    x_tf = mel.squeeze(0).transpose(0, 1)  # [T, F]
    return x_tf.unsqueeze(0)  # [1, T, F]

In [79]:
# %%
# -------- Concept datasets (FIXED for WeSpeaker ResNet293) --------
# WeSpeaker expects x shape = (B, T, F)  (time first)
# Your saved .npy concepts are (F, T) so we transpose -> (T, F).
# Also: DO NOT unsqueeze(0) here, because DataLoader already creates batch dim.


class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: Optional[int] = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel_ft = np.load(self.files[idx]).astype(np.float32)  # (F, T)
        if mel_ft.ndim != 2:
            raise RuntimeError(
                f"{self.files[idx].name}: expected 2D (F,T), got {mel_ft.shape}"
            )

        F_bins, T_frames = mel_ft.shape
        if F_bins != N_MELS:
            raise RuntimeError(
                f"{self.files[idx].name}: expected F={N_MELS} mel bins, got {F_bins}"
            )
        if T_frames != TARGET_FRAMES:
            raise RuntimeError(
                f"{self.files[idx].name}: expected T={TARGET_FRAMES} frames, got {T_frames}"
            )

        mel_tf = mel_ft.T  # (T, F)
        x = torch.from_numpy(mel_tf)  # (T, F) on CPU (or TCAV_DEVICE)
        return x.to(TCAV_DEVICE)


class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int):
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        # random in (T, F) because model expects (B, T, F)
        x = torch.randn(TARGET_FRAMES, N_MELS, dtype=torch.float32)
        return x.to(TCAV_DEVICE)

In [80]:
# %%
# -------- Build TCAV concepts --------
tcav = TCAV(wrapped_model, LAYER_NAMES, test_split_ratio=0.33)

positive_concepts: List[Concept] = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES)
rand_dl = DataLoader(
    rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0
)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]

print("Built experimental sets:", len(experimental_sets))

Built experimental sets: 14




In [81]:
# %%
# -------- Compute CAV accuracies (sanity) --------
def compute_cav_acc_df(
    tcav: TCAV, positive_concepts: List[Concept], random_concept: Concept
) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs(
        [[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS
    )

    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get("accs", None)
            if acc is None:
                acc = cav_obj.stats.get("acc", None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()

            rows.append(
                {
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "cav acc": float(acc) if acc is not None else np.nan,
                }
            )

    return pd.DataFrame(rows, columns=["concept name", "layer name", "cav acc"])


acc_df = compute_cav_acc_df(tcav, positive_concepts, random_concept)
print(acc_df.sort_values("cav acc", ascending=False).head(10))

  save_dict = torch.load(cavs_path)


                        concept name                layer name   cav acc
34            long_rising_steep_thin  backbone.layer3.63.conv3  0.480769
29           long_rising_steep_thick  backbone.layer2.19.conv3  0.480769
27            long_rising_flat_thick   backbone.layer4.2.conv3  0.480769
13  long_dropping_flat_thick_Vibrato  backbone.layer2.19.conv3  0.480769
10          long_dropping_flat_thick  backbone.layer3.63.conv3  0.480769
52           short_rising_steep_thin   backbone.layer1.9.conv3  0.442308
51          short_rising_steep_thick   backbone.layer4.2.conv3  0.442308
39              short_constant_thick   backbone.layer4.2.conv3  0.442308
22          long_dropping_steep_thin  backbone.layer3.63.conv3  0.442308
9           long_dropping_flat_thick  backbone.layer2.19.conv3  0.442308


In [82]:
# %%
# -------- Prediction helper (optional) --------


@torch.no_grad()
def predict_speaker_from_wav(path: Path) -> Tuple[str, float]:
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1].float()

    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)

    wav = wav.to(TCAV_DEVICE)

    # THIS is the correct embedding extraction
    emb = speaker.extract_embedding_from_pcm(wav, 16000)

    emb = emb / (emb.norm(p=2) + 1e-12)  # same as training
    emb = emb.unsqueeze(0) if emb.ndim == 1 else emb  # [1, D]

    logits = head(emb)
    probs = F.softmax(logits, dim=1)[0]

    pred_id = int(torch.argmax(probs).item())
    return id_to_speaker[pred_id], float(probs[pred_id].item())

In [85]:
# %%
# -------- Run TCAV on your ranked CSV --------
df_attr = pd.read_csv(ATTR_CSV_PATH)

if "path" not in df_attr.columns or "speaker" not in df_attr.columns:
    raise RuntimeError(
        f"CSV must contain columns ['path','speaker']. Got: {list(df_attr.columns)}"
    )

rows = []

for _, r in tqdm(df_attr.iterrows(), total=len(df_attr)):
    path = Path(r["path"])
    true_label = str(r["speaker"])

    if not path.exists():
        continue
    if true_label not in speaker_to_id:
        continue

    # 1) build mel in SAME space as concepts: (B,T,F)
    mel_ft = wav_path_to_mel3d(path)  # (1, F, T)
    x = mel_ft.transpose(1, 2)  # (1, T, F)

    # 2) make it require grad BUT not a leaf (avoid wespeaker in-place error)
    x = x.detach().requires_grad_(True)
    x = x * 1.0

    target_idx = speaker_to_id[true_label]

    # optional prediction (not used by TCAV graph)
    pred_label, pred_prob = predict_speaker_from_wav(path)

    # 3) TCAV interpret
    score_for_label = tcav.interpret(
        inputs=x,
        experimental_sets=experimental_sets,
        target=target_idx,
    )

    for exp_key, layer_dict in score_for_label.items():
        try:
            pos_idx = int(str(exp_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_idx < len(positive_concepts)):
            continue

        concept_name = positive_concepts[pos_idx].name

        for layer_name, metrics in layer_dict.items():
            sc = metrics.get("sign_count")
            mg = metrics.get("magnitude")
            if sc is None or mg is None:
                continue

            if isinstance(sc, torch.Tensor):
                sc = sc.detach().cpu().tolist()
            if isinstance(mg, torch.Tensor):
                mg = mg.detach().cpu().tolist()

            rows.append(
                {
                    "path": str(path),
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "positive percentage": float(sc[0]),
                    "magnitude": float(mg[0]),
                    "true label": true_label,
                    "predicted label": pred_label,
                    "predicted probability": float(pred_prob),
                }
            )

df_tcav = pd.DataFrame(
    rows,
    columns=[
        "path",
        "concept name",
        "layer name",
        "positive percentage",
        "magnitude",
        "true label",
        "predicted label",
        "predicted probability",
    ],
)

df_tcav = df_tcav.merge(acc_df, on=["concept name", "layer name"], how="left")
df_tcav.to_csv(OUT_CSV, index=False)

print("Saved →", OUT_CSV)
df_tcav.head()

  save_dict = torch.load(cavs_path)
  0%|          | 0/90 [00:02<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn