In [1]:
# %%
import sys
from pathlib import Path

import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from wespeaker.cli.speaker import load_model
from pytorch_grad_cam.utils.image import show_cam_on_image  # only for overlay image

PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("PROJECT_ROOT =", PROJECT_ROOT)
print("Using device:", DEVICE)

# Load WeSpeaker backbone
speaker = load_model(PROJECT_ROOT / "wespeaker-voxceleb-resnet293-LM")
backbone = speaker.model.to(DEVICE).eval()
print("ResNet-293 loaded")

# WeSpeaker mel params
SAMPLE_RATE = 16000
N_MELS = 80
FRAME_LENGTH_MS = 25
FRAME_SHIFT_MS = 10
N_FFT = int(SAMPLE_RATE * FRAME_LENGTH_MS / 1000)
HOP_LENGTH = int(SAMPLE_RATE * FRAME_SHIFT_MS / 1000)

print(f"Mel params: n_fft={N_FFT}, hop_length={HOP_LENGTH}, n_mels={N_MELS}")

  from .autonotebook import tqdm as notebook_tqdm
  torchaudio.set_audio_backend("sox_io")
ESPnet is not installed, cannot use espnet_hubert upstream


PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda
{'data_type': 'shard', 'dataloader_args': {'batch_size': 32, 'drop_last': True, 'num_workers': 16, 'pin_memory': False, 'prefetch_factor': 8}, 'dataset_args': {'aug_prob': 0.6, 'fbank_args': {'dither': 1.0, 'frame_length': 25, 'frame_shift': 10, 'num_mel_bins': 80}, 'num_frms': 200, 'shuffle': True, 'shuffle_args': {'shuffle_size': 2500}, 'spec_aug': False, 'spec_aug_args': {'max_f': 8, 'max_t': 10, 'num_f_mask': 1, 'num_t_mask': 1, 'prob': 0.6}, 'speed_perturb': True}, 'exp_dir': 'exp/ResNet293-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150', 'gpus': [0, 1], 'log_batch_interval': 100, 'loss': 'CrossEntropyLoss', 'loss_args': {}, 'margin_scheduler': 'MarginScheduler', 'margin_update': {'epoch_iter': 17062, 'final_margin': 0.2, 'fix_start_epoch': 40, 'increase_start_epoch': 20, 'increase_type': 'exp', 'initial_margin': 0.0, 'update_margin': True}, 'model': 'ResNet293', 'model_args': {'embed_dim': 2

  checkpoint = torch.load(path, map_location="cpu")


ResNet-293 loaded
Mel params: n_fft=400, hop_length=160, n_mels=80


In [2]:
# %%
# ========== HELPERS ==========
def upsample_hw(
    arr: np.ndarray, size_hw: tuple[int, int], mode: str = "bilinear"
) -> np.ndarray:
    arr = np.ascontiguousarray(arr)

    if arr.ndim == 2:
        t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()  # [1,1,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0, 0].cpu().numpy()

    if arr.ndim == 3 and arr.shape[2] == 3:
        t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()  # [1,3,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0].permute(1, 2, 0).cpu().numpy()

    raise ValueError(f"Unsupported shape: {arr.shape}")


def plot_mel(mel2d: np.ndarray, save_path=None, title=None):
    mel_norm = (mel2d - mel2d.min()) / (mel2d.max() - mel2d.min() + 1e-8)
    plt.figure(figsize=(10, 4))
    plt.imshow(
        mel_norm, aspect="auto", cmap="magma", interpolation="bilinear", origin="lower"
    )
    plt.colorbar(label="Normalized Energy")
    plt.xlabel("Time Frames")
    plt.ylabel("Mel Bins")
    if title:
        plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200)
        plt.close()
    else:
        plt.show()


def load_wav_mono_16k(wav_path: str) -> torch.Tensor:
    wav, sr = torchaudio.load(wav_path)
    if wav.shape[0] > 1:
        wav = wav[:1, :]
    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    return wav  # CPU tensor




In [3]:
# %%
import torchaudio.compliance.kaldi as kaldi


def wespeaker_fbank(
    wav_cpu: torch.Tensor,
    sample_rate: int = 16000,
    num_mel_bins: int = 80,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    snip_edges: bool = True,
    dither: float = 0.0,
    do_cmvn: bool = True,
) -> torch.Tensor:
    """
    Returns features in shape: [1, 80, T]  (float32)
    This is Kaldi-style log-mel filterbank, close to WeSpeaker frontend.
    """
    # kaldi.fbank expects CPU float tensor, shape [1, T]
    if wav_cpu.dtype != torch.float32:
        wav_cpu = wav_cpu.float()

    if wav_cpu.dim() == 2 and wav_cpu.size(0) > 1:
        wav_cpu = wav_cpu[:1, :]

    feats = kaldi.fbank(
        wav_cpu,  # [1, T]
        sample_frequency=sample_rate,
        num_mel_bins=num_mel_bins,
        frame_length=frame_length,
        frame_shift=frame_shift,
        dither=dither,
        snip_edges=snip_edges,
        energy_floor=0.0,
        use_energy=False,
    )  # [num_frames, num_mel_bins]

    # CMVN (common in speaker frontends): per-utterance mean/var normalization
    if do_cmvn:
        mean = feats.mean(dim=0, keepdim=True)
        std = feats.std(dim=0, keepdim=True).clamp_min(1e-5)
        feats = (feats - mean) / std

    # Convert to [1, 80, T]
    feats = feats.transpose(0, 1).unsqueeze(0).contiguous()  # [1, F, T]
    return feats

In [4]:
# %%
# ========== TARGET LAYERS ==========
target_layers = [
    backbone.layer1[9].conv3,
    backbone.layer2[19].conv3,
    backbone.layer3[63].conv3,
    backbone.layer4[2].conv3,
]

layer_names = [
    "layer1.9.conv3",
    "layer2.19.conv3",
    "layer3.63.conv3",
    "layer4.2.conv3",
]

print("Target layers:")
for n, l in zip(layer_names, target_layers):
    print(n, "->", l)

Target layers:
layer1.9.conv3 -> Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
layer2.19.conv3 -> Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
layer3.63.conv3 -> Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
layer4.2.conv3 -> Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [5]:
# %%
def _run_backbone_with_layout_fallback(backbone: torch.nn.Module, feats: torch.Tensor):
    """
    Some speech backbones expect [B, F, T] and others [B, T, F].
    We try both layouts to avoid shape headaches.
    """
    try:
        return backbone(feats)  # try [B, F, T]
    except Exception:
        return backbone(feats.transpose(1, 2))  # try [B, T, F]


def _activation_to_cam2d(act: torch.Tensor) -> np.ndarray:
    """
    Convert conv activation to 2D CAM-like map by averaging channels.
    Handles common shapes:
      [B, C, H, W]  -> cam: [H, W]
      [B, C, T]     -> cam: [1, T]  (rare)
      [C, H, W]     -> cam: [H, W]
    """
    if isinstance(act, (tuple, list)):
        # just in case a module returns tuple
        act = next(x for x in act if isinstance(x, torch.Tensor))

    if act.ndim == 4:  # [B,C,H,W]
        cam = act[0].mean(dim=0)  # [H,W]
    elif act.ndim == 3:  # [B,C,T] or [C,H,W]
        if act.shape[0] == 1:  # [B,C,T]
            cam = act[0].mean(dim=0).unsqueeze(0)  # [1,T]
        else:  # [C,H,W]
            cam = act.mean(dim=0)  # [H,W]
    else:
        raise ValueError(f"Unsupported activation shape: {tuple(act.shape)}")

    cam = cam.detach().float().cpu().numpy()
    return cam


def run_activation_cam(
    wav_path: str,
    save_dir: str,
    upscale: int = 10,
    cam_quantile: float = 0.85,
):
    """
    Activation-CAM (no gradients):
    - input: mel features
    - for each target layer: capture activations
    - CAM = mean over channels (activation strength)
    - align to mel size + save weighted/focus/overlay images
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    fname = Path(wav_path).stem

    # 1) Load wav -> mel (GPU)
    wav_cpu = load_wav_mono_16k(wav_path)  # [1,T] CPU
    # 1) wav -> wespeaker fbank (CPU) -> GPU
    feats_cpu = wespeaker_fbank(
        wav_cpu,
        sample_rate=SAMPLE_RATE,
        num_mel_bins=N_MELS,
        frame_length=FRAME_LENGTH_MS,
        frame_shift=FRAME_SHIFT_MS,
        snip_edges=True,   # try False if needed
        dither=0.0,
        do_cmvn=True,
    )  # [1,80,T] CPU

    feats = feats_cpu.to(DEVICE)  # GPU for backbone forward

    # for visualization
    mel2d = feats_cpu[0].numpy()  # [80,T] CPU numpy

    mel_norm = (mel2d - mel2d.min()) / (mel2d.max() - mel2d.min() + 1e-8)

    originals_dir = save_dir / "original"
    originals_dir.mkdir(exist_ok=True)
    plot_mel(mel2d, save_path=originals_dir / f"{fname}_mel.png", title=f"{fname}")

    # 2) Register hooks to capture activations
    acts = {}
    hooks = []

    def make_hook(name):
        def hook_fn(module, inp, out):
            acts[name] = out

        return hook_fn

    for lname, layer in zip(layer_names, target_layers):
        hooks.append(layer.register_forward_hook(make_hook(lname)))

    # 3) Forward pass (NO GRADS)
    with torch.no_grad():
        _ = _run_backbone_with_layout_fallback(backbone, feats)

    # remove hooks
    for h in hooks:
        h.remove()

    # 4) Build outputs per layer
    layer_stats = {}
    for lname in layer_names:
        if lname not in acts:
            print(f"[WARN] No activation captured for {lname}")
            continue

        layer_dir = save_dir / lname
        layer_dir.mkdir(parents=True, exist_ok=True)

        cam_raw = _activation_to_cam2d(acts[lname])  # [H',W'] in layer space
        cam_on_mel = upsample_hw(cam_raw, mel2d.shape, mode="bicubic")

        cam_norm = (cam_on_mel - cam_on_mel.min()) / (
            cam_on_mel.max() - cam_on_mel.min() + 1e-8
        )

        # weighted mel
        mel_weighted = mel2d * cam_norm
        plot_mel(
            mel_weighted,
            save_path=layer_dir / f"{fname}_mel_weighted.png",
            title=f"{fname} | {lname} | weighted",
        )

        # focus mask
        thr = np.quantile(cam_norm, cam_quantile)
        mask = cam_norm >= thr
        mel_focus = np.where(mask, mel2d, mel2d.min())
        plot_mel(
            mel_focus,
            save_path=layer_dir / f"{fname}_mel_focus_q{cam_quantile}.png",
            title=f"{fname} | {lname} | focus",
        )

        # overlay (big)
        H, W = mel_norm.shape
        H2, W2 = H * upscale, W * upscale
        rgb_big = upsample_hw(
            np.stack([mel_norm] * 3, axis=-1), (H2, W2), mode="bicubic"
        )
        cam_big = upsample_hw(cam_norm, (H2, W2), mode="bicubic")
        rgb_big = rgb_big.astype(np.float32)
        rgb_big = np.clip(rgb_big, 0.0, 1.0)

        cam_big = cam_big.astype(np.float32)
        cam_big = np.clip(cam_big, 0.0, 1.0)

        overlay = show_cam_on_image(rgb_big, cam_big, use_rgb=True)
        plt.imsave(layer_dir / f"{fname}_overlay.png", overlay)

        # stats
        entropy = -np.sum(cam_norm * np.log(cam_norm + 1e-8)) / cam_norm.size
        support = (cam_norm > 0.5).mean()
        layer_stats[lname] = {"entropy": float(entropy), "support@0.5": float(support)}

    stats_df = pd.DataFrame(layer_stats).T
    stats_df.to_csv(save_dir / "layer_activation_stats.csv")
    print("Saved Activation-CAM results to:", save_dir)
    print(stats_df)

In [6]:
# # %%
# wav_dir = PROJECT_ROOT / "data" / "wavs"
# test_wav = wav_dir / "idan_012.wav"
# print("Testing with:", test_wav)
# run_out_dir = Path(f"activation_cam_results/idan/{test_wav.stem}")

# run_activation_cam(
#     wav_path=str(test_wav),
#     save_dir=str(run_out_dir),
#     upscale=12,
#     cam_quantile=0.85,
# )




In [7]:
wav_dir = PROJECT_ROOT / "data" / "wavs"
out_dir = Path("activation_cam_results")

top3_by_speaker = {
    "eden": ["eden_007.wav", "eden_010.wav", "eden_012.wav"],
    "idan": ["idan_001.wav", "idan_002.wav", "idan_005.wav"],
    "yoav": ["yoav_004.wav", "yoav_014.wav", "yoav_015.wav"],
}

for spk, fnames in top3_by_speaker.items():
    for fname in fnames:
        wav_path = wav_dir / fname
        if not wav_path.exists():
            print(f"[WARN] missing file: {wav_path}")
            continue

        print(f"\n=== Running Activation-CAM | speaker={spk} | file={fname} ===")

        run_out_dir = out_dir / spk / Path(fname).stem

        run_activation_cam(
            wav_path=str(wav_path),
            save_dir=str(run_out_dir),
            upscale=12,
            cam_quantile=0.85,
        )


=== Running Activation-CAM | speaker=eden | file=eden_007.wav ===
Saved Activation-CAM results to: activation_cam_results/eden/eden_007
                  entropy  support@0.5
layer1.9.conv3   0.285503     0.851207
layer2.19.conv3  0.253512     0.937114
layer3.63.conv3  0.310794     0.661390
layer4.2.conv3   0.210128     0.906178

=== Running Activation-CAM | speaker=eden | file=eden_010.wav ===
Saved Activation-CAM results to: activation_cam_results/eden/eden_010
                  entropy  support@0.5
layer1.9.conv3   0.294476     0.812092
layer2.19.conv3  0.241299     0.963949
layer3.63.conv3  0.317666     0.601766
layer4.2.conv3   0.221715     0.887862

=== Running Activation-CAM | speaker=eden | file=eden_012.wav ===
Saved Activation-CAM results to: activation_cam_results/eden/eden_012
                  entropy  support@0.5
layer1.9.conv3   0.284342     0.840766
layer2.19.conv3  0.252354     0.923198
layer3.63.conv3  0.321082     0.543750
layer4.2.conv3   0.204536     0.922523

===