# TCAV on ASVspoof 2019 (LA) — ReDimNet + Logistic Probe (Spoof vs Bonafide)

This notebook runs **TCAV** to explain a **spoof (fake) vs bonafide (real)** decision.

Pipeline:
- Waveform (HF dataset) → `ReDimNet.spec` → mel
- mel → ReDimNet backbone/pool/bn/linear → embedding
- embedding → StandardScaler → LogisticRegression probe → 2-class logits `[real, fake]`
- TCAV computes concept influence on **target class = Fake**.

Expected paths:
- Concepts: `/home/SpeakerRec/BioVoice/concept/temp_concepts/<concept_name>/*.npy`
- TCAV subset: `/home/SpeakerRec/BioVoice/data/datasets/asv_spoof_2019/tcav__20_speakers_10_real_10_fake`
- Probe model: `/home/SpeakerRec/BioVoice/data/models/asvspoof_probe_50_50/{scaler.pkl, logistic_regression.pkl}`


In [1]:
# %%
import sys
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
from captum.concept import TCAV, Concept

print("torch:", torch.__version__)


  from .autonotebook import tqdm as notebook_tqdm


torch: 2.1.2+cu121


In [None]:
# %%
# -------- Paths / device --------
PROJECT_ROOT = Path("/home/SpeakerRec/BioVoice")
sys.path.append(str(PROJECT_ROOT))

DEVICE = torch.device("cpu")  # TCAV on CPU
print("PROJECT_ROOT =", PROJECT_ROOT)
print("DEVICE =", DEVICE)

CONCEPT_ROOT = PROJECT_ROOT / "concept" / "temp_concepts"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"

TCAV_DATASET_PATH = PROJECT_ROOT / "data" / "datasets" / "asv_spoof_2019" / "tcav__20_speakers_10_real_10_fake"
assert TCAV_DATASET_PATH.exists(), f"Missing {TCAV_DATASET_PATH}"

MODEL_DIR = PROJECT_ROOT / "data" / "models" / "asvspoof_probe_50_50"
LOGREG_PATH = MODEL_DIR / "logistic_regression.pkl"
SCALER_PATH = MODEL_DIR / "scaler.pkl"
assert LOGREG_PATH.exists(), f"Missing {LOGREG_PATH}"
assert SCALER_PATH.exists(), f"Missing {SCALER_PATH}"

OUT_DIR = PROJECT_ROOT / "data" / "tcav"
OUT_DIR.mkdir(parents=True, exist_ok=True)
print("OUT_DIR =", OUT_DIR)

PROJECT_ROOT = /home/SpeakerRec/BioVoice
DEVICE = cpu
OUT_DIR = /home/SpeakerRec/BioVoice/output


In [3]:
# %%
# -------- Load TCAV subset (HF dataset) --------
tcav_subset = load_from_disk(TCAV_DATASET_PATH)
print("Loaded tcav_subset:", len(tcav_subset))
print("Columns:", tcav_subset.column_names)

# ASVspoof: key==1 bonafide(real), key==0 spoof(fake)
print("key counts:", pd.Series(tcav_subset["key"]).value_counts().to_dict())


Loaded tcav_subset: 1150
Columns: ['speaker_id', 'audio_file_name', 'audio', 'system_id', 'key']
key counts: {0: 670, 1: 480}


In [4]:
# %%
# -------- Load ReDimNet --------
redim_model = (
    torch.hub.load(
        "IDRnD/ReDimNet",
        "ReDimNet",
        model_name="b5",
        train_type="ptn",
        dataset="vox2",
    )
    .to(DEVICE)
    .eval()
)

with torch.no_grad():
    dummy_wav = torch.zeros(1, 16000, device=DEVICE)
    dummy_mel = redim_model.spec(dummy_wav)  # (1, N_MELS, T)

N_MELS = int(dummy_mel.shape[1])
print("Loaded ReDimNet. N_MELS =", N_MELS)


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet. N_MELS = 72


In [5]:
# %%
# -------- Load scaler + logistic probe --------
with open(LOGREG_PATH, "rb") as f:
    logreg_clf = pickle.load(f)

with open(SCALER_PATH, "rb") as f:
    scaler = pickle.load(f)

print("Loaded probe:", type(logreg_clf), type(scaler))


Loaded probe: <class 'sklearn.linear_model._logistic.LogisticRegression'> <class 'sklearn.preprocessing._data.StandardScaler'>


In [6]:
class ReDimNetSpoofWrapper(nn.Module):
    def __init__(self, redim_model, W, b, mean, scale, l2_norm_emb=True):
        super().__init__()
        self.redim = redim_model
        self.l2_norm_emb = l2_norm_emb

        D = W.shape[1]

        # scaler parameters as buffers
        self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
        self.register_buffer("scale", torch.tensor(scale, dtype=torch.float32))

        # logistic as linear layer
        self.linear = nn.Linear(D, 1)
        self.linear.weight.data = torch.tensor(W, dtype=torch.float32)
        self.linear.bias.data = torch.tensor(b, dtype=torch.float32)

    def forward(self, mel4d):
        x = self.redim.backbone(mel4d)
        x = self.redim.pool(x)
        x = self.redim.bn(x)
        emb = self.redim.linear(x)

        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)

        # torch scaling
        emb = (emb - self.mean) / self.scale

        score = self.linear(emb)  # [B,1]

        logits = torch.cat([-score, score], dim=1)  # [B,2]

        return logits

spoof_model = (
    ReDimNetSpoofWrapper(
        redim_model,
        W=logreg_clf.coef_,
        b=logreg_clf.intercept_,
        mean=scaler.mean_,
        scale=scaler.scale_,
    )
    .to(DEVICE)
    .eval()
)
print("Created ReDimNetSpoofWrapper model.")

Created ReDimNetSpoofWrapper model.


In [7]:
# %%
# Pick 1 layer for speed (add stage5 later if you want)
TARGET_LAYERS = {
    "stage4": spoof_model.redim.backbone.stage4[2],
    # "stage5": spoof_model.redim.backbone.stage5[2],
}


In [8]:
# %%
def module_name_in_model(model: torch.nn.Module, target_module: torch.nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError("Could not find selected layer module in model.named_modules()")


In [9]:
# %%
def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """mel_3d: (1, N_MELS, T) -> (1, N_MELS, target_frames)"""
    T = int(mel_3d.shape[-1])
    if T == target_frames:
        return mel_3d
    if T > target_frames:
        start = (T - target_frames) // 2
        return mel_3d[..., start:start + target_frames]
    pad = target_frames - T
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)


In [10]:
# %%
def infer_frames_for_random(concept_dirs: list[Path]) -> int:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is not None:
            mel = np.load(f)
            return int(mel.shape[1])
    raise RuntimeError("Could not infer frames from concept dirs")

concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")

concept_names = [d.name for d in concept_dirs]
TARGET_FRAMES = infer_frames_for_random(concept_dirs)

print("Concepts:", concept_names)
print("TARGET_FRAMES =", TARGET_FRAMES)


Concepts: ['long_constant_thick', 'long_constant_thick_Vibrato', 'long_dropping_flat_thick', 'long_dropping_flat_thick_Vibrato', 'long_dropping_steep_thick', 'long_dropping_steep_thin', 'long_rising_flat_thick', 'long_rising_steep_thick', 'long_rising_steep_thin', 'short_constant_thick', 'short_dropping_steep_thick', 'short_dropping_steep_thin', 'short_rising_steep_thick', 'short_rising_steep_thin']
TARGET_FRAMES = 304


In [11]:
# %%
CONCEPT_SAMPLES = 100
RANDOM_SAMPLES = 100
BATCH_SIZE_CONCEPT = 1
FORCE_TRAIN_CAVS = True

class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: int | None = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = np.load(self.files[idx]).astype(np.float32)  # (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{self.files[idx].name}: expected {N_MELS} bins, got {mel.shape}")
        x = torch.from_numpy(mel).unsqueeze(0)  # (1, N_MELS, T)
        x = fix_mel_frames(x, TARGET_FRAMES)    # (1, N_MELS, TARGET_FRAMES)
        return x

class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int, frames: int):
        self.n_samples = n_samples
        self.frames = frames

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return torch.randn(1, N_MELS, self.frames, dtype=torch.float32)

positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES, frames=TARGET_FRAMES)
rand_dl = DataLoader(rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]

print("Prepared", len(positive_concepts), "concepts + random.")


Prepared 14 concepts + random.


In [12]:
# %%
# Initialize TCAV per layer
all_tcav = {}
for layer_key, layer_module in TARGET_LAYERS.items():
    layer_name = module_name_in_model(spoof_model, layer_module)
    print("Layer:", layer_key, "->", layer_name)
    all_tcav[layer_key] = TCAV(spoof_model, [layer_name], test_split_ratio=0.33)


Layer: stage4 -> redim.backbone.stage4.2




In [13]:
# %%
def compute_cav_acc_df(tcav: TCAV, positive_concepts: list[Concept], random_concept: Concept, layer_key: str) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs([[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS)

    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get("accs", None)
            if acc is None:
                acc = cav_obj.stats.get("acc", None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()
            rows.append({
                "layer_key": layer_key,
                "concept name": concept_name,
                "layer name": layer_name,
                "cav acc": float(acc) if acc is not None else np.nan,
            })
    return pd.DataFrame(rows)

print("Computing CAV accuracies...")
acc_dfs = []
for layer_key, tcav in all_tcav.items():
    df_acc = compute_cav_acc_df(tcav, positive_concepts, random_concept, layer_key)
    print(layer_key, "rows:", len(df_acc))
    acc_dfs.append(df_acc)

acc_df_combined = pd.concat(acc_dfs, ignore_index=True) if acc_dfs else pd.DataFrame()
display(acc_df_combined.head(10) if not acc_df_combined.empty else acc_df_combined)


Computing CAV accuracies...


  bias_values = torch.FloatTensor([sklearn_model.intercept_]).to(  # type: ignore


stage4 rows: 14


Unnamed: 0,layer_key,concept name,layer name,cav acc
0,stage4,long_constant_thick,redim.backbone.stage4.2,0.384615
1,stage4,long_constant_thick_Vibrato,redim.backbone.stage4.2,0.442308
2,stage4,long_dropping_flat_thick,redim.backbone.stage4.2,0.326923
3,stage4,long_dropping_flat_thick_Vibrato,redim.backbone.stage4.2,0.403846
4,stage4,long_dropping_steep_thick,redim.backbone.stage4.2,0.403846
5,stage4,long_dropping_steep_thin,redim.backbone.stage4.2,0.326923
6,stage4,long_rising_flat_thick,redim.backbone.stage4.2,0.384615
7,stage4,long_rising_steep_thick,redim.backbone.stage4.2,0.269231
8,stage4,long_rising_steep_thin,redim.backbone.stage4.2,0.423077
9,stage4,short_constant_thick,redim.backbone.stage4.2,0.269231


In [14]:
# %%
def waveform_to_mel4d(audio_array: np.ndarray) -> torch.Tensor:
    wav = torch.tensor(audio_array, dtype=torch.float32, device=DEVICE).unsqueeze(0)  # (1, T)
    with torch.no_grad():
        mel = redim_model.spec(wav)               # (1, N_MELS, Tm)
    mel = fix_mel_frames(mel, TARGET_FRAMES)     # (1, N_MELS, TARGET_FRAMES)
    mel4d = mel.unsqueeze(1)                     # (1, 1, N_MELS, TARGET_FRAMES)
    return mel4d

MAX_SAMPLES = None  # set e.g. 50 for a quick smoke test


In [15]:
# %%
rows = []
TARGET_CLASS = 1  # Fake (spoof) class index in our wrapper logits

n = len(tcav_subset) if MAX_SAMPLES is None else min(len(tcav_subset), MAX_SAMPLES)
print("Running TCAV on samples:", n)

for i in range(n):
    sample = tcav_subset[i]

    audio_array = sample["audio"]["array"]
    speaker_id = sample.get("speaker_id", None)
    system_id = sample.get("system_id", None)
    key = int(sample["key"])  # 1=real, 0=fake
    true_label = 0 if key == 1 else 1

    x = waveform_to_mel4d(audio_array)

    for layer_key, tcav in all_tcav.items():
        score_for_label = tcav.interpret(
            inputs=x,
            experimental_sets=experimental_sets,
            target=TARGET_CLASS,
        )

        for exp_key, layer_dict in score_for_label.items():
            try:
                pos_idx = int(str(exp_key).split("-")[0])
            except Exception:
                continue
            if not (0 <= pos_idx < len(positive_concepts)):
                continue

            concept_name = positive_concepts[pos_idx].name

            for layer_name, metrics in layer_dict.items():
                sc = metrics.get("sign_count")
                mg = metrics.get("magnitude")
                if sc is None or mg is None:
                    continue

                sc = float(sc.detach().cpu().flatten()[0].item()) if isinstance(sc, torch.Tensor) else float(np.array(sc).flatten()[0])
                mg = float(mg.detach().cpu().flatten()[0].item()) if isinstance(mg, torch.Tensor) else float(np.array(mg).flatten()[0])

                rows.append({
                    "idx": i,
                    "speaker_id": speaker_id,
                    "system_id": system_id,
                    "key": key,
                    "true label": true_label,
                    "layer_key": layer_key,
                    "concept name": concept_name,
                    "layer name": layer_name,
                    "positive percentage": sc,
                    "magnitude": mg,
                })

df_tcav = pd.DataFrame(rows)
print("df_tcav shape:", df_tcav.shape)
display(df_tcav.head())


Running TCAV on samples: 1150
df_tcav shape: (16100, 10)


Unnamed: 0,idx,speaker_id,system_id,key,true label,layer_key,concept name,layer name,positive percentage,magnitude
0,0,LA_0039,A18,1,0,stage4,long_constant_thick,redim.backbone.stage4.2,1.0,1.386472
1,0,LA_0039,A18,1,0,stage4,long_constant_thick_Vibrato,redim.backbone.stage4.2,1.0,1.66574
2,0,LA_0039,A18,1,0,stage4,long_dropping_flat_thick,redim.backbone.stage4.2,1.0,1.12434
3,0,LA_0039,A18,1,0,stage4,long_dropping_flat_thick_Vibrato,redim.backbone.stage4.2,1.0,1.578779
4,0,LA_0039,A18,1,0,stage4,long_dropping_steep_thick,redim.backbone.stage4.2,1.0,0.497841


In [16]:
# %%
if not df_tcav.empty and not acc_df_combined.empty:
    df_tcav = df_tcav.merge(acc_df_combined, on=["layer_key", "concept name", "layer name"], how="left")

out_csv = OUT_DIR / "tcav_ASVspoof_stage4_spoofwrapper.csv"
df_tcav.to_csv(out_csv, index=False)
print("Saved ->", out_csv)


Saved -> /home/SpeakerRec/BioVoice/output/tcav_ASVspoof_stage4_spoofwrapper.csv


In [17]:
# %%
if not df_tcav.empty:
    summary = (
        df_tcav.groupby(["concept name", "true label"])["positive percentage"]
        .mean()
        .reset_index()
        .pivot(index="concept name", columns="true label", values="positive percentage")
    )
    summary.columns = ["Real (0) mean", "Fake (1) mean"]
    summary["Fake-Real"] = summary["Fake (1) mean"] - summary["Real (0) mean"]
    display(summary.sort_values("Fake-Real", ascending=False).head(20))


Unnamed: 0_level_0,Real (0) mean,Fake (1) mean,Fake-Real
concept name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
short_rising_steep_thin,0.60625,0.619403,0.013153
long_dropping_flat_thick_Vibrato,0.60625,0.60597,-0.00028
long_constant_thick_Vibrato,0.664583,0.643284,-0.0213
long_rising_flat_thick,0.59375,0.564179,-0.029571
short_dropping_steep_thin,0.608333,0.577612,-0.030721
long_dropping_flat_thick,0.64375,0.61194,-0.03181
short_constant_thick,0.672917,0.635821,-0.037096
short_dropping_steep_thick,0.65625,0.61791,-0.03834
long_dropping_steep_thick,0.59375,0.555224,-0.038526
long_constant_thick,0.639583,0.6,-0.039583
