
# TCAV on ASVspoof5 Train-Only Subset (ReDimNet + Spoof Wrapper)

This notebook runs TCAV on the **train-only ASVspoof5 subset** you prepared (`A/B/C = 30/15/5`, 50 speakers total).

Key choices in this version:
- Consistent labels with explicit enum:
  - `label_str in {bonafide, spoof}`
  - `label_id: bonafide=0, spoof=1`
- `TARGET_CLASS_ID = 1` (spoof) in wrapper logits `[bonafide, spoof]`
- **Recompute CAVs** (configurable, default `True`)
- Faster/safer CSV generation via:
  - checkpoint chunk writes (`CHECKPOINT_EVERY_N`)
  - resume support (`RESUME_FROM_PARTIAL`)
  - optional mel cache to disk (`ENABLE_MEL_CACHE`)

Expected server location for this notebook:
- `/home/SpeakerRec/BioVoice/redimnet/tcav/deepfakes/asvspoof5/`

Expected inputs already uploaded on SSH:
- `asvspoof5_train_only_selected_utterances_plan.csv` (in same folder as this notebook)
- subset audio root: `/home/SpeakerRec/BioVoice/data/datasets/asvspoof5_train_only_subset_audio` with folders `A/`, `B/`, `C/`


In [23]:

# Imports
import sys
import time
import json
import pickle
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from captum.concept import TCAV, Concept
from torch.utils.data import DataLoader, Dataset

print('torch:', torch.__version__)


torch: 2.1.2+cu121


In [33]:

# Paths + config (SSH defaults)
PROJECT_ROOT = Path('/home/SpeakerRec/BioVoice')
sys.path.append(str(PROJECT_ROOT))

# Notebook folder contains the uploaded plan CSVs
NOTEBOOK_DIR = PROJECT_ROOT / 'redimnet' / 'tcav' / 'deepfakes' / 'asvspoof5'
PLAN_CSV = NOTEBOOK_DIR / 'asvspoof5_train_only_selected_utterances_plan.csv'
PLAN_SUMMARY_JSON = NOTEBOOK_DIR / 'asvspoof5_train_only_plan_summary.json'

# Extracted train-only subset audio root created from your local extraction+zip pipeline
AUDIO_ROOT = PROJECT_ROOT / 'data' / 'datasets' / 'asvspoof5_train_only_subset_audio'

# Concepts + probe (same style as previous notebook; adjust if needed)
CONCEPT_ROOT = PROJECT_ROOT / 'concept' / 'final_concepts'
MODEL_DIR = PROJECT_ROOT / 'data' / 'models' / 'asvspoof_probe_50_50'
LOGREG_PATH = MODEL_DIR / 'logistic_regression.pkl'
SCALER_PATH = MODEL_DIR / 'scaler.pkl'

OUT_DIR = PROJECT_ROOT / 'data' / 'tcav'
OUT_DIR.mkdir(parents=True, exist_ok=True)
RUN_TAG = 'ASVspoof5_train_only_stage4_spoofwrapper'
RUN_DIR = OUT_DIR / RUN_TAG
RUN_DIR.mkdir(parents=True, exist_ok=True)

# Device controls
# PREPROCESS_DEVICE can use GPU to speed up audio->mel preprocessing (ReDimNet.spec).
# TCAV_DEVICE controls where the TCAV wrapper model runs.
# If TCAV_DEVICE stays CPU, only preprocessing gets GPU acceleration.
# PREPROCESS_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PREPROCESS_DEVICE = torch.device('cpu')
TCAV_DEVICE = torch.device('cpu')

# Wrapper/model device used by TCAV interpret()
MODEL_DEVICE = TCAV_DEVICE

# Label enum (keep this consistent everywhere)
LABEL_TO_ID = {'bonafide': 0, 'spoof': 1}
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
TARGET_CLASS_NAME = 'spoof'
TARGET_CLASS_ID = LABEL_TO_ID[TARGET_CLASS_NAME]  # wrapper logits index [bonafide, spoof]

# Run controls
RUN_GROUPS = ['A', 'B', 'C']          # You can set ['A'] for first run if needed
MAX_SAMPLES = None                    # e.g. 100 for smoke test
RECOMPUTE_CAVS = True                 # user requested: recompute CAVs
CHECKPOINT_EVERY_N = 50               # write partial CSV every N samples
RESUME_FROM_PARTIAL = False            # skip already processed utt_id if partial CSV exists

# Speed / reliability helpers for the CSV-generation stage
ENABLE_MEL_CACHE = True
MEL_CACHE_DIR = RUN_DIR / 'mel_cache_stage4_inputs'
MEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
PARTIAL_CSV = RUN_DIR / f'{RUN_TAG}__partial.csv'
FINAL_CSV = RUN_DIR / f'{RUN_TAG}.csv'
PROGRESS_JSON = RUN_DIR / f'{RUN_TAG}__progress.json'
CAV_ACC_CSV = RUN_DIR / f'{RUN_TAG}__concept_cav_acc.csv'

print('NOTEBOOK_DIR =', NOTEBOOK_DIR)
print('PLAN_CSV =', PLAN_CSV, '| exists =', PLAN_CSV.exists())
print('AUDIO_ROOT =', AUDIO_ROOT, '| exists =', AUDIO_ROOT.exists())
print('CONCEPT_ROOT =', CONCEPT_ROOT, '| exists =', CONCEPT_ROOT.exists())
print('LOGREG_PATH =', LOGREG_PATH, '| exists =', LOGREG_PATH.exists())
print('SCALER_PATH =', SCALER_PATH, '| exists =', SCALER_PATH.exists())
print('RUN_DIR =', RUN_DIR)
print('PREPROCESS_DEVICE =', PREPROCESS_DEVICE)
print('TCAV_DEVICE =', TCAV_DEVICE)
print('MODEL_DEVICE =', MODEL_DEVICE)
print('TARGET_CLASS_ID =', TARGET_CLASS_ID, '(', TARGET_CLASS_NAME, ')')


NOTEBOOK_DIR = /home/SpeakerRec/BioVoice/redimnet/tcav/deepfakes/asvspoof5
PLAN_CSV = /home/SpeakerRec/BioVoice/redimnet/tcav/deepfakes/asvspoof5/asvspoof5_train_only_selected_utterances_plan.csv | exists = True
AUDIO_ROOT = /home/SpeakerRec/BioVoice/data/datasets/asvspoof5_train_only_subset_audio | exists = True
CONCEPT_ROOT = /home/SpeakerRec/BioVoice/concept/final_concepts | exists = True
LOGREG_PATH = /home/SpeakerRec/BioVoice/data/models/asvspoof_probe_50_50/logistic_regression.pkl | exists = True
SCALER_PATH = /home/SpeakerRec/BioVoice/data/models/asvspoof_probe_50_50/scaler.pkl | exists = True
RUN_DIR = /home/SpeakerRec/BioVoice/data/tcav/ASVspoof5_train_only_stage4_spoofwrapper
PREPROCESS_DEVICE = cpu
TCAV_DEVICE = cpu
MODEL_DEVICE = cpu
TARGET_CLASS_ID = 1 ( spoof )


In [25]:

# Load subset plan and build SSH audio paths (no local paths used)
assert PLAN_CSV.exists(), f'Missing plan CSV: {PLAN_CSV}'
assert AUDIO_ROOT.exists(), f'Missing subset audio root: {AUDIO_ROOT}'

plan_df = pd.read_csv(PLAN_CSV)
required_cols = {'group','partition','speaker_id','utt_id','gender','label','system_id'}
missing = sorted(required_cols - set(plan_df.columns))
assert not missing, f'Missing columns in plan CSV: {missing}'

plan_df['utt_id'] = plan_df['utt_id'].astype(str)
plan_df['speaker_id'] = plan_df['speaker_id'].astype(str)
plan_df['group'] = plan_df['group'].astype(str)
plan_df['label'] = plan_df['label'].astype(str)
plan_df['label_id'] = plan_df['label'].map(LABEL_TO_ID)
assert plan_df['label_id'].notna().all(), 'Unexpected label values in plan CSV'
plan_df['label_id'] = plan_df['label_id'].astype(int)

# Reconstruct audio path by extraction convention: AUDIO_ROOT/{group}/{label}/{utt_id}.flac (or .wav fallback)
def resolve_audio_path(row):
    p_flac = AUDIO_ROOT / row['group'] / row['label'] / f"{row['utt_id']}.flac"
    if p_flac.exists():
        return str(p_flac)
    p_wav = AUDIO_ROOT / row['group'] / row['label'] / f"{row['utt_id']}.wav"
    if p_wav.exists():
        return str(p_wav)
    return None

plan_df['audio_path'] = plan_df.apply(resolve_audio_path, axis=1)
missing_audio = plan_df['audio_path'].isna().sum()
print('Plan rows total:', len(plan_df))
print('Rows with missing audio_path:', int(missing_audio))
assert missing_audio == 0, 'Some audio files are missing under AUDIO_ROOT. Check uploaded subset folders.'

# Filter groups for this run
run_df = plan_df[plan_df['group'].isin(RUN_GROUPS)].copy().reset_index(drop=True)
if MAX_SAMPLES is not None:
    run_df = run_df.head(int(MAX_SAMPLES)).copy().reset_index(drop=True)

print('Run rows:', len(run_df))
print('Groups:', run_df['group'].value_counts().sort_index().to_dict())
print('Labels:', run_df['label'].value_counts().sort_index().to_dict())
print('Spoof systems:', sorted(run_df.loc[run_df['label']=='spoof','system_id'].unique().tolist()))
display(run_df.head())


Plan rows total: 3200
Rows with missing audio_path: 0
Run rows: 3200
Groups: {'A': 1920, 'B': 960, 'C': 320}
Labels: {'bonafide': 1600, 'spoof': 1600}
Spoof systems: ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08']


Unnamed: 0,group,partition,speaker_id,utt_id,gender,label,system_id,codec_id,codec_q,source_utt_id,attack_codec_id,selected_reason,label_id,audio_path
0,A,train,T_0170,T_0000003481,F,bonafide,bonafide,-,-,-,-,bonafide_quota,0,/home/SpeakerRec/BioVoice/data/datasets/asvspo...
1,A,train,T_0170,T_0000012452,F,bonafide,bonafide,-,-,-,-,bonafide_quota,0,/home/SpeakerRec/BioVoice/data/datasets/asvspo...
2,A,train,T_0170,T_0000014646,F,bonafide,bonafide,-,-,-,-,bonafide_quota,0,/home/SpeakerRec/BioVoice/data/datasets/asvspo...
3,A,train,T_0170,T_0000018611,F,bonafide,bonafide,-,-,-,-,bonafide_quota,0,/home/SpeakerRec/BioVoice/data/datasets/asvspo...
4,A,train,T_0170,T_0000023686,F,bonafide,bonafide,-,-,-,-,bonafide_quota,0,/home/SpeakerRec/BioVoice/data/datasets/asvspo...


In [26]:

# Audio loader + ReDimNet + wrapper probe
# Uses soundfile if available, falls back to torchaudio.
try:
    import soundfile as sf
    _HAS_SF = True
except Exception:
    sf = None
    _HAS_SF = False
    import torchaudio


def load_audio_16k_mono(audio_path: str) -> np.ndarray:
    p = str(audio_path)
    if _HAS_SF:
        wav, sr = sf.read(p)
        if wav.ndim == 2:
            wav = wav.mean(axis=1)
        wav = wav.astype(np.float32)
        if sr != 16000:
            # light fallback via torchaudio for resampling if needed
            w = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)
            w = torchaudio.functional.resample(w, sr, 16000)
            wav = w.squeeze(0).cpu().numpy().astype(np.float32)
        return wav
    else:
        w, sr = torchaudio.load(p)
        if w.shape[0] > 1:
            w = w.mean(dim=0, keepdim=True)
        if sr != 16000:
            w = torchaudio.functional.resample(w, sr, 16000)
        return w.squeeze(0).cpu().numpy().astype(np.float32)

# Load ReDimNet
redim_model = (
    torch.hub.load(
        'IDRnD/ReDimNet',
        'ReDimNet',
        model_name='b5',
        train_type='ptn',
        dataset='vox2',
    )
    .to(MODEL_DEVICE)
    .eval()
)

with torch.no_grad():
    dummy_wav = torch.zeros(1, 16000, device=MODEL_DEVICE)
    dummy_mel = redim_model.spec(dummy_wav)
N_MELS = int(dummy_mel.shape[1])
print('Loaded ReDimNet. N_MELS =', N_MELS)

# Load scaler + logistic probe
with open(LOGREG_PATH, 'rb') as f:
    logreg_clf = pickle.load(f)
with open(SCALER_PATH, 'rb') as f:
    scaler = pickle.load(f)
print('Loaded probe:', type(logreg_clf), type(scaler))

class ReDimNetSpoofWrapper(nn.Module):
    def __init__(self, redim_model, W, b, mean, scale, l2_norm_emb=True):
        super().__init__()
        self.redim = redim_model
        self.l2_norm_emb = l2_norm_emb
        D = W.shape[1]
        self.register_buffer('mean', torch.tensor(mean, dtype=torch.float32))
        self.register_buffer('scale', torch.tensor(scale, dtype=torch.float32))
        self.linear = nn.Linear(D, 1)
        self.linear.weight.data = torch.tensor(W, dtype=torch.float32)
        self.linear.bias.data = torch.tensor(b, dtype=torch.float32)

    def forward(self, mel4d):
        x = self.redim.backbone(mel4d)
        x = self.redim.pool(x)
        x = self.redim.bn(x)
        emb = self.redim.linear(x)
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)
        emb = (emb - self.mean) / self.scale
        score = self.linear(emb)  # [B,1]
        logits = torch.cat([-score, score], dim=1)  # [B,2] = [bonafide, spoof]
        return logits

spoof_model = ReDimNetSpoofWrapper(
    redim_model,
    W=logreg_clf.coef_,
    b=logreg_clf.intercept_,
    mean=scaler.mean_,
    scale=scaler.scale_,
).to(MODEL_DEVICE).eval()

print('Created ReDimNetSpoofWrapper. Logits order = [bonafide, spoof]')
print('Note: If TCAV_DEVICE=cpu, the wrapper forward (backbone+embedding+probe) still runs on CPU.')


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet. N_MELS = 72
Loaded probe: <class 'sklearn.linear_model._logistic.LogisticRegression'> <class 'sklearn.preprocessing._data.StandardScaler'>
Created ReDimNetSpoofWrapper. Logits order = [bonafide, spoof]
Note: If TCAV_DEVICE=cpu, the wrapper forward (backbone+embedding+probe) still runs on CPU.


In [27]:

# Concepts + TCAV setup (recompute CAVs by default)
assert CONCEPT_ROOT.exists(), f'Missing concept root: {CONCEPT_ROOT}'

# Random concept creation (auto-generate if missing)
AUTO_CREATE_RANDOM_CONCEPT = True
RANDOM_CONCEPT_NAME = 'random'
RANDOM_CONCEPT_N = 60
RANDOM_CONCEPT_SEED = 42

# Minimal concept dataset for Captum Concept/DataLoader using .npy files in concept folders
class NpyConceptDataset(Dataset):
    def __init__(self, npy_files: list[Path]):
        self.npy_files = list(npy_files)
    def __len__(self):
        return len(self.npy_files)
    def __getitem__(self, idx):
        arr = np.load(self.npy_files[idx]).astype(np.float32)
        x = torch.tensor(arr)
        # Expected per-sample shape for ReDimNet conv input is [1, H, W].
        # Many concept files are saved as [H, W], so add the channel dim.
        if x.ndim == 2:
            x = x.unsqueeze(0)
        elif x.ndim == 3 and x.shape[0] != 1 and x.shape[-1] == 1:
            # Convert [H, W, 1] -> [1, H, W] if needed.
            x = x.permute(2, 0, 1)
        if x.ndim != 3 or x.shape[0] != 1:
            raise RuntimeError(f'Unexpected concept tensor shape {tuple(x.shape)} in {self.npy_files[idx]} (expected [1,H,W])')
        return x

def concept_loader_from_dir(cdir: Path, batch_size: int = 16) -> DataLoader:
    npy_files = sorted(list(cdir.glob('*.npy')))
    assert len(npy_files) > 0, f'No .npy files in concept dir: {cdir}'
    ds = NpyConceptDataset(npy_files)
    return DataLoader(ds, batch_size=batch_size, shuffle=True)

concept_dirs = [p for p in sorted(CONCEPT_ROOT.iterdir()) if p.is_dir() and p.name != RANDOM_CONCEPT_NAME]
assert len(concept_dirs) > 0, f'No concept directories found in {CONCEPT_ROOT}'

positive_concepts = []
for idx, cdir in enumerate(concept_dirs):
    dl = concept_loader_from_dir(cdir)
    positive_concepts.append(Concept(id=idx, name=cdir.name, data_iter=dl))

# random concept directory (TCAV baseline). If missing, optionally auto-create by mixing samples from other concepts.
random_dir = CONCEPT_ROOT / RANDOM_CONCEPT_NAME
if (not random_dir.exists()) and AUTO_CREATE_RANDOM_CONCEPT:
    random_dir.mkdir(parents=True, exist_ok=True)
    pool = []
    for cdir in concept_dirs:
        pool.extend(sorted(cdir.glob('*.npy')))
    if len(pool) < RANDOM_CONCEPT_N:
        raise RuntimeError(f'Not enough .npy files to create random concept: pool={len(pool)} need={RANDOM_CONCEPT_N}')
    rng = np.random.default_rng(RANDOM_CONCEPT_SEED)
    chosen = rng.choice(np.array(pool, dtype=object), size=RANDOM_CONCEPT_N, replace=False)
    # Copy files into random/ with stable numbering (keeps original concept dirs untouched).
    import shutil
    for i, src_npy in enumerate(chosen, start=1):
        dst = random_dir / f"{i:06d}.npy"
        shutil.copy2(Path(src_npy), dst)
    # Basic metadata for traceability
    (random_dir / 'meta.json').write_text(json.dumps({
        'kind': 'auto_random_mixed_concepts',
        'n_files': RANDOM_CONCEPT_N,
        'seed': RANDOM_CONCEPT_SEED,
        'source_root': str(CONCEPT_ROOT),
        'excluded_dir': RANDOM_CONCEPT_NAME,
    }, indent=2), encoding='utf-8')
    print(f'[INFO] Auto-created random concept at {random_dir} with {RANDOM_CONCEPT_N} mixed .npy files')

if random_dir.exists() and random_dir.is_dir():
    rand_dl = concept_loader_from_dir(random_dir)
else:
    raise RuntimeError(f'Missing random concept directory: {random_dir}. Set AUTO_CREATE_RANDOM_CONCEPT=True or create {random_dir}/*.npy')

random_concept = Concept(id=len(positive_concepts), name=RANDOM_CONCEPT_NAME, data_iter=rand_dl)
experimental_sets = [[c, random_concept] for c in positive_concepts]

# Shape sanity checks (common TCAV failure point)
first_concept_batch = next(iter(positive_concepts[0].data_iter))
first_random_batch = next(iter(random_concept.data_iter))
print('First concept batch shape:', tuple(first_concept_batch.shape))
print('First random batch shape:', tuple(first_random_batch.shape))
assert first_concept_batch.ndim == 4 and first_concept_batch.shape[1] == 1,     f'Concept batch must be [B,1,H,W], got {tuple(first_concept_batch.shape)}'
assert first_random_batch.ndim == 4 and first_random_batch.shape[1] == 1,     f'Random concept batch must be [B,1,H,W], got {tuple(first_random_batch.shape)}'

TARGET_LAYERS = {
    'stage4': redim_model.backbone.stage4[2],
}

all_tcav = {}
for layer_key, layer_module in TARGET_LAYERS.items():
    layer_name = getattr(layer_module, 'layer_name', None) or 'redim.backbone.stage4.2'
    print('Layer:', layer_key, '->', layer_name)
    all_tcav[layer_key] = TCAV(spoof_model, [layer_name], test_split_ratio=0.33)

FORCE_TRAIN_CAVS = bool(RECOMPUTE_CAVS)
print('Prepared', len(positive_concepts), 'concepts + random.')
print('FORCE_TRAIN_CAVS =', FORCE_TRAIN_CAVS)


First concept batch shape: (16, 1, 72, 304)
First random batch shape: (16, 1, 72, 304)
Layer: stage4 -> redim.backbone.stage4.2
Prepared 28 concepts + random.
FORCE_TRAIN_CAVS = True




In [28]:

# Compute and save CAV accuracies (recomputed if FORCE_TRAIN_CAVS=True)
def compute_cav_acc_df(tcav: TCAV, positive_concepts: list[Concept], random_concept: Concept, layer_key: str) -> pd.DataFrame:
    cavs_dict = tcav.compute_cavs([[c, random_concept] for c in positive_concepts], force_train=FORCE_TRAIN_CAVS)
    rows = []
    for concepts_key, layer_map in cavs_dict.items():
        try:
            pos_id = int(str(concepts_key).split('-')[0])
        except Exception:
            continue
        if not (0 <= pos_id < len(positive_concepts)):
            continue
        concept_name = positive_concepts[pos_id].name
        for layer_name, cav_obj in layer_map.items():
            if cav_obj is None or cav_obj.stats is None:
                continue
            acc = cav_obj.stats.get('accs', None)
            if acc is None:
                acc = cav_obj.stats.get('acc', None)
            if isinstance(acc, torch.Tensor):
                acc = acc.detach().cpu().item()
            rows.append({
                'layer_key': layer_key,
                'concept_name': concept_name,
                'layer_name': layer_name,
                'cav_acc': float(acc) if acc is not None else np.nan,
            })
    return pd.DataFrame(rows)

print('Computing CAV accuracies...')
acc_dfs = []
for layer_key, tcav in all_tcav.items():
    df_acc = compute_cav_acc_df(tcav, positive_concepts, random_concept, layer_key)
    print(layer_key, 'rows:', len(df_acc))
    acc_dfs.append(df_acc)
acc_df_combined = pd.concat(acc_dfs, ignore_index=True) if acc_dfs else pd.DataFrame()
acc_df_combined.to_csv(CAV_ACC_CSV, index=False)
print('Saved CAV acc ->', CAV_ACC_CSV)
display(acc_df_combined.head(20))


Computing CAV accuracies...
stage4 rows: 28
Saved CAV acc -> /home/SpeakerRec/BioVoice/data/tcav/ASVspoof5_train_only_stage4_spoofwrapper/ASVspoof5_train_only_stage4_spoofwrapper__concept_cav_acc.csv


Unnamed: 0,layer_key,concept_name,layer_name,cav_acc
0,stage4,arch_long_thick,redim.backbone.stage4.2,0.461538
1,stage4,arch_long_thin,redim.backbone.stage4.2,0.512821
2,stage4,arch_short_thick,redim.backbone.stage4.2,0.435897
3,stage4,arch_short_thin,redim.backbone.stage4.2,0.512821
4,stage4,const_long_thick,redim.backbone.stage4.2,0.694915
5,stage4,const_long_thin,redim.backbone.stage4.2,0.661017
6,stage4,const_short_thick,redim.backbone.stage4.2,0.644068
7,stage4,const_short_thin,redim.backbone.stage4.2,0.661017
8,stage4,dropping_long_thick,redim.backbone.stage4.2,0.692308
9,stage4,dropping_long_thin,redim.backbone.stage4.2,0.487179


In [37]:
# Mel preprocessing helpers + optional disk cache
TARGET_FRAMES = 304  # keep same style as prior notebook; adjust if your concept tensors expect a different length

def fix_mel_frames(mel: torch.Tensor, target_frames: int) -> torch.Tensor:
    # mel: (B, N_MELS, T)
    T = mel.shape[-1]
    if T == target_frames:
        return mel
    if T > target_frames:
        return mel[..., :target_frames]
    pad = target_frames - T
    return F.pad(mel, (0, pad))

def mel_cache_path_for_utt(utt_id: str) -> Path:
    return MEL_CACHE_DIR / f'{utt_id}.npy'

def waveform_to_mel4d_from_audio_path(audio_path: str, utt_id: str) -> torch.Tensor:
    cache_path = mel_cache_path_for_utt(utt_id)
    if ENABLE_MEL_CACHE and cache_path.exists():
        arr = np.load(cache_path)
        return torch.tensor(arr, dtype=torch.float32, device=PREPROCESS_DEVICE)

    wav_np = load_audio_16k_mono(audio_path)
    wav = torch.tensor(wav_np, dtype=torch.float32, device=PREPROCESS_DEVICE).unsqueeze(0)  # (1,T)
    with torch.no_grad():
        mel = redim_model.spec(wav)  # (1,N_MELS,Tm)
    mel = fix_mel_frames(mel, TARGET_FRAMES)
    mel4d = mel.unsqueeze(1)  # (1,1,N_MELS,TARGET_FRAMES)

    if ENABLE_MEL_CACHE:
        np.save(cache_path, mel4d.detach().cpu().numpy().astype(np.float32))
    # TCAV wrapper runs on TCAV_DEVICE; move tensor explicitly before returning.
    mel4d = mel4d.to(TCAV_DEVICE)
    return mel4d

print('ENABLE_MEL_CACHE =', ENABLE_MEL_CACHE)
print('MEL_CACHE_DIR =', MEL_CACHE_DIR)

ENABLE_MEL_CACHE = True
MEL_CACHE_DIR = /home/SpeakerRec/BioVoice/data/tcav/ASVspoof5_train_only_stage4_spoofwrapper/mel_cache_stage4_inputs


In [38]:
# TCAV scoring loop -> CSV with checkpointing/resume (this is the expensive 'CSV creation' stage)
# Output is long-format: one row per (sample, concept) per layer.


def load_processed_utts_from_partial(partial_csv: Path) -> set[str]:
    if (not RESUME_FROM_PARTIAL) or (not partial_csv.exists()):
        return set()
    try:
        tmp = pd.read_csv(partial_csv)
        if "utt_id" not in tmp.columns:
            return set()
        # Only treat utt_ids as processed if at least one non-error TCAV row exists.
        if "layer_key" in tmp.columns:
            tmp = tmp[tmp["layer_key"] != "ERROR"]
        return set(tmp["utt_id"].astype(str).unique().tolist())
    except Exception as e:
        print("[WARN] Failed to read partial CSV for resume, starting fresh:", e)
        return set()


def _scalarize_metric(x):
    if isinstance(x, torch.Tensor):
        return float(x.detach().cpu().flatten()[0].item())
    return float(np.array(x).flatten()[0])


rows_buffer = []
start_time = time.time()
processed_utts = load_processed_utts_from_partial(PARTIAL_CSV)
if (not RESUME_FROM_PARTIAL) and PARTIAL_CSV.exists():
    print(
        f"[INFO] RESUME_FROM_PARTIAL=False -> removing old partial CSV: {PARTIAL_CSV}"
    )
    PARTIAL_CSV.unlink()

print(
    "Resume enabled =",
    RESUME_FROM_PARTIAL,
    "| already processed utt_ids =",
    len(processed_utts),
)
if PREPROCESS_DEVICE.type == "cuda" and TCAV_DEVICE.type == "cpu":
    print(
        "[INFO] Using GPU for mel preprocessing only; TCAV wrapper forward runs on CPU. Inputs will be moved to TCAV_DEVICE before interpret()."
    )

# Determine run order deterministically
iter_df = run_df.copy()
iter_df = iter_df.sort_values(
    ["group", "speaker_id", "label", "system_id", "utt_id"]
).reset_index(drop=True)

n_total = len(iter_df)
n_done = 0
n_skipped = 0
n_failed = 0

for i, rec in iter_df.iterrows():
    utt_id = str(rec["utt_id"])
    if utt_id in processed_utts:
        n_skipped += 1
        continue

    try:
        x = waveform_to_mel4d_from_audio_path(rec["audio_path"], utt_id=utt_id)
        # Defensive device alignment: cached/preprocessed tensors must match TCAV/model device
        x = x.to(TCAV_DEVICE)

        for layer_key, tcav in all_tcav.items():
            score_for_label = tcav.interpret(
                inputs=x,
                experimental_sets=experimental_sets,
                target=TARGET_CLASS_ID,
            )

            for exp_key, layer_dict in score_for_label.items():
                try:
                    pos_idx = int(str(exp_key).split("-")[0])
                except Exception:
                    continue
                if not (0 <= pos_idx < len(positive_concepts)):
                    continue
                concept_name = positive_concepts[pos_idx].name

                for layer_name, metrics in layer_dict.items():
                    sc = metrics.get("sign_count")
                    mg = metrics.get("magnitude")
                    if sc is None or mg is None:
                        continue

                    rows_buffer.append(
                        {
                            "row_index": int(i),
                            "utt_id": utt_id,
                            "group": str(rec["group"]),
                            "partition": str(rec["partition"]),
                            "speaker_id": str(rec["speaker_id"]),
                            "gender": str(rec["gender"]),
                            "label_str": str(rec["label"]),
                            "label_id": int(rec["label_id"]),
                            "system_id": str(rec["system_id"]),
                            "target_class_name": TARGET_CLASS_NAME,
                            "target_class_id": int(TARGET_CLASS_ID),
                            "layer_key": layer_key,
                            "concept_name": concept_name,
                            "layer_name": layer_name,
                            "positive_percentage": _scalarize_metric(sc),
                            "magnitude": _scalarize_metric(mg),
                        }
                    )

        processed_utts.add(utt_id)
        n_done += 1

    except Exception as e:
        n_failed += 1
        rows_buffer.append(
            {
                "row_index": int(i),
                "utt_id": utt_id,
                "group": str(rec["group"]),
                "partition": str(rec["partition"]),
                "speaker_id": str(rec["speaker_id"]),
                "gender": str(rec["gender"]),
                "label_str": str(rec["label"]),
                "label_id": int(rec["label_id"]),
                "system_id": str(rec["system_id"]),
                "target_class_name": TARGET_CLASS_NAME,
                "target_class_id": int(TARGET_CLASS_ID),
                "layer_key": "ERROR",
                "concept_name": "ERROR",
                "layer_name": "ERROR",
                "positive_percentage": np.nan,
                "magnitude": np.nan,
                "error": repr(e),
            }
        )
        print(f"[WARN] Failed on utt_id={utt_id}: {e}")

    # Checkpoint by sample count
    if (n_done + n_failed) % CHECKPOINT_EVERY_N == 0 and rows_buffer:
        chunk_df = pd.DataFrame(rows_buffer)
        write_header = not PARTIAL_CSV.exists()
        chunk_df.to_csv(PARTIAL_CSV, mode="a", index=False, header=write_header)
        rows_buffer = []
        elapsed = time.time() - start_time
        progress = {
            "n_total_iter_rows": int(n_total),
            "n_done": int(n_done),
            "n_skipped_resume": int(n_skipped),
            "n_failed": int(n_failed),
            "elapsed_sec": float(elapsed),
            "partial_csv": str(PARTIAL_CSV),
        }
        PROGRESS_JSON.write_text(json.dumps(progress, indent=2), encoding="utf-8")
        print(
            f"[CHKPT] done={n_done} skipped={n_skipped} failed={n_failed} elapsed={elapsed/60:.1f} min"
        )

# Flush remaining rows
if rows_buffer:
    chunk_df = pd.DataFrame(rows_buffer)
    write_header = not PARTIAL_CSV.exists()
    chunk_df.to_csv(PARTIAL_CSV, mode="a", index=False, header=write_header)
    rows_buffer = []

print("Scoring loop complete.")
print("done =", n_done, "| skipped =", n_skipped, "| failed =", n_failed)
print("Partial CSV =", PARTIAL_CSV)

Resume enabled = False | already processed utt_ids = 0
[CHKPT] done=50 skipped=0 failed=0 elapsed=7.9 min
[CHKPT] done=100 skipped=0 failed=0 elapsed=15.8 min
[CHKPT] done=150 skipped=0 failed=0 elapsed=23.8 min
[CHKPT] done=200 skipped=0 failed=0 elapsed=31.6 min
[CHKPT] done=250 skipped=0 failed=0 elapsed=39.5 min
[CHKPT] done=300 skipped=0 failed=0 elapsed=47.4 min
[CHKPT] done=350 skipped=0 failed=0 elapsed=55.2 min
[CHKPT] done=400 skipped=0 failed=0 elapsed=63.1 min
[CHKPT] done=450 skipped=0 failed=0 elapsed=70.9 min
[CHKPT] done=500 skipped=0 failed=0 elapsed=78.8 min
[CHKPT] done=550 skipped=0 failed=0 elapsed=86.7 min
[CHKPT] done=600 skipped=0 failed=0 elapsed=94.8 min
[CHKPT] done=650 skipped=0 failed=0 elapsed=102.9 min
[CHKPT] done=700 skipped=0 failed=0 elapsed=110.8 min
[CHKPT] done=750 skipped=0 failed=0 elapsed=118.9 min
[CHKPT] done=800 skipped=0 failed=0 elapsed=126.8 min
[CHKPT] done=850 skipped=0 failed=0 elapsed=134.8 min
[CHKPT] done=900 skipped=0 failed=0 elaps

In [31]:

# Finalize CSV: merge CAV accuracies, save final file, and sanity checks
assert PARTIAL_CSV.exists(), f'Missing partial CSV: {PARTIAL_CSV}'
df_tcav = pd.read_csv(PARTIAL_CSV)
print('Partial df_tcav shape:', df_tcav.shape)

# Remove error rows from final scoring CSV (keep them separately if needed)
if 'error' in df_tcav.columns:
    err_rows = df_tcav[df_tcav['layer_key'] == 'ERROR'].copy()
    if not err_rows.empty:
        err_csv = RUN_DIR / f'{RUN_TAG}__errors.csv'
        err_rows.to_csv(err_csv, index=False)
        print('Saved error rows ->', err_csv, '| count =', len(err_rows))
    df_tcav = df_tcav[df_tcav['layer_key'] != 'ERROR'].copy()

if not acc_df_combined.empty:
    df_tcav = df_tcav.merge(
        acc_df_combined,
        on=['layer_key', 'concept_name', 'layer_name'],
        how='left'
    )

# Consistent final column names (underscore style)
final_cols = [
    'row_index','utt_id','group','partition','speaker_id','gender',
    'label_str','label_id','system_id','target_class_name','target_class_id',
    'layer_key','concept_name','layer_name','positive_percentage','magnitude','cav_acc'
]
extra_cols = [c for c in df_tcav.columns if c not in final_cols]
df_tcav = df_tcav[[c for c in final_cols if c in df_tcav.columns] + extra_cols]

df_tcav.to_csv(FINAL_CSV, index=False)
print('Saved final CSV ->', FINAL_CSV)
print('Final shape:', df_tcav.shape)

print('Sanity checks:')
print('Unique utt_id:', df_tcav['utt_id'].nunique())
print('Rows per utt (value counts):')
print(df_tcav.groupby('utt_id').size().value_counts().sort_index())
print('Labels:', df_tcav[['utt_id','label_str']].drop_duplicates()['label_str'].value_counts().to_dict())
print('Groups:', df_tcav[['utt_id','group']].drop_duplicates()['group'].value_counts().to_dict())
print('Spoof systems in scored rows:', sorted(df_tcav.loc[df_tcav['label_str']=='spoof','system_id'].astype(str).unique().tolist())[:20])

display(df_tcav.head(20))


Partial df_tcav shape: (6400, 17)
Saved error rows -> /home/SpeakerRec/BioVoice/data/tcav/ASVspoof5_train_only_stage4_spoofwrapper/ASVspoof5_train_only_stage4_spoofwrapper__errors.csv | count = 6400
Saved final CSV -> /home/SpeakerRec/BioVoice/data/tcav/ASVspoof5_train_only_stage4_spoofwrapper/ASVspoof5_train_only_stage4_spoofwrapper.csv
Final shape: (0, 18)
Sanity checks:
Unique utt_id: 0
Rows per utt (value counts):
Series([], dtype: int64)
Labels: {}
Groups: {}
Spoof systems in scored rows: []


Unnamed: 0,row_index,utt_id,group,partition,speaker_id,gender,label_str,label_id,system_id,target_class_name,target_class_id,layer_key,concept_name,layer_name,positive_percentage,magnitude,cav_acc,error
