In [1]:
import sys
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

from captum.concept import TCAV, Concept


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -------- Project paths / device --------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

ATTR_CSV_PATH = Path(
    PROJECT_ROOT
    / "redimnet"
    / "grad_cam"
    / "2.0"
    / "output"
    / "speaker_similarity_ranking_vox2_10_20_ids.csv"
)
CONCEPT_ROOT  = Path(PROJECT_ROOT / "concept" / "temp_concepts")

# Pick one layer key you want TCAV on:
# LAYER_KEY = "stage4"

CONCEPT_SAMPLES = 100
RANDOM_SAMPLES  = 100
BATCH_SIZE_CONCEPT = 1  # keep 1 (safe if variable T)
FORCE_TRAIN_CAVS = True  # set True if you want to retrain CAVs

# OUT_CSV = Path(f"stage5_temp_concepts_{LAYER_KEY}.csv")

assert ATTR_CSV_PATH.exists(), f"Missing {ATTR_CSV_PATH}"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"

PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda


In [3]:
redim_model = (
    torch.hub.load(
        "IDRnD/ReDimNet",
        "ReDimNet",
        model_name="b5",
        train_type="ptn",
        dataset="vox2",
    )
    .to(DEVICE)
    .eval()
)
print("Loaded ReDimNet successfully.")

with torch.no_grad():
    dummy_wav = torch.zeros(1, 16000, device=DEVICE)
    dummy_mel = redim_model.spec(dummy_wav)  # (1, N_MELS, T)
N_MELS = int(dummy_mel.shape[1])
print("ReDimNet spec N_MELS =", N_MELS)


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet successfully.
ReDimNet spec N_MELS = 72


In [4]:
# -------- Load your speaker head ckpt --------
HEAD_PATH = PROJECT_ROOT / "data" / "heads" / "redim_speaker_head_vox2_10_20.pt"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

ckpt = torch.load(HEAD_PATH, map_location=DEVICE)
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = {v: k for k, v in speaker_to_id.items()}
SPEAKERS = list(speaker_to_id.keys())
l2_norm_emb = bool(ckpt.get("l2_norm_emb", True))

# infer in_dim from checkpoint
fc_w = ckpt["state_dict"]["fc.weight"]
in_dim = int(fc_w.shape[1])
num_classes = int(fc_w.shape[0])

print("Loaded head:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("Head in_dim:", in_dim, "num_classes:", num_classes, "l2_norm_emb:", l2_norm_emb)

class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

head = SpeakerHead(in_dim=in_dim, num_classes=num_classes).to(DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()

Loaded head: /home/SpeakerRec/BioVoice/data/heads/redim_speaker_head_vox2_10_20.pt
Speakers: ['id00012', 'id00016', 'id00018', 'id00019', 'id00020', 'id00021', 'id00022', 'id00024', 'id00025', 'id00026']
Head in_dim: 192 num_classes: 10 l2_norm_emb: True


  ckpt = torch.load(HEAD_PATH, map_location=DEVICE)


SpeakerHead(
  (fc): Linear(in_features=192, out_features=10, bias=True)
)

In [5]:
# -------- Wrap ReDimNet -> logits --------
class ReDimNetMelLogitsWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, N_MELS, T]
    Output: logits [B, num_speakers]
    """
    def __init__(self, redim_model, head: nn.Module, l2_norm_emb: bool):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool = redim_model.pool
        self.bn = redim_model.bn
        self.linear = redim_model.linear
        self.head = head
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d: torch.Tensor) -> torch.Tensor:
        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)
        return self.head(emb)

wrapped_model = ReDimNetMelLogitsWrapper(redim_model, head, l2_norm_emb=l2_norm_emb).to(DEVICE).eval()
print("wrapped_model ready.")


wrapped_model ready.


In [6]:
TARGET_LAYERS = {
    "stem":   wrapped_model.backbone.stem[0],
    "stage0": wrapped_model.backbone.stage0[2],
    "stage1": wrapped_model.backbone.stage1[2],
    "stage2": wrapped_model.backbone.stage2[2],
    "stage3": wrapped_model.backbone.stage3[2],
    "stage4": wrapped_model.backbone.stage4[2],
    "stage5": wrapped_model.backbone.stage5[2],
}
# assert LAYER_KEY in TARGET_LAYERS, f"{LAYER_KEY=} not in TARGET_LAYERS: {list(TARGET_LAYERS.keys())}"

In [7]:
# -------- Resolve layer name string for Captum --------
def module_name_in_model(model: torch.nn.Module, target_module: torch.nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError("Could not find the selected layer module in wrapped_model.named_modules()")


In [8]:
TCAV_DEVICE = torch.device("cpu")
print("TCAV_DEVICE =", TCAV_DEVICE)

redim_model = redim_model.to(TCAV_DEVICE).eval()
wrapped_model = wrapped_model.to(TCAV_DEVICE).eval()

DEVICE = TCAV_DEVICE


TCAV_DEVICE = cpu


In [9]:
class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: int | None = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = np.load(self.files[idx]).astype(np.float32)  # (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{self.files[idx].name}: expected {N_MELS} bins, got {mel.shape}")
        x = torch.from_numpy(mel).unsqueeze(0)  # (1, N_MELS, T) on CPU
        return x



def infer_frames_for_random(concept_dirs: list[Path]) -> int:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is not None:
            mel = np.load(f)
            return int(mel.shape[1])
    raise RuntimeError("Could not infer frames from concept dirs")

class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int, frames: int):
        self.n_samples = n_samples
        self.frames = frames

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        mel = torch.randn(N_MELS, self.frames, dtype=torch.float32)  
        return mel.unsqueeze(0) 


In [10]:
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")

concept_names = [d.name for d in concept_dirs]
print("Concepts:", concept_names)

TARGET_FRAMES = infer_frames_for_random(concept_dirs)
print("Using fixed frames for TCAV (from concepts):", TARGET_FRAMES)

Concepts: ['long_constant_thick', 'long_constant_thick_Vibrato', 'long_dropping_flat_thick', 'long_dropping_flat_thick_Vibrato', 'long_dropping_steep_thick', 'long_dropping_steep_thin', 'long_rising_flat_thick', 'long_rising_steep_thick', 'long_rising_steep_thin', 'short_constant_thick', 'short_dropping_steep_thick', 'short_dropping_steep_thin', 'short_rising_steep_thick', 'short_rising_steep_thin']
Using fixed frames for TCAV (from concepts): 304


In [11]:
# -------- Prepare concepts (same for all layers) --------
positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES, frames=TARGET_FRAMES)
rand_dl = DataLoader(rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]

# -------- Loop over all layers --------
all_tcav_results = {}  # Store results for each layer
all_acc_dfs = []       # Accumulate accuracy DataFrames

for layer_key, layer_module in TARGET_LAYERS.items():
    
    # Resolve layer name for Captum
    layer_name = module_name_in_model(wrapped_model, layer_module)
    # Initialize TCAV for this layer
    tcav = TCAV(wrapped_model, [layer_name], test_split_ratio=0.33)
    all_tcav_results[layer_key] = tcav




In [12]:
# === Output & cache paths ===
CAV_ROOT = PROJECT_ROOT / "data" / "cavs" / "vox2_10_20"
CAV_ROOT.mkdir(parents=True, exist_ok=True)

TCAV_OUT_DIR = PROJECT_ROOT / "data" / "tcav"
TCAV_OUT_DIR.mkdir(parents=True, exist_ok=True)

OUT_PARTIAL = TCAV_OUT_DIR / "tcav_partial_vox2_10_20.csv"
OUT_FINAL = TCAV_OUT_DIR / "tcav_all_layers_results_vox2_10_20.csv"

print("CAV_ROOT:", CAV_ROOT)
print("TCAV_OUT_DIR:", TCAV_OUT_DIR)

CAV_ROOT: /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20
TCAV_OUT_DIR: /home/SpeakerRec/BioVoice/data/tcav


In [13]:
def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """
    mel_3d: (1, N_MELS, T)
    returns: (1, N_MELS, target_frames)
    """
    T = int(mel_3d.shape[-1])
    if T == target_frames:
        return mel_3d
    if T > target_frames:
        start = (T - target_frames) // 2
        return mel_3d[..., start : start + target_frames]
    pad = target_frames - T
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)


def wav_path_to_mel4d(path: Path) -> torch.Tensor:
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1, :].float().to(DEVICE)
    with torch.no_grad():
        mel = redim_model.spec(wav)  # (1, N_MELS, T)
    mel = fix_mel_frames(mel, TARGET_FRAMES)  # (1, N_MELS, TARGET_FRAMES)
    return mel.unsqueeze(0)  # (1, 1, N_MELS, TARGET_FRAMES)


def predict_speaker(path: Path) -> tuple[str, float]:
    x = wav_path_to_mel4d(path)
    with torch.no_grad():
        logits = wrapped_model(x)  # (1, num_speakers)
        probs = F.softmax(logits, dim=1)[0]
        pred_id = int(torch.argmax(probs).item())
        pred_name = id_to_speaker[pred_id]
        pred_prob = float(probs[pred_id].item())
    return pred_name, pred_prob

In [14]:
def compute_cav_acc_df(
    tcav: TCAV,
    positive_concepts: list[Concept],
    random_concept: Concept,
    layer_key: str,
) -> pd.DataFrame:

    rows = []

    print(f"    Training CAVs for {layer_key}")
    cavs_dict = tcav.compute_cavs(
        [[c, random_concept] for c in positive_concepts],
        force_train=FORCE_TRAIN_CAVS,
    )

    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue

        if not (0 <= pos_id < len(positive_concepts)):
            continue

        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue

            acc = cav_obj.stats.get("accs") or cav_obj.stats.get("acc")
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()

            rows.append(
                {
                    "layer_key": layer_key,
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "cav acc": float(acc) if acc is not None else np.nan,
                }
            )

    acc_df = pd.DataFrame(rows)

    # ✅ save ONLY the statistics
    out_csv = CAV_ROOT / f"cav_acc_{layer_key}.csv"
    acc_df.to_csv(out_csv, index=False)
    print(f"    Saved CAV acc → {out_csv}")

    return acc_df

In [15]:
all_acc_dfs = []

print("Computing / loading CAV accuracies...")
for layer_key, tcav in all_tcav_results.items():
    print(f"  Layer: {layer_key}")
    acc_df = compute_cav_acc_df(tcav, positive_concepts, random_concept, layer_key)
    all_acc_dfs.append(acc_df)
    print(f"    → {len(acc_df)} rows")

acc_df_combined = (
    pd.concat(all_acc_dfs, ignore_index=True) if all_acc_dfs else pd.DataFrame()
)

print("Total CAV acc rows:", len(acc_df_combined))

Computing / loading CAV accuracies...
  Layer: stem
    Training CAVs for stem


  av = torch.load(fl)
  bias_values = torch.FloatTensor([sklearn_model.intercept_]).to(  # type: ignore
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stem.csv
    → 14 rows
  Layer: stage0
    Training CAVs for stage0


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage0.csv
    → 14 rows
  Layer: stage1
    Training CAVs for stage1


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage1.csv
    → 14 rows
  Layer: stage2
    Training CAVs for stage2


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage2.csv
    → 14 rows
  Layer: stage3
    Training CAVs for stage3


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage3.csv
    → 14 rows
  Layer: stage4
    Training CAVs for stage4


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage4.csv
    → 14 rows
  Layer: stage5
    Training CAVs for stage5


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Saved CAV acc → /home/SpeakerRec/BioVoice/data/cavs/vox2_10_20/cav_acc_stage5.csv
    → 14 rows
Total CAV acc rows: 98


In [16]:
# -------- Load attribute CSV (inputs for TCAV scoring) --------
ATTR_CSV_PATH = (
    PROJECT_ROOT
    / "redimnet"
    / "grad_cam"
    / "2.0"
    / "output"
    / "speaker_similarity_ranking_vox2_10_20_ids.csv"
)

assert ATTR_CSV_PATH.exists(), f"Missing CSV: {ATTR_CSV_PATH}"

df_attr = pd.read_csv(ATTR_CSV_PATH)


print("Loaded df_attr:", df_attr.shape)

Loaded df_attr: (200, 4)


In [17]:
rows = []

if OUT_PARTIAL.exists():
    print("Resuming from partial CSV")
    rows = pd.read_csv(OUT_PARTIAL).to_dict("records")

for i, r in enumerate(df_attr.itertuples(), start=1):
    path = Path(r.path)
    true_label = str(r.speaker)

    if not path.exists() or true_label not in speaker_to_id:
        continue

    pred_label, pred_prob = predict_speaker(path)
    x = wav_path_to_mel4d(path)
    target_idx = speaker_to_id[true_label]

    for layer_key, tcav in all_tcav_results.items():
        scores = tcav.interpret(
            inputs=x,
            experimental_sets=experimental_sets,
            target=target_idx,
        )

        for exp_key, layer_dict in scores.items():
            try:
                pos_idx = int(str(exp_key).split("-")[0])
            except Exception:
                continue

            if not (0 <= pos_idx < len(positive_concepts)):
                continue

            concept_name = positive_concepts[pos_idx].name

            for layer_name, metrics in layer_dict.items():
                sc = metrics.get("sign_count")
                mg = metrics.get("magnitude")
                if sc is None or mg is None:
                    continue

                if isinstance(sc, torch.Tensor):
                    sc = sc.detach().cpu().tolist()
                if isinstance(mg, torch.Tensor):
                    mg = mg.detach().cpu().tolist()

                rows.append(
                    {
                        "path": str(path),
                        "layer_key": layer_key,
                        "concept name": concept_name,
                        "layer name": layer_name,
                        "positive percentage": float(sc[0]),
                        "magnitude": float(mg[0]),
                        "true label": true_label,
                        "predicted label": pred_label,
                        "predicted probability": float(pred_prob),
                    }
                )

    # 🔹 checkpoint every 100 samples
    if i % 100 == 0:
        pd.DataFrame(rows).to_csv(OUT_PARTIAL, index=False)
        print(f"Saved partial → {OUT_PARTIAL} ({len(rows)} rows)")

  save_dict = torch.load(cavs_path)


Saved partial → /home/SpeakerRec/BioVoice/data/tcav/tcav_partial_vox2_10_20.csv (9800 rows)


  save_dict = torch.load(cavs_path)


Saved partial → /home/SpeakerRec/BioVoice/data/tcav/tcav_partial_vox2_10_20.csv (19600 rows)


In [18]:
df_tcav = pd.DataFrame(rows)

df_tcav = df_tcav.merge(
    acc_df_combined,
    on=["layer_key", "concept name", "layer name"],
    how="left",
)

df_tcav.to_csv(OUT_FINAL, index=False)

print(f"Saved FINAL → {OUT_FINAL}")
print("Final shape:", df_tcav.shape)
display(df_tcav.head())

Saved FINAL → /home/SpeakerRec/BioVoice/data/tcav/tcav_all_layers_results_vox2_10_20.csv
Final shape: (19600, 10)


Unnamed: 0,path,layer_key,concept name,layer name,positive percentage,magnitude,true label,predicted label,predicted probability,cav acc
0,/home/SpeakerRec/BioVoice/data/datasets/voxcel...,stem,long_constant_thick,backbone.stem.0,1.0,0.008734,id00012,id00012,0.179189,0.288462
1,/home/SpeakerRec/BioVoice/data/datasets/voxcel...,stem,long_constant_thick_Vibrato,backbone.stem.0,0.0,-0.035952,id00012,id00012,0.179189,0.423077
2,/home/SpeakerRec/BioVoice/data/datasets/voxcel...,stem,long_dropping_flat_thick,backbone.stem.0,1.0,0.022522,id00012,id00012,0.179189,0.423077
3,/home/SpeakerRec/BioVoice/data/datasets/voxcel...,stem,long_dropping_flat_thick_Vibrato,backbone.stem.0,0.0,-0.060824,id00012,id00012,0.179189,0.442308
4,/home/SpeakerRec/BioVoice/data/datasets/voxcel...,stem,long_dropping_steep_thick,backbone.stem.0,1.0,0.008886,id00012,id00012,0.179189,0.403846
