In [None]:
# %%

from typing import Optional


from captum.concept import TCAV, Concept

import sys
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader, Dataset
from wespeaker.cli.speaker import load_model
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F


import pandas as pd
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# %%
# -------- Project paths / device --------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
ATTR_CSV_PATH = Path(
    PROJECT_ROOT / "resnet_293" / "speaker_similarity_ranking_team.csv"
)

WAV_FOLDER = Path(PROJECT_ROOT / "data" / "wavs")
CONCEPT_ROOT = Path(PROJECT_ROOT / "concept" / "temp_concepts")

# Pick one layer key you want TCAV on:
# LAYER_KEY = "stage5"

CONCEPT_SAMPLES = 100
RANDOM_SAMPLES = 100
BATCH_SIZE_CONCEPT = 1  # keep 1 (safe if variable T)
FORCE_TRAIN_CAVS = True  # set True if you want to retrain CAVs

OUT_CSV = Path(f"stage5_temp_concepts_{LAYER_KEY}.csv")

assert ATTR_CSV_PATH.exists(), f"Missing {ATTR_CSV_PATH}"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"
HEAD_PATH = Path(PROJECT_ROOT / "data" / "heads" / "resnet_293_speaker_head.pt")

assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

# %%
speaker = load_model(PROJECT_ROOT / "wespeaker-voxceleb-resnet293-LM")
net = speaker.model
net = net.to(DEVICE)

print("ResNet-293 loaded from HF")


# %%
# -------- Load your speaker head ckpt --------


ckpt = torch.load(HEAD_PATH, map_location=DEVICE)
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
SPEAKERS = list(speaker_to_id.keys())
l2_norm_emb = bool(ckpt.get("l2_norm_emb", True))

# infer in_dim from checkpoint
fc_w = ckpt["state_dict"]["fc.weight"]
in_dim = int(fc_w.shape[1])
num_classes = int(fc_w.shape[0])

print("Loaded head:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("Head in_dim:", in_dim, "num_classes:", num_classes, "l2_norm_emb:", l2_norm_emb)


class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)


head = SpeakerHead(in_dim=in_dim, num_classes=num_classes).to(DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()


# %%
class WeSpeakerWithHeadForGradCAM(nn.Module):
    def __init__(
        self, backbone: nn.Module, head: nn.Module, try_transpose: bool = True
    ):
        super().__init__()
        self.backbone = backbone
        self.head = head
        self.try_transpose = try_transpose

    def forward(self, feats):
        # Run backbone
        try:
            out = self.backbone(feats)
        except Exception:
            if not self.try_transpose:
                raise
            out = self.backbone(feats.transpose(1, 2))

        # Unwrap tuple/list (WeSpeaker behavior)
        if isinstance(out, (tuple, list)):
            emb = out[0]
        else:
            emb = out

        # ---- HARD SHAPE GUARD ----
        if emb.ndim == 0:
            emb = emb.unsqueeze(0).unsqueeze(0)
        elif emb.ndim == 1:
            emb = emb.unsqueeze(0)
        elif emb.ndim > 2:
            emb = emb.view(emb.size(0), -1)

        # ---- HARD DEVICE GUARD (THIS FIXES YOUR ERROR) ----
        emb = emb.to(next(self.head.parameters()).device)

        logits = self.head(emb)

        if logits.ndim == 1:
            logits = logits.unsqueeze(0)

        return logits


wrapped_model = WeSpeakerWithHeadForGradCAM(speaker.model, head).to(DEVICE).eval()

# Enable gradients for backbone (needed for Grad-CAM)
for p in wrapped_model.backbone.parameters():
    p.requires_grad_(True)

# Head grads not required (optional)
for p in wrapped_model.head.parameters():
    p.requires_grad_(False)

print("wrapped_model ready")

# %%
# %%
# ========== TARGET LAYERS (same idea, but via wrapped_model.backbone) ==========
TARGET_LAYERS = {
    "layer1.9.conv3": wrapped_model.backbone.layer1[9].conv3,
    "layer2.19.conv3": wrapped_model.backbone.layer2[19].conv3,
    "layer3.63.conv3": wrapped_model.backbone.layer3[63].conv3,
    "layer4.2.conv3": wrapped_model.backbone.layer4[2].conv3,
}


assert (
    LAYER_KEY in TARGET_LAYERS
), f"{LAYER_KEY=} not in TARGET_LAYERS: {list(TARGET_LAYERS.keys())}"


# %%
# -------- Resolve layer name string for Captum --------
def module_name_in_model(model: torch.nn.Module, target_module: torch.nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError(
        "Could not find the selected layer module in wrapped_model.named_modules()"
    )


layer_module = TARGET_LAYERS[LAYER_KEY]
LAYER_NAME = module_name_in_model(wrapped_model, layer_module)
print("Using layer:", LAYER_KEY, "->", LAYER_NAME)


# %%
TCAV_DEVICE = torch.device("cpu")
print("TCAV_DEVICE =", TCAV_DEVICE)

redim_model = redim_model.to(TCAV_DEVICE).eval()
wrapped_model = wrapped_model.to(TCAV_DEVICE).eval()

DEVICE = TCAV_DEVICE


# %%
class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: int | None = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = np.load(self.files[idx]).astype(np.float32)  # (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(
                f"{self.files[idx].name}: expected {N_MELS} bins, got {mel.shape}"
            )
        x = torch.from_numpy(mel).unsqueeze(0)  # (1, N_MELS, T) on CPU
        return x


def infer_frames_for_random(concept_dirs: list[Path]) -> int:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is not None:
            mel = np.load(f)
            return int(mel.shape[1])
    raise RuntimeError("Could not infer frames from concept dirs")


class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int, frames: int):
        self.n_samples = n_samples
        self.frames = frames

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        mel = torch.randn(N_MELS, self.frames, dtype=torch.float32)
        return mel.unsqueeze(0)


# %%
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")

concept_names = [d.name for d in concept_dirs]
print("Concepts:", concept_names)

TARGET_FRAMES = infer_frames_for_random(concept_dirs)
print("Using fixed frames for TCAV (from concepts):", TARGET_FRAMES)
tcav = TCAV(wrapped_model, [LAYER_NAME], test_split_ratio=0.33)

positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES, frames=TARGET_FRAMES)
rand_dl = DataLoader(
    rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0
)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]


# %%
def compute_cav_acc_df(
    tcav: TCAV, positive_concepts: list[Concept], random_concept: Concept
) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs(
        [[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS
    )

    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get("accs", None)
            if acc is None:
                acc = cav_obj.stats.get("acc", None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()
            rows.append(
                {
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "cav acc": float(acc) if acc is not None else np.nan,
                }
            )
    return pd.DataFrame(rows, columns=["concept name", "layer name", "cav acc"])


acc_df = compute_cav_acc_df(tcav, positive_concepts, random_concept)
print(acc_df.head())


# %%
def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """
    mel_3d: (1, N_MELS, T)
    returns: (1, N_MELS, target_frames)
    """
    T = int(mel_3d.shape[-1])
    if T == target_frames:
        return mel_3d
    if T > target_frames:
        start = (T - target_frames) // 2
        return mel_3d[..., start : start + target_frames]
    pad = target_frames - T
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)


def wav_path_to_mel4d(path: Path) -> torch.Tensor:
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1, :].float().to(DEVICE)
    with torch.no_grad():
        mel = redim_model.spec(wav)  # (1, N_MELS, T)
    mel = fix_mel_frames(mel, TARGET_FRAMES)  # (1, N_MELS, TARGET_FRAMES)
    return mel.unsqueeze(0)  # (1, 1, N_MELS, TARGET_FRAMES)


def predict_speaker(path: Path) -> tuple[str, float]:
    x = wav_path_to_mel4d(path)
    with torch.no_grad():
        logits = wrapped_model(x)  # (1, num_speakers)
        probs = F.softmax(logits, dim=1)[0]
        pred_id = int(torch.argmax(probs).item())
        pred_name = id_to_speaker[pred_id]
        pred_prob = float(probs[pred_id].item())
    return pred_name, pred_prob


# %%
df_attr = pd.read_csv(ATTR_CSV_PATH)


if "path" not in df_attr.columns or "speaker" not in df_attr.columns:
    raise RuntimeError(
        f"CSV must contain columns ['path','speaker']. Got: {list(df_attr.columns)}"
    )

rows = []

for _, r in df_attr.iterrows():
    path = Path(r["path"])
    true_label = str(r["speaker"])

    if not path.exists():
        continue
    if true_label not in speaker_to_id:
        continue

    pred_label, pred_prob = predict_speaker(path)

    x = wav_path_to_mel4d(path)
    target_idx = speaker_to_id[true_label]

    score_for_label = tcav.interpret(
        inputs=x,
        experimental_sets=experimental_sets,
        target=target_idx,
    )

    for exp_key, layer_dict in score_for_label.items():
        try:
            pos_idx = int(str(exp_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_idx < len(positive_concepts)):
            continue

        concept_name = positive_concepts[pos_idx].name

        for layer_name, metrics in layer_dict.items():
            sc = metrics.get("sign_count")
            mg = metrics.get("magnitude")
            if sc is None or mg is None:
                continue

            if isinstance(sc, torch.Tensor):
                sc = sc.detach().cpu().tolist()
            if isinstance(mg, torch.Tensor):
                mg = mg.detach().cpu().tolist()

            rows.append(
                {
                    "path": str(path),
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "positive percentage": float(sc[0]),
                    "magnitude": float(mg[0]),
                    "true label": true_label,
                    "predicted label": pred_label,
                    "predicted probability": float(pred_prob),
                }
            )

df_tcav = pd.DataFrame(
    rows,
    columns=[
        "path",
        "concept name",
        "layer name",
        "positive percentage",
        "magnitude",
        "true label",
        "predicted label",
        "predicted probability",
    ],
)

df_tcav = df_tcav.merge(acc_df, on=["concept name", "layer name"], how="left")

df_tcav.to_csv(OUT_CSV, index=False)
print("Saved →", OUT_CSV)
df_tcav.head()

In [None]:
# -------- Project paths / device --------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
ATTR_CSV_PATH = Path(
    PROJECT_ROOT
    / "resnet_293"
    / "speaker_similarity_ranking_team.csv"
)

WAV_FOLDER = Path(PROJECT_ROOT / "data" / "wavs")
CONCEPT_ROOT  = Path(PROJECT_ROOT / "concept" / "temp_concepts")

# Pick one layer key you want TCAV on:
# LAYER_KEY = "stage5"

CONCEPT_SAMPLES = 100
RANDOM_SAMPLES  = 100
BATCH_SIZE_CONCEPT = 1  # keep 1 (safe if variable T)
FORCE_TRAIN_CAVS = True  # set True if you want to retrain CAVs

OUT_CSV = Path(f"stage5_temp_concepts_{LAYER_KEY}.csv")

assert ATTR_CSV_PATH.exists(), f"Missing {ATTR_CSV_PATH}"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"
HEAD_PATH = Path(PROJECT_ROOT
    / "data"
    / "heads"
    / "resnet_293_speaker_head.pt")

assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

In [None]:
speaker = load_model(PROJECT_ROOT / "wespeaker-voxceleb-resnet293-LM")
net = speaker.model
net = net.to(DEVICE)

print("ResNet-293 loaded from HF")



In [None]:
# -------- Load your speaker head ckpt --------


ckpt = torch.load(HEAD_PATH, map_location=DEVICE)
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
SPEAKERS = list(speaker_to_id.keys())
l2_norm_emb = bool(ckpt.get("l2_norm_emb", True))

# infer in_dim from checkpoint
fc_w = ckpt["state_dict"]["fc.weight"]
in_dim = int(fc_w.shape[1])
num_classes = int(fc_w.shape[0])

print("Loaded head:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("Head in_dim:", in_dim, "num_classes:", num_classes, "l2_norm_emb:", l2_norm_emb)

class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

head = SpeakerHead(in_dim=in_dim, num_classes=num_classes).to(DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()


In [None]:
class WeSpeakerWithHeadForGradCAM(nn.Module):
    def __init__(
        self, backbone: nn.Module, head: nn.Module, try_transpose: bool = True
    ):
        super().__init__()
        self.backbone = backbone
        self.head = head
        self.try_transpose = try_transpose

    def forward(self, feats):
        # Run backbone
        try:
            out = self.backbone(feats)
        except Exception:
            if not self.try_transpose:
                raise
            out = self.backbone(feats.transpose(1, 2))

        # Unwrap tuple/list (WeSpeaker behavior)
        if isinstance(out, (tuple, list)):
            emb = out[0]
        else:
            emb = out

        # ---- HARD SHAPE GUARD ----
        if emb.ndim == 0:
            emb = emb.unsqueeze(0).unsqueeze(0)
        elif emb.ndim == 1:
            emb = emb.unsqueeze(0)
        elif emb.ndim > 2:
            emb = emb.view(emb.size(0), -1)

        # ---- HARD DEVICE GUARD (THIS FIXES YOUR ERROR) ----
        emb = emb.to(next(self.head.parameters()).device)

        logits = self.head(emb)

        if logits.ndim == 1:
            logits = logits.unsqueeze(0)

        return logits


wrapped_model = WeSpeakerWithHeadForGradCAM(speaker.model, head).to(DEVICE).eval()

# Enable gradients for backbone (needed for Grad-CAM)
for p in wrapped_model.backbone.parameters():
    p.requires_grad_(True)

# Head grads not required (optional)
for p in wrapped_model.head.parameters():
    p.requires_grad_(False)

print("wrapped_model ready")

In [None]:
# %%
# ========== TARGET LAYERS (same idea, but via wrapped_model.backbone) ==========
TARGET_LAYERS = {
    "layer1.9.conv3": wrapped_model.backbone.layer1[9].conv3,
    "layer2.19.conv3": wrapped_model.backbone.layer2[19].conv3,
    "layer3.63.conv3": wrapped_model.backbone.layer3[63].conv3,
    "layer4.2.conv3": wrapped_model.backbone.layer4[2].conv3,
}


assert LAYER_KEY in TARGET_LAYERS, f"{LAYER_KEY=} not in TARGET_LAYERS: {list(TARGET_LAYERS.keys())}"

In [None]:
# -------- Resolve layer name string for Captum --------
def module_name_in_model(model: torch.nn.Module, target_module: torch.nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError("Could not find the selected layer module in wrapped_model.named_modules()")


    
layer_module = TARGET_LAYERS[LAYER_KEY]
LAYER_NAME = module_name_in_model(wrapped_model, layer_module)
print("Using layer:", LAYER_KEY, "->", LAYER_NAME)


In [None]:
TCAV_DEVICE = torch.device("cpu")
print("TCAV_DEVICE =", TCAV_DEVICE)

redim_model = redim_model.to(TCAV_DEVICE).eval()
wrapped_model = wrapped_model.to(TCAV_DEVICE).eval()

DEVICE = TCAV_DEVICE


In [None]:
class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: int | None = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = np.load(self.files[idx]).astype(np.float32)  # (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{self.files[idx].name}: expected {N_MELS} bins, got {mel.shape}")
        x = torch.from_numpy(mel).unsqueeze(0)  # (1, N_MELS, T) on CPU
        return x



def infer_frames_for_random(concept_dirs: list[Path]) -> int:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is not None:
            mel = np.load(f)
            return int(mel.shape[1])
    raise RuntimeError("Could not infer frames from concept dirs")

class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int, frames: int):
        self.n_samples = n_samples
        self.frames = frames

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        mel = torch.randn(N_MELS, self.frames, dtype=torch.float32)  
        return mel.unsqueeze(0) 


In [None]:
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")

concept_names = [d.name for d in concept_dirs]
print("Concepts:", concept_names)

TARGET_FRAMES = infer_frames_for_random(concept_dirs)
print("Using fixed frames for TCAV (from concepts):", TARGET_FRAMES)
tcav = TCAV(wrapped_model, [LAYER_NAME], test_split_ratio=0.33)

positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES, frames=TARGET_FRAMES)
rand_dl = DataLoader(rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]



In [None]:
def compute_cav_acc_df(tcav: TCAV, positive_concepts: list[Concept], random_concept: Concept) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs([[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS)

    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get("accs", None)
            if acc is None:
                acc = cav_obj.stats.get("acc", None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()
            rows.append({
                "concept name": concept_name,
                "layer name": layer_name,
                "cav acc": float(acc) if acc is not None else np.nan,
            })
    return pd.DataFrame(rows, columns=["concept name", "layer name", "cav acc"])

acc_df = compute_cav_acc_df(tcav, positive_concepts, random_concept)
print(acc_df.head())


In [None]:
def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """
    mel_3d: (1, N_MELS, T)
    returns: (1, N_MELS, target_frames)
    """
    T = int(mel_3d.shape[-1])
    if T == target_frames:
        return mel_3d
    if T > target_frames:
        start = (T - target_frames) // 2
        return mel_3d[..., start:start + target_frames]
    pad = target_frames - T
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)

def wav_path_to_mel4d(path: Path) -> torch.Tensor:
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1, :].float().to(DEVICE)
    with torch.no_grad():
        mel = redim_model.spec(wav)          # (1, N_MELS, T)
    mel = fix_mel_frames(mel, TARGET_FRAMES) # (1, N_MELS, TARGET_FRAMES)
    return mel.unsqueeze(0)                  # (1, 1, N_MELS, TARGET_FRAMES)

def predict_speaker(path: Path) -> tuple[str, float]:
    x = wav_path_to_mel4d(path)
    with torch.no_grad():
        logits = wrapped_model(x)            # (1, num_speakers)
        probs = F.softmax(logits, dim=1)[0]
        pred_id = int(torch.argmax(probs).item())
        pred_name = id_to_speaker[pred_id]
        pred_prob = float(probs[pred_id].item())
    return pred_name, pred_prob


In [None]:
df_attr = pd.read_csv(ATTR_CSV_PATH)


if "path" not in df_attr.columns or "speaker" not in df_attr.columns:
    raise RuntimeError(f"CSV must contain columns ['path','speaker']. Got: {list(df_attr.columns)}")

rows = []

for _, r in df_attr.iterrows():
    path = Path(r["path"])
    true_label = str(r["speaker"])

    if not path.exists():
        continue
    if true_label not in speaker_to_id:
        continue

    pred_label, pred_prob = predict_speaker(path)

    x = wav_path_to_mel4d(path)
    target_idx = speaker_to_id[true_label]

    score_for_label = tcav.interpret(
        inputs=x,
        experimental_sets=experimental_sets,
        target=target_idx,
    )

    for exp_key, layer_dict in score_for_label.items():
        try:
            pos_idx = int(str(exp_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_idx < len(positive_concepts)):
            continue

        concept_name = positive_concepts[pos_idx].name

        for layer_name, metrics in layer_dict.items():
            sc = metrics.get("sign_count")
            mg = metrics.get("magnitude")
            if sc is None or mg is None:
                continue

            if isinstance(sc, torch.Tensor):
                sc = sc.detach().cpu().tolist()
            if isinstance(mg, torch.Tensor):
                mg = mg.detach().cpu().tolist()

            rows.append({
                "path": str(path),
                "concept name": concept_name,
                "layer name": layer_name,
                "positive percentage": float(sc[0]),
                "magnitude": float(mg[0]),
                "true label": true_label,
                "predicted label": pred_label,
                "predicted probability": float(pred_prob),
            })

df_tcav = pd.DataFrame(
    rows,
    columns=[
        "path", "concept name", "layer name", "positive percentage", "magnitude",
        "true label", "predicted label", "predicted probability"
    ],
)

df_tcav = df_tcav.merge(acc_df, on=["concept name", "layer name"], how="left")

df_tcav.to_csv(OUT_CSV, index=False)
print("Saved →", OUT_CSV)
df_tcav.head()
