In [1]:
import sys
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader

from captum.concept import TCAV, Concept
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -------- Project paths / device --------
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

ATTR_CSV_PATH = PROJECT_ROOT / "data" / "real_fake_paths.csv"
CONCEPT_ROOT  = Path(PROJECT_ROOT / "concept" / "temp_concepts")

# Pick one layer key you want TCAV on:
# LAYER_KEY = "stage5"

CONCEPT_SAMPLES = 100
RANDOM_SAMPLES  = 100
BATCH_SIZE_CONCEPT = 1  # keep 1 (safe if variable T)
FORCE_TRAIN_CAVS = True  # set True if you want to retrain CAVs

# OUT_CSV = Path(f"stage5_temp_concepts_{LAYER_KEY}.csv")

assert ATTR_CSV_PATH.exists(), f"Missing {ATTR_CSV_PATH}"
assert CONCEPT_ROOT.exists(), f"Missing {CONCEPT_ROOT}"

PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda


In [3]:
redim_model = (
    torch.hub.load(
        "IDRnD/ReDimNet",
        "ReDimNet",
        model_name="b5",
        train_type="ptn",
        dataset="vox2",
    )
    .to(DEVICE)
    .eval()
)
print("Loaded ReDimNet successfully.")

with torch.no_grad():
    dummy_wav = torch.zeros(1, 16000, device=DEVICE)
    dummy_mel = redim_model.spec(dummy_wav)  # (1, N_MELS, T)
N_MELS = int(dummy_mel.shape[1])
print("ReDimNet spec N_MELS =", N_MELS)


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet successfully.
ReDimNet spec N_MELS = 72


In [4]:

MODEL_DIR = PROJECT_ROOT / "data" / "models" / "real_vs_fake"

LOGREG_PATH = MODEL_DIR / "logistic_regression.pkl"
SCALER_PATH = MODEL_DIR / "scaler.pkl"

assert LOGREG_PATH.exists(), f"Missing {LOGREG_PATH}"
assert SCALER_PATH.exists(), f"Missing {SCALER_PATH}"

with open(LOGREG_PATH, "rb") as f:
    logreg_clf = pickle.load(f)

with open(SCALER_PATH, "rb") as f:
    scaler = pickle.load(f)

print("Loaded Logistic Regression and Scaler.")

Loaded Logistic Regression and Scaler.


In [5]:
# -------- Wrap ReDimNet -> embeddings --------
class ReDimNetEmbeddingWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, N_MELS, T]
    Output: embeddings [B, D]
    """

    def __init__(self, redim_model, l2_norm_emb: bool = True):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool = redim_model.pool
        self.bn = redim_model.bn
        self.linear = redim_model.linear
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d: torch.Tensor) -> torch.Tensor:
        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)
        return emb


wrapped_model = (
    ReDimNetEmbeddingWrapper(redim_model, l2_norm_emb=True).to(DEVICE).eval()
)

print("wrapped_model (embedding-only) ready.")

wrapped_model (embedding-only) ready.


In [6]:
TARGET_LAYERS = {
    # "stem":   wrapped_model.backbone.stem[0],
    # "stage0": wrapped_model.backbone.stage0[2],
    # "stage1": wrapped_model.backbone.stage1[2],
    # "stage2": wrapped_model.backbone.stage2[2],
    # "stage3": wrapped_model.backbone.stage3[2],
    "stage4": wrapped_model.backbone.stage4[2],
    "stage5": wrapped_model.backbone.stage5[2],
}
# assert LAYER_KEY in TARGET_LAYERS, f"{LAYER_KEY=} not in TARGET_LAYERS: {list(TARGET_LAYERS.keys())}"

In [7]:
def module_name_in_model(model: torch.nn.Module, target_module: torch.nn.Module) -> str:
    for name, mod in model.named_modules():
        if mod is target_module:
            return name
    raise RuntimeError(
        "Could not find the selected layer module in model.named_modules()"
    )

In [8]:
TCAV_DEVICE = torch.device("cpu")
print("TCAV_DEVICE =", TCAV_DEVICE)

redim_model = redim_model.to(TCAV_DEVICE).eval()
wrapped_model = wrapped_model.to(TCAV_DEVICE).eval()

DEVICE = TCAV_DEVICE


TCAV_DEVICE = cpu


In [9]:
class ConceptNPYDataset(Dataset):
    def __init__(self, concept_dir: Path, limit: int | None = None):
        self.files = sorted(concept_dir.glob("*.npy"))
        if not self.files:
            raise RuntimeError(f"No .npy found in {concept_dir}")
        if limit is not None:
            self.files = self.files[:limit]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = np.load(self.files[idx]).astype(np.float32)  # (N_MELS, T)
        if mel.shape[0] != N_MELS:
            raise RuntimeError(f"{self.files[idx].name}: expected {N_MELS} bins, got {mel.shape}")
        x = torch.from_numpy(mel).unsqueeze(0)  # (1, N_MELS, T) on CPU
        return x



def infer_frames_for_random(concept_dirs: list[Path]) -> int:
    for d in concept_dirs:
        f = next(d.glob("*.npy"), None)
        if f is not None:
            mel = np.load(f)
            return int(mel.shape[1])
    raise RuntimeError("Could not infer frames from concept dirs")

class RandomMelDataset(Dataset):
    def __init__(self, n_samples: int, frames: int):
        self.n_samples = n_samples
        self.frames = frames

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        mel = torch.randn(N_MELS, self.frames, dtype=torch.float32)  
        return mel.unsqueeze(0) 


In [10]:
concept_dirs = sorted([d for d in CONCEPT_ROOT.iterdir() if d.is_dir()])
if not concept_dirs:
    raise RuntimeError(f"No concept folders in {CONCEPT_ROOT}")

concept_names = [d.name for d in concept_dirs]
print("Concepts:", concept_names)

TARGET_FRAMES = infer_frames_for_random(concept_dirs)
print("Using fixed frames for TCAV (from concepts):", TARGET_FRAMES)

Concepts: ['long_constant_thick', 'long_constant_thick_Vibrato', 'long_dropping_flat_thick', 'long_dropping_flat_thick_Vibrato', 'long_dropping_steep_thick', 'long_dropping_steep_thin', 'long_rising_flat_thick', 'long_rising_steep_thick', 'long_rising_steep_thin', 'short_constant_thick', 'short_dropping_steep_thick', 'short_dropping_steep_thin', 'short_rising_steep_thick', 'short_rising_steep_thin']
Using fixed frames for TCAV (from concepts): 304


In [11]:
# -------- Prepare concepts (same for all layers) --------
positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    ds = ConceptNPYDataset(cdir, limit=CONCEPT_SAMPLES)
    dl = DataLoader(ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

rand_ds = RandomMelDataset(n_samples=RANDOM_SAMPLES, frames=TARGET_FRAMES)
rand_dl = DataLoader(rand_ds, batch_size=BATCH_SIZE_CONCEPT, shuffle=False, num_workers=0)
random_concept = Concept(id=len(positive_concepts), name="random", data_iter=rand_dl)

experimental_sets = [[c, random_concept] for c in positive_concepts]

# -------- Loop over all layers --------
all_tcav_results = {}  # Store results for each layer
all_acc_dfs = []       # Accumulate accuracy DataFrames

for layer_key, layer_module in TARGET_LAYERS.items():
    
    # Resolve layer name for Captum
    layer_name = module_name_in_model(wrapped_model, layer_module)
    # Initialize TCAV for this layer
    tcav = TCAV(wrapped_model, [layer_name], test_split_ratio=0.33)
    all_tcav_results[layer_key] = tcav




In [12]:
def compute_cav_acc_df(tcav: TCAV, positive_concepts: list[Concept], random_concept: Concept, layer_key: str) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs([[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS)

    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split("-")[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name

        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get("accs", None)
            if acc is None:
                acc = cav_obj.stats.get("acc", None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()
            rows.append({
                "layer_key": layer_key,
                "concept name": concept_name,
                "layer name": layer_name,
                "cav acc": float(acc) if acc is not None else np.nan,
            })
    return pd.DataFrame(rows, columns=["layer_key", "concept name", "layer name", "cav acc"])

# Compute CAV accuracies for all layers
print("Computing CAV accuracies for all layers...")
for layer_key, tcav in all_tcav_results.items():
    print(f"  Computing CAVs for {layer_key}...")
    acc_df = compute_cav_acc_df(tcav, positive_concepts, random_concept, layer_key)
    all_acc_dfs.append(acc_df)
    print(f"    Found {len(acc_df)} CAV accuracies")

acc_df_combined = pd.concat(all_acc_dfs, ignore_index=True) if all_acc_dfs else pd.DataFrame()
print(f"\nTotal CAV accuracies: {len(acc_df_combined)}")
if not acc_df_combined.empty:
    print(acc_df_combined.head(10))


Computing CAV accuracies for all layers...
  Computing CAVs for stage4...


  av = torch.load(fl)
  bias_values = torch.FloatTensor([sklearn_model.intercept_]).to(  # type: ignore
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Found 14 CAV accuracies
  Computing CAVs for stage5...


  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)
  av = torch.load(fl)


    Found 14 CAV accuracies

Total CAV accuracies: 28
  layer_key                      concept name         layer name   cav acc
0    stage4               long_constant_thick  backbone.stage4.2  0.365385
1    stage4       long_constant_thick_Vibrato  backbone.stage4.2  0.423077
2    stage4          long_dropping_flat_thick  backbone.stage4.2  0.346154
3    stage4  long_dropping_flat_thick_Vibrato  backbone.stage4.2  0.346154
4    stage4         long_dropping_steep_thick  backbone.stage4.2  0.365385
5    stage4          long_dropping_steep_thin  backbone.stage4.2  0.461538
6    stage4            long_rising_flat_thick  backbone.stage4.2  0.423077
7    stage4           long_rising_steep_thick  backbone.stage4.2  0.480769
8    stage4            long_rising_steep_thin  backbone.stage4.2  0.307692
9    stage4              short_constant_thick  backbone.stage4.2  0.500000


In [13]:
def fix_mel_frames(mel_3d: torch.Tensor, target_frames: int) -> torch.Tensor:
    """
    mel_3d: (1, N_MELS, T)
    returns: (1, N_MELS, target_frames)
    """
    T = int(mel_3d.shape[-1])
    if T == target_frames:
        return mel_3d
    if T > target_frames:
        start = (T - target_frames) // 2
        return mel_3d[..., start:start + target_frames]
    pad = target_frames - T
    return F.pad(mel_3d, (0, pad), mode="constant", value=0.0)

def wav_path_to_mel4d(path: Path) -> torch.Tensor:
    wav, sr = torchaudio.load(str(path))
    wav = wav[:1, :].float().to(DEVICE)
    with torch.no_grad():
        mel = redim_model.spec(wav)          # (1, N_MELS, T)
    mel = fix_mel_frames(mel, TARGET_FRAMES) # (1, N_MELS, TARGET_FRAMES)
    return mel.unsqueeze(0)                  # (1, 1, N_MELS, TARGET_FRAMES)

def predict_speaker(path: Path) -> tuple[str, float]:
    x = wav_path_to_mel4d(path)
    with torch.no_grad():
        logits = wrapped_model(x)            # (1, num_speakers)
        probs = F.softmax(logits, dim=1)[0]
        pred_id = int(torch.argmax(probs).item())
        pred_name = id_to_speaker[pred_id]
        pred_prob = float(probs[pred_id].item())
    return pred_name, pred_prob


In [14]:
df_attr = pd.read_csv(ATTR_CSV_PATH)

# Expect columns: path, label (0=real, 1=fake)
if "path" not in df_attr.columns or "label" not in df_attr.columns:
    raise RuntimeError(
        f"CSV must contain columns ['path','label']. Got: {list(df_attr.columns)}"
    )

rows = []

TARGET_CLASS = 1  # Fake

for _, r in df_attr.iterrows():
    path = Path(r["path"])
    true_label = int(r["label"])

    if not path.exists():
        continue

    # Prepare input
    x = wav_path_to_mel4d(path)

    # Loop over all layers and collect TCAV scores
    for layer_key, tcav in all_tcav_results.items():
        score_for_label = tcav.interpret(
            inputs=x,
            experimental_sets=experimental_sets,
            target=TARGET_CLASS,
        )

        for exp_key, layer_dict in score_for_label.items():
            try:
                pos_idx = int(str(exp_key).split("-")[0])
            except Exception:
                continue
            if not (0 <= pos_idx < len(positive_concepts)):
                continue

            concept_name = positive_concepts[pos_idx].name

            for layer_name, metrics in layer_dict.items():
                sc = metrics.get("sign_count")
                mg = metrics.get("magnitude")
                if sc is None or mg is None:
                    continue

                if isinstance(sc, torch.Tensor):
                    sc = sc.detach().cpu().tolist()
                if isinstance(mg, torch.Tensor):
                    mg = mg.detach().cpu().tolist()

                rows.append(
                    {
                        "path": str(path),
                        "layer_key": layer_key,
                        "concept name": concept_name,
                        "layer name": layer_name,
                        "positive percentage": float(sc[0]),
                        "magnitude": float(mg[0]),
                        "true label": true_label,  # 0=real, 1=fake
                    }
                )

df_tcav = pd.DataFrame(
    rows,
    columns=[
        "path",
        "layer_key",
        "concept name",
        "layer name",
        "positive percentage",
        "magnitude",
        "true label",
    ],
)

# Merge with CAV accuracies
df_tcav = df_tcav.merge(
    acc_df_combined,
    on=["layer_key", "concept name", "layer name"],
    how="left",
)

OUT_CSV = Path("tcav_real_vs_fake_all_layers.csv")
df_tcav.to_csv(OUT_CSV, index=False)

print(f"Saved → {OUT_CSV}")
print(f"Results shape: {df_tcav.shape}")
print("\nFirst few rows:")
print(df_tcav.head())

  save_dict = torch.load(cavs_path)


Saved → tcav_real_vs_fake_all_layers.csv
Results shape: (5040, 8)

First few rows:
                                               path layer_key  \
0  /home/SpeakerRec/BioVoice/data/wavs/idan_001.wav    stage4   
1  /home/SpeakerRec/BioVoice/data/wavs/idan_001.wav    stage4   
2  /home/SpeakerRec/BioVoice/data/wavs/idan_001.wav    stage4   
3  /home/SpeakerRec/BioVoice/data/wavs/idan_001.wav    stage4   
4  /home/SpeakerRec/BioVoice/data/wavs/idan_001.wav    stage4   

                       concept name         layer name  positive percentage  \
0               long_constant_thick  backbone.stage4.2                  0.0   
1       long_constant_thick_Vibrato  backbone.stage4.2                  0.0   
2          long_dropping_flat_thick  backbone.stage4.2                  0.0   
3  long_dropping_flat_thick_Vibrato  backbone.stage4.2                  0.0   
4         long_dropping_steep_thick  backbone.stage4.2                  0.0   

   magnitude  true label   cav acc  
0  -0.009237  