In [18]:
import sys
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio

from tqdm import tqdm
from sklearn.linear_model import SGDClassifier


In [19]:
# ---------- Paths / Device ----------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

np.random.seed(42)
torch.manual_seed(42)


PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda


<torch._C.Generator at 0x7f2162b0cd90>

In [20]:
# ---------- Load ReDimNet ----------
redim_model = (
    torch.hub.load(
        "IDRnD/ReDimNet",
        "ReDimNet",
        model_name="b5",
        train_type="ptn",
        dataset="vox2",
    )
    .to(DEVICE)
    .eval()
)
print("Loaded ReDimNet successfully.")


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet successfully.


In [21]:
# ---------- Infer N_MELS from model.spec ----------
with torch.no_grad():
    dummy_wav = torch.zeros(1, 16000, device=DEVICE)
    dummy_mel = redim_model.spec(dummy_wav)  # (1, N_MELS, T)
N_MELS = int(dummy_mel.shape[1])
print("ReDimNet spec N_MELS =", N_MELS)  # you expect 72


ReDimNet spec N_MELS = 72


In [22]:
# ---------- Load trained speaker head ----------
HEAD_PATH = Path.cwd() / "output" / "redim_speaker_head_linear.pt"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

ckpt = torch.load(HEAD_PATH, map_location=DEVICE)
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
SPEAKERS = list(speaker_to_id.keys())

print("Loaded speaker head from:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("L2-normalized embeddings:", ckpt.get("l2_norm_emb", False))


Loaded speaker head from: /home/SpeakerRec/BioVoice/redimnet/tcav/output/redim_speaker_head_linear.pt
Speakers: ['eden', 'idan', 'yoav']
L2-normalized embeddings: True


  ckpt = torch.load(HEAD_PATH, map_location=DEVICE)


In [23]:
class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

# IMPORTANT: in_dim must match ReDimNet linear output dim.
# In your snippet you used 192; keep it consistent with your trained head.
head = SpeakerHead(in_dim=192, num_classes=len(SPEAKERS)).to(DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()


SpeakerHead(
  (fc): Linear(in_features=192, out_features=3, bias=True)
)

In [24]:
class ReDimNetMelLogitsWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, N_MELS, T]
    Output: logits [B, num_speakers]
    """
    def __init__(self, redim_model, head, l2_norm_emb: bool = True):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool = redim_model.pool
        self.bn = redim_model.bn
        self.linear = redim_model.linear  # [B, emb_dim]
        self.head = head
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d: torch.Tensor) -> torch.Tensor:
        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)

        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)

        logits = self.head(emb)
        return logits

wrapped_model = ReDimNetMelLogitsWrapper(
    redim_model=redim_model,
    head=head,
    l2_norm_emb=True,
).to(DEVICE).eval()

print("Wrapper (logits) ready.")


Wrapper (logits) ready.


In [25]:
# ---------- Choose target layers (name -> module) ----------
# NOTE: indices [2] are based on what you had. If any index errors happen,
# print the modules and adjust accordingly.
TARGET_LAYERS = {
    # "stem":   wrapped_model.backbone.stem[0],
    # "stage0": wrapped_model.backbone.stage0[2],
    # "stage1": wrapped_model.backbone.stage1[2],
    # "stage2": wrapped_model.backbone.stage2[2],
    # "stage3": wrapped_model.backbone.stage3[2],
    "stage4": wrapped_model.backbone.stage4[2],
    # "stage5": wrapped_model.backbone.stage5[2],
}
print("Target layers:", list(TARGET_LAYERS.keys()))


Target layers: ['stage4']


In [26]:
# ---------- Data / Concepts ----------
CONCEPT_ROOT = PROJECT_ROOT / "concept" / "positive_concepts_dataset_72"
DATA_DIR = PROJECT_ROOT / "data" / "wavs"

print("CONCEPT_ROOT =", CONCEPT_ROOT)
print("DATA_DIR     =", DATA_DIR)

concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
print(f"Found {len(concept_dirs)} concept dirs")
if len(concept_dirs) < 2:
    raise RuntimeError("Need at least 2 concept folders (pos vs neg) for TCAV.")

wav_files = sorted(DATA_DIR.glob("*.wav"))
if not wav_files:
    raise RuntimeError(f"No wav files found in {DATA_DIR}")
print("Found wav files:", len(wav_files))


CONCEPT_ROOT = /home/SpeakerRec/BioVoice/concept/positive_concepts_dataset_72
DATA_DIR     = /home/SpeakerRec/BioVoice/data/wavs
Found 12 concept dirs
Found wav files: 90


In [27]:
# ---------- Helpers ----------
def speaker_from_filename(fname: str) -> str:
    n = fname.lower()
    if n.startswith("eden"):
        return "eden"
    if n.startswith("idan"):
        return "idan"
    if n.startswith("yoav"):
        return "yoav"
    return "other"

def wav_to_redim_mel_np(wav_path: Path) -> np.ndarray:
    wav, sr = torchaudio.load(str(wav_path))
    wav = wav[:1, :].float().to(DEVICE)
    with torch.no_grad():
        mel = redim_model.spec(wav)  # (1, N_MELS, T)
    return mel.squeeze(0).cpu().numpy().astype(np.float32)  # (N_MELS, T)

def mel_to_redim_input(mel: np.ndarray) -> torch.Tensor:
    """
    mel: (N_MELS, T) numpy
    returns: (1, 1, N_MELS, T) torch on DEVICE
    """
    if mel.ndim != 2:
        raise ValueError(f"Expected mel (F,T), got {mel.shape}")
    F, T = mel.shape
    if F != N_MELS:
        raise ValueError(f"Mel bins mismatch: expected {N_MELS}, got {F}")
    x = torch.from_numpy(mel).float().unsqueeze(0).unsqueeze(0)  # (1,1,F,T)
    return x.to(DEVICE)

def tensor_to_channel_vec(t: torch.Tensor) -> np.ndarray:
    """
    Convert activation/grad tensor (with batch) into a 1D channel vector (C,)
    by averaging across non-channel dimensions.
    Supports:
      (B,C,H,W) -> mean over H,W
      (B,C,T)   -> mean over T
      (B,T,C)   -> mean over T
      (B,C)     -> C
    """
    if t.ndim < 2:
        raise RuntimeError(f"Unexpected tensor shape: {tuple(t.shape)}")

    t = t[0]  # remove batch

    if t.ndim == 1:        # (C,)
        v = t
    elif t.ndim == 2:      # (C,T) or (T,C)
        v = t.mean(dim=1) if t.shape[0] <= t.shape[1] else t.mean(dim=0)
    elif t.ndim == 3:      # (C,H,W)
        v = t.mean(dim=(1, 2))
    else:
        raise RuntimeError(f"Unsupported tensor shape after batch removed: {tuple(t.shape)}")

    return v.detach().cpu().numpy().astype(np.float32)

def get_activation_vec(mel: np.ndarray, layer: nn.Module) -> np.ndarray:
    """
    Run wrapped_model and capture layer activation -> channel vec.
    """
    x = mel_to_redim_input(mel)
    store = {}

    def hook_fn(_m, _inp, out):
        store["act"] = out.detach()

    h = layer.register_forward_hook(hook_fn)
    with torch.no_grad():
        _ = wrapped_model(x)
    h.remove()

    if "act" not in store:
        raise RuntimeError("Hook did not capture activation.")
    return tensor_to_channel_vec(store["act"])


In [28]:
# ---------- Precompute eval mels (using model.spec for consistency) ----------
eval_items = []  # list of dicts: {wav_path, speaker, mel_np}
speaker_items = defaultdict(list)

for w in tqdm(wav_files, desc="Precomputing eval mels"):
    spk = speaker_from_filename(w.name)
    if spk == "other":
        continue
    if spk not in speaker_to_id:
        continue

    mel_np = wav_to_redim_mel_np(w)  # (N_MELS, T)
    if mel_np.shape[0] != N_MELS:
        raise RuntimeError(f"{w.name}: expected {N_MELS} mel bins, got {mel_np.shape}")

    item = {"wav": w, "speaker": spk, "mel": mel_np}
    eval_items.append(item)
    speaker_items[spk].append(item)

print("Eval items:", len(eval_items))
for spk, items in speaker_items.items():
    print(f"{spk}: {len(items)}")


Precomputing eval mels: 100%|██████████| 90/90 [00:00<00:00, 98.30it/s] 

Eval items: 90
eden: 30
idan: 30
yoav: 30





In [29]:
# ---------- Build CAV ----------
def build_cav_for_concept_and_layer(
    layer_name: str,
    concept_dir: Path,
    all_concept_dirs: list[Path],
) -> np.ndarray:
    layer = TARGET_LAYERS[layer_name]

    pos_paths = sorted(concept_dir.glob("*.npy"))
    if not pos_paths:
        raise RuntimeError(f"No .npy files in {concept_dir}")

    neg_paths_all = []
    for d in all_concept_dirs:
        if d == concept_dir:
            continue
        neg_paths_all.extend(sorted(d.glob("*.npy")))
    if not neg_paths_all:
        raise RuntimeError("No negative samples found in other concept dirs.")

    n_pos = len(pos_paths)
    n_neg = min(len(neg_paths_all), n_pos * 2)
    neg_paths = [neg_paths_all[i] for i in np.random.choice(len(neg_paths_all), n_neg, replace=False)]

    X, Y = [], []

    # Positive
    for p in pos_paths:
        mel = np.load(p).astype(np.float32)  # expected (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{p.name}: expected {N_MELS} mel bins, got {mel.shape}")
        X.append(get_activation_vec(mel, layer))
        Y.append(1)

    # Negative
    for p in neg_paths:
        mel = np.load(p).astype(np.float32)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{p.name}: expected {N_MELS} mel bins, got {mel.shape}")
        X.append(get_activation_vec(mel, layer))
        Y.append(0)

    X = np.vstack(X).astype(np.float32)
    Y = np.array(Y, dtype=np.int64)

    clf = SGDClassifier(loss="hinge", alpha=1e-4, max_iter=2000, tol=1e-3)
    clf.fit(X, Y)

    cav = clf.coef_.reshape(-1).astype(np.float32)
    cav /= (np.linalg.norm(cav) + 1e-8)
    return cav


In [30]:
# ---------- TCAV Score (directional derivative sign on class logit) ----------
def tcav_score_for_items(layer_name: str, cav: np.ndarray, items: list[dict]) -> float:
    layer = TARGET_LAYERS[layer_name]
    positives = 0
    total = 0

    for item in items:
        spk = item["speaker"]
        mel_np = item["mel"]

        x = mel_to_redim_input(mel_np)
        x.requires_grad_(True)

        store = {}

        def hook_fn(_m, _inp, out):
            out.retain_grad()
            store["act"] = out

        h = layer.register_forward_hook(hook_fn)
        logits = wrapped_model(x)  # (1, num_speakers)
        h.remove()

        cls = speaker_to_id[spk]
        scalar = logits[0, cls]

        wrapped_model.zero_grad()
        scalar.backward()

        if "act" not in store or store["act"].grad is None:
            raise RuntimeError("No gradient captured. Layer might not be connected to output.")

        grad_vec = tensor_to_channel_vec(store["act"].grad)
        dd = float(np.dot(grad_vec, cav))

        positives += (dd > 0.0)
        total += 1

    return positives / total if total else 0.0

def tcav_score_all(layer_name: str, cav: np.ndarray) -> float:
    return tcav_score_for_items(layer_name, cav, eval_items)

def tcav_score_for_speaker(layer_name: str, cav: np.ndarray, speaker: str) -> float:
    return tcav_score_for_items(layer_name, cav, speaker_items.get(speaker, []))


In [31]:
# ---------- Run TCAV ----------
results = []
cavs = {}

for cdir in tqdm(concept_dirs, desc="Concepts"):
    cname = cdir.name
    print(f"\n=== Concept: {cname} ===")

    for layer_name in TARGET_LAYERS.keys():
        print(f"  Building CAV for layer {layer_name} ...")
        cav = build_cav_for_concept_and_layer(layer_name, cdir, concept_dirs)
        cavs[(cname, layer_name)] = cav

        s_all = tcav_score_all(layer_name, cav)
        results.append({"Concept": cname, "Layer": layer_name, "Speaker": "all", "TCAV": s_all})
        print(f"    TCAV(all) = {s_all:.3f}")

        for spk in ["eden", "idan", "yoav"]:
            if spk not in speaker_items:
                continue
            s_spk = tcav_score_for_speaker(layer_name, cav, spk)
            results.append({"Concept": cname, "Layer": layer_name, "Speaker": spk, "TCAV": s_spk})
            print(f"    TCAV({spk}) = {s_spk:.3f}")

df = pd.DataFrame(results)
out_csv = Path.cwd() / "tcav_redimnet.csv"
df.to_csv(out_csv, index=False)
print("Saved:", out_csv)


Concepts:   0%|          | 0/12 [00:00<?, ?it/s]


=== Concept: long_constant_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.356
    TCAV(eden) = 0.067
    TCAV(idan) = 0.133


Concepts:   8%|▊         | 1/12 [00:32<06:02, 32.98s/it]

    TCAV(yoav) = 0.867

=== Concept: long_dropping_flat_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.622
    TCAV(eden) = 0.867
    TCAV(idan) = 0.633


Concepts:  17%|█▋        | 2/12 [01:05<05:28, 32.82s/it]

    TCAV(yoav) = 0.367

=== Concept: long_dropping_steep_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.622
    TCAV(eden) = 0.833
    TCAV(idan) = 0.733


Concepts:  25%|██▌       | 3/12 [01:38<04:53, 32.64s/it]

    TCAV(yoav) = 0.300

=== Concept: long_dropping_steep_thin ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.744
    TCAV(eden) = 0.967
    TCAV(idan) = 0.767


Concepts:  33%|███▎      | 4/12 [02:10<04:20, 32.56s/it]

    TCAV(yoav) = 0.500

=== Concept: long_rising_flat_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.300
    TCAV(eden) = 0.267
    TCAV(idan) = 0.233


Concepts:  42%|████▏     | 5/12 [02:43<03:48, 32.64s/it]

    TCAV(yoav) = 0.400

=== Concept: long_rising_steep_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.356
    TCAV(eden) = 0.833
    TCAV(idan) = 0.200


Concepts:  50%|█████     | 6/12 [03:16<03:15, 32.65s/it]

    TCAV(yoav) = 0.033

=== Concept: long_rising_steep_thin ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.711
    TCAV(eden) = 0.967
    TCAV(idan) = 0.733


Concepts:  58%|█████▊    | 7/12 [03:48<02:43, 32.75s/it]

    TCAV(yoav) = 0.433

=== Concept: short_constant_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.378
    TCAV(eden) = 0.033
    TCAV(idan) = 0.200


Concepts:  67%|██████▋   | 8/12 [04:21<02:10, 32.64s/it]

    TCAV(yoav) = 0.900

=== Concept: short_dropping_steep_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.478
    TCAV(eden) = 0.333
    TCAV(idan) = 0.167


Concepts:  75%|███████▌  | 9/12 [04:54<01:38, 32.68s/it]

    TCAV(yoav) = 0.933

=== Concept: short_dropping_steep_thin ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.400
    TCAV(eden) = 0.033
    TCAV(idan) = 0.400


Concepts:  83%|████████▎ | 10/12 [05:27<01:05, 32.88s/it]

    TCAV(yoav) = 0.767

=== Concept: short_rising_steep_thick ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.289
    TCAV(eden) = 0.333
    TCAV(idan) = 0.300


Concepts:  92%|█████████▏| 11/12 [06:00<00:32, 32.93s/it]

    TCAV(yoav) = 0.233

=== Concept: short_rising_steep_thin ===
  Building CAV for layer stage4 ...
    TCAV(all) = 0.400
    TCAV(eden) = 0.167
    TCAV(idan) = 0.700


Concepts: 100%|██████████| 12/12 [06:33<00:00, 32.78s/it]

    TCAV(yoav) = 0.333
Saved: /home/SpeakerRec/BioVoice/redimnet/tcav/tcav_redimnet.csv





In [32]:
# Optional: show a quick preview
df.head()


Unnamed: 0,Concept,Layer,Speaker,TCAV
0,long_constant_thick,stage4,all,0.355556
1,long_constant_thick,stage4,eden,0.066667
2,long_constant_thick,stage4,idan,0.133333
3,long_constant_thick,stage4,yoav,0.866667
4,long_dropping_flat_thick,stage4,all,0.622222


In [33]:
# %%
# -----------------------------
# WIDE CSV: one row per wav
# columns: wav_name, speaker, concept1, concept2, ...
# -----------------------------
import numpy as np
import pandas as pd
import torch

selected_layer = "stage4"  # change to the layer you want
layer = TARGET_LAYERS[selected_layer]

concept_names = sorted({c for (c, l) in cavs.keys() if l == selected_layer})
if not concept_names:
    raise RuntimeError(f"No CAVs for layer {selected_layer}")

CAV_MAT = np.stack([cavs[(c, selected_layer)] for c in concept_names], axis=0).astype(np.float32)
cav_dim = CAV_MAT.shape[1]

def tensor_to_channel_vec(t: torch.Tensor) -> np.ndarray:
    t = t[0]
    if t.ndim == 1:
        v = t
    elif t.ndim == 2:
        v = t.mean(dim=1) if t.shape[0] <= t.shape[1] else t.mean(dim=0)
    elif t.ndim == 3:
        v = t.mean(dim=(1, 2))
    else:
        raise RuntimeError(f"Unsupported tensor shape: {tuple(t.shape)}")
    return v.detach().cpu().numpy().astype(np.float32)

rows = []
for item in eval_items:  # {"wav": Path, "speaker": str, "mel": np.ndarray}
    wav_path = item["wav"]
    speaker = item["speaker"]
    mel_np = item["mel"]

    if speaker not in speaker_to_id:
        continue

    x = mel_to_redim_input(mel_np)
    x.requires_grad_(True)

    store = {}
    def hook_fn(_m, _inp, out):
        out.retain_grad()
        store["act"] = out

    h = layer.register_forward_hook(hook_fn)
    logits = wrapped_model(x)
    h.remove()

    scalar = logits[0, speaker_to_id[speaker]]
    wrapped_model.zero_grad(set_to_none=True)
    scalar.backward()

    g_vec = tensor_to_channel_vec(store["act"].grad)
    if g_vec.shape[0] != cav_dim:
        raise RuntimeError(f"Grad dim {g_vec.shape[0]} != CAV dim {cav_dim} at {selected_layer}")

    dd_vals = (CAV_MAT @ g_vec).astype(np.float32)

    row = {"wav_name": wav_path.name, "speaker": speaker}
    for c, dd in zip(concept_names, dd_vals):
        row[c] = float(dd)
    rows.append(row)

df = pd.DataFrame(rows, columns=["wav_name", "speaker"] + concept_names)
csv_path = f"tcav_per_wav_{selected_layer}_wide.csv"
df.to_csv(csv_path, index=False)
print("Saved →", csv_path)
df.head()


Saved → tcav_per_wav_stage4_wide.csv


Unnamed: 0,wav_name,speaker,long_constant_thick,long_dropping_flat_thick,long_dropping_steep_thick,long_dropping_steep_thin,long_rising_flat_thick,long_rising_steep_thick,long_rising_steep_thin,short_constant_thick,short_dropping_steep_thick,short_dropping_steep_thin,short_rising_steep_thick,short_rising_steep_thin
0,eden_001.wav,eden,-4e-06,-1.9e-05,1.180272e-05,-1.7e-05,4.145732e-07,2.6e-05,-1.7e-05,9.265759e-07,8e-06,1.1e-05,2.3e-05,5.853912e-06
1,eden_002.wav,eden,-2.5e-05,2.9e-05,3.175466e-06,2.8e-05,4.479923e-06,2.2e-05,3e-05,-2.939742e-05,-1.8e-05,-4.4e-05,-2.2e-05,-2.57233e-05
2,eden_003.wav,eden,-1e-06,1.3e-05,2.771273e-06,7e-06,1.685402e-06,5e-06,4e-06,-7.89874e-06,1.2e-05,-7e-06,5e-06,-4.419076e-08
3,eden_004.wav,eden,-6e-06,1.3e-05,-3.250373e-07,8e-06,-1.168117e-05,-1e-06,5e-06,-1.068907e-05,-2e-06,-6e-06,3e-06,5.820491e-06
4,eden_005.wav,eden,-3.1e-05,5e-06,1.885032e-05,1.7e-05,-1.340914e-05,2.2e-05,1.8e-05,-2.873736e-05,-9e-06,-1.9e-05,-7e-06,-1.163866e-05
