In [3]:
# %%
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from wespeaker.cli.speaker import load_model

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))
from utils.spectogram_player_html import save_spectrogram_player_html

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("PROJECT_ROOT =", PROJECT_ROOT)
print("Using device:", DEVICE)

# Load WeSpeaker bundle + move backbone to GPU
speaker = load_model(PROJECT_ROOT / "wespeaker-voxceleb-resnet293-LM")
speaker.model = speaker.model.to(DEVICE).eval()

print("ResNet-293 loaded")

# WeSpeaker mel params
SAMPLE_RATE = 16000
N_MELS = 80
FRAME_LENGTH_MS = 25
FRAME_SHIFT_MS = 10
N_FFT = int(SAMPLE_RATE * FRAME_LENGTH_MS / 1000)
HOP_LENGTH = int(SAMPLE_RATE * FRAME_SHIFT_MS / 1000)

print(f"Mel params: n_fft={N_FFT}, hop_length={HOP_LENGTH}, n_mels={N_MELS}")

PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda
{'data_type': 'shard', 'dataloader_args': {'batch_size': 32, 'drop_last': True, 'num_workers': 16, 'pin_memory': False, 'prefetch_factor': 8}, 'dataset_args': {'aug_prob': 0.6, 'fbank_args': {'dither': 1.0, 'frame_length': 25, 'frame_shift': 10, 'num_mel_bins': 80}, 'num_frms': 200, 'shuffle': True, 'shuffle_args': {'shuffle_size': 2500}, 'spec_aug': False, 'spec_aug_args': {'max_f': 8, 'max_t': 10, 'num_f_mask': 1, 'num_t_mask': 1, 'prob': 0.6}, 'speed_perturb': True}, 'exp_dir': 'exp/ResNet293-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150', 'gpus': [0, 1], 'log_batch_interval': 100, 'loss': 'CrossEntropyLoss', 'loss_args': {}, 'margin_scheduler': 'MarginScheduler', 'margin_update': {'epoch_iter': 17062, 'final_margin': 0.2, 'fix_start_epoch': 40, 'increase_start_epoch': 20, 'increase_type': 'exp', 'initial_margin': 0.0, 'update_margin': True}, 'model': 'ResNet293', 'model_args': {'embed_dim': 2

In [4]:
# %%
# ========== LOAD YOUR TRAINED FC HEAD ==========
SPEAKERS = ["eden", "idan", "yoav"]
speaker_to_id = {s: i for i, s in enumerate(SPEAKERS)}
id_to_speaker = {i: s for s, i in speaker_to_id.items()}


class SpeakerHead(nn.Module):
    def __init__(self, in_dim=256, num_classes=3):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)


head = SpeakerHead(in_dim=256, num_classes=len(SPEAKERS)).to(DEVICE)

head_path = PROJECT_ROOT / "data" / "heads" / "resnet_293_speaker_head.pt"
assert head_path.exists(), f"Missing head checkpoint: {head_path}"

ckpt = torch.load(head_path, map_location="cpu")
head.load_state_dict(ckpt["state_dict"])
head = head.to(DEVICE).eval()

if "speakers" in ckpt:
    SPEAKERS = ckpt["speakers"]
    speaker_to_id = {s: i for i, s in enumerate(SPEAKERS)}
    id_to_speaker = {i: s for s, i in speaker_to_id.items()}

print("Loaded head from:", head_path)
print("Speakers:", SPEAKERS)

Loaded head from: /home/SpeakerRec/BioVoice/data/heads/resnet_293_speaker_head.pt
Speakers: ['eden', 'idan', 'yoav']


  ckpt = torch.load(head_path, map_location="cpu")


In [5]:
class WeSpeakerWithHeadForGradCAM(nn.Module):
    def __init__(
        self, backbone: nn.Module, head: nn.Module, try_transpose: bool = True
    ):
        super().__init__()
        self.backbone = backbone
        self.head = head
        self.try_transpose = try_transpose

    def forward(self, feats):
        # Accept NCHW from Grad-CAM and squeeze channel dim
        if feats.ndim == 4 and feats.size(1) == 1:
            feats = feats.squeeze(1)

        # Run backbone
        try:
            out = self.backbone(feats)
        except Exception:
            if not self.try_transpose:
                raise
            out = self.backbone(feats.transpose(1, 2))

        # Unwrap tuple/list (WeSpeaker behavior)
        if isinstance(out, (tuple, list)):
            emb = out[-1]
        else:
            emb = out

        # ---- HARD SHAPE GUARD ----
        if emb.ndim == 0:
            emb = emb.unsqueeze(0).unsqueeze(0)
        elif emb.ndim == 1:
            emb = emb.unsqueeze(0)
        elif emb.ndim > 2:
            emb = emb.view(emb.size(0), -1)

        # ---- HARD DEVICE GUARD (THIS FIXES YOUR ERROR) ----
        emb = emb.to(next(self.head.parameters()).device)

        logits = self.head(emb)

        if logits.ndim == 1:
            logits = logits.unsqueeze(0)

        return logits


wrapped_model = WeSpeakerWithHeadForGradCAM(speaker.model, head).to(DEVICE).eval()

# Enable gradients for backbone (needed for Grad-CAM)
for p in wrapped_model.backbone.parameters():
    p.requires_grad_(True)

# Head grads not required (optional)
for p in wrapped_model.head.parameters():
    p.requires_grad_(False)

print("wrapped_model ready")

wrapped_model ready


In [6]:
# %%
# ========== HELPERS (your existing ones, unchanged) ==========
def magma_rgb(img2d: np.ndarray) -> np.ndarray:
    img2d = img2d.astype(np.float32)
    img2d = (img2d - img2d.min()) / (img2d.max() - img2d.min() + 1e-8)
    return cm.get_cmap("magma")(img2d)[..., :3].astype(np.float32)

def normalize_01(arr: np.ndarray) -> np.ndarray:
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    vmin = float(arr.min())
    vmax = float(arr.max())
    if vmax - vmin < 1e-8:
        return np.zeros_like(arr, dtype=np.float32)
    return ((arr - vmin) / (vmax - vmin)).astype(np.float32)

def upsample_hw(arr: np.ndarray, size_hw: tuple[int, int], mode: str = "bilinear") -> np.ndarray:
    arr = np.ascontiguousarray(arr)

    if arr.ndim == 2:
        t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()  # [1,1,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0, 0].cpu().numpy()

    if arr.ndim == 3 and arr.shape[2] == 3:
        t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()  # [1,3,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0].permute(1, 2, 0).cpu().numpy()

    raise ValueError(f"Unsupported shape: {arr.shape}")

def to_float01(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    if img.max() > 1.0:
        img = img / 255.0
    return np.clip(img, 0.0, 1.0)

def plot_mel(mel, save_path=None, title=None):
    if isinstance(mel, torch.Tensor):
        mel = mel.squeeze(0).detach().cpu().numpy()

    mel_db = 10.0 * np.log10(np.maximum(mel, 1e-10))
    mel_db = np.clip(mel_db, mel_db.max() - 80.0, mel_db.max())
    mel_norm = (mel_db - mel_db.min()) / (mel_db.max() - mel_db.min() + 1e-8)

    plt.figure(figsize=(10, 4))
    plt.imshow(mel_norm, aspect="auto", cmap="magma", interpolation="bilinear", origin="lower")
    plt.colorbar(label="Normalized Energy")
    plt.xlabel("Time Frames")
    plt.ylabel("Mel Frequency Bins")
    if title:
        plt.title(title)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200)
        plt.close()
    else:
        plt.show()

def load_wav_mono_16k(wav_path: str) -> torch.Tensor:
    wav, sr = torchaudio.load(wav_path)

    if wav.shape[0] > 1:
        wav = wav[:1, :]

    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)

    return wav  # keep on CPU for mel transform stability


In [7]:
# %%
# ========== FEATURE EXTRACTOR (mel) ==========
# Important: mel_transform must be on DEVICE if you want mel tensor on GPU.
mel_transform = T.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=N_MELS,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
).to(DEVICE)

In [8]:
# %%
# ========== TARGET LAYERS (same idea, but via wrapped_model.backbone) ==========
target_layers = [
    # wrapped_model.backbone.layer1[9].conv3,
    # wrapped_model.backbone.layer2[19].conv3,
    # wrapped_model.backbone.layer3[63].conv3,
    wrapped_model.backbone.layer4[2].conv3,
]

layer_names = [
    # "layer1.9.conv3",
    # "layer2.19.conv3",
    # "layer3.63.conv3",
    "layer4.2.conv3",
]

print("Target layers:")
for n, l in zip(layer_names, target_layers):
    print(n, "->", l)

Target layers:
layer4.2.conv3 -> Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [9]:
# %%
def run_gradcam_speaker(
    wav_path: str,
    target_speaker: str,
    save_dir: str = "gradcam_results_cls",
    upscale: int = 10,
    cam_quantile: float = 0.85,
):
    assert target_speaker in speaker_to_id, f"Unknown speaker: {target_speaker}"

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    fname = Path(wav_path).stem

    # 1) Load wav (CPU), compute mel ON GPU with grad
    wav_cpu = load_wav_mono_16k(wav_path)              # [1, T] CPU
    wav_gpu = wav_cpu.to(DEVICE)                       # move to GPU for mel_transform
    feats = mel_transform(wav_gpu)                     # [1, 80, Tfeat] on GPU
    feats = feats.contiguous().float()
    feats_cam = feats.unsqueeze(1)                  # [1, 1, 80, Tfeat] for Grad-CAM

    # 1b) Detached mel for visualization (numpy)
    with torch.no_grad():
        mel2d = feats.detach()[0].cpu().numpy()        # [80, Tfeat]
    mel_db = 10.0 * np.log10(np.maximum(mel2d, 1e-10))
    mel_db = np.clip(mel_db, mel_db.max() - 80.0, mel_db.max())
    mel_norm = normalize_01(mel_db)

    originals_dir = save_dir / "original"
    originals_dir.mkdir(exist_ok=True)
    plot_mel(
        mel2d,
        save_path=originals_dir / f"{fname}_mel.png",
        title=f"{fname} (target={target_speaker})",
    )

    # 2) target class
    target = ClassifierOutputTarget(speaker_to_id[target_speaker])

    layer_stats = {}

    for layer, lname in zip(target_layers, layer_names):
        print(f"[Grad-CAM] Layer: {lname}")

        layer_dir = save_dir / lname
        layer_dir.mkdir(parents=True, exist_ok=True)

        wrapped_model.zero_grad(set_to_none=True)
        if feats.grad is not None:
            feats.grad.zero_()

        # Optional sanity check
        with torch.no_grad():
            tmp = wrapped_model(feats)
        print("wrapped_model(feats) shape:", tmp.shape)

        try:
            cam = GradCAM(model=wrapped_model, target_layers=[layer], use_cuda=(DEVICE == "cuda"))
        except TypeError:
            cam = GradCAM(model=wrapped_model, target_layers=[layer])

        with torch.enable_grad():
            grayscale_cam = cam(
                input_tensor=feats_cam,      # [1,1,80,T] for Grad-CAM
                targets=[target],
                aug_smooth=False,
                eigen_smooth=True,
            )[0]  # [H', W'] (feature map space)

        # 4) align CAM to mel for plotting
        cam_on_mel = upsample_hw(grayscale_cam, mel2d.shape, mode="bicubic")
        cam_norm = normalize_01(cam_on_mel)

        # 5) weighted mel
        mel_weighted = mel2d * cam_norm
        plot_mel(
            mel_weighted,
            save_path=layer_dir / f"{fname}_mel_weighted.png",
            title=f"{fname} | {lname} | weighted",
        )

        # 6) focus mask
        thr = np.quantile(cam_norm, cam_quantile)
        mask = cam_norm >= thr
        mel_focus = np.where(mask, mel2d, mel2d.min())
        plot_mel(
            mel_focus,
            save_path=layer_dir / f"{fname}_mel_focus_q{cam_quantile}.png",
            title=f"{fname} | {lname} | focus",
        )

        # 7) overlay
        H, W = mel_norm.shape
        H2, W2 = H * upscale, W * upscale

        rgb_big = upsample_hw(np.stack([mel_norm] * 3, axis=-1), (H2, W2), mode="bicubic")
        cam_big = upsample_hw(cam_norm, (H2, W2), mode="bicubic")

        rgb_big = to_float01(rgb_big)
        cam_big = normalize_01(cam_big)
        overlay = show_cam_on_image(rgb_big, cam_big, use_rgb=True)
        plt.imsave(layer_dir / f"{fname}_overlay.png", overlay)

        # 7b) masked mel (no axes) + HTML player
        thr_big = np.quantile(cam_big, cam_quantile)
        mask_big = (cam_big >= thr_big).astype(np.float32)
        mel_only = magma_rgb(rgb_big[..., 0]) * mask_big[..., None]
        mel_only = np.clip(mel_only, 0.0, 1.0)
        mel_masked_png = layer_dir / f"{fname}_mel_masked_q{cam_quantile:.2f}.png"
        plt.imsave(mel_masked_png, mel_only)
        save_spectrogram_player_html(
            audio_path=wav_path,
            spectrogram_png_path=mel_masked_png,
            out_html_path=layer_dir / f"{fname}_mel_masked_q{cam_quantile:.2f}.html",
            total_time_sec=None,
            copy_audio=True,
            embed_image=True,
        )

        # 8) stats
        entropy = -np.sum(cam_norm * np.log(cam_norm + 1e-8)) / cam_norm.size
        support = (cam_norm > 0.5).mean()
        layer_stats[lname] = {"entropy": float(entropy), "support@0.5": float(support)}

    stats_df = pd.DataFrame(layer_stats).T
    stats_df.to_csv(save_dir / "layer_cam_stats.csv")

    print("Saved Grad-CAM results to:", save_dir)
    print(stats_df)


In [10]:
wav_dir = PROJECT_ROOT / "data" / "wavs"

# Use an existing audio file
test_wav = wav_dir / "idan_012.wav"
if not test_wav.exists():
    # Fallback to first available idan file
    idan_files = list(wav_dir.glob("idan_*.wav"))
    if idan_files:
        test_wav = sorted(idan_files)[0]
    else:
        raise FileNotFoundError(f"No audio files found in {wav_dir}")

print(f"Testing with: {test_wav}")

run_out_dir = Path(f"gradcam_results/idan/{test_wav.stem}")
originals_dir = run_out_dir / "original"

run_gradcam_speaker(
    wav_path=str(test_wav),
    target_speaker="idan",
    save_dir=str(run_out_dir),
    upscale=12,
)

# Uncomment to run on multiple speakers
# top3_by_speaker = {
#     "eden": ["eden_007.wav", "eden_010.wav", "eden_012.wav"],
#     "idan": ["idan_001.wav", "idan_002.wav", "idan_005.wav"],
#     "yoav": ["yoav_004.wav", "yoav_014.wav", "yoav_015.wav"],
# }

# for spk, fnames in top3_by_speaker.items():
#     for fname in fnames:
#         wav_path = wav_dir / fname
#         if not wav_path.exists():
#             print(f"[WARN] missing file: {wav_path}")
#             continue

#         run_gradcam_speaker(
#             wav_path=str(wav_path),
#             target_speaker=spk,
#             save_dir=f"gradcam_results/{spk}/{fname.replace('.wav', '')}",
#             upscale=12,
#         )
# Run all wavs per speaker (set RUN_ALL = True to enable)
RUN_ALL = False
if RUN_ALL:
    wav_by_user = {spk: list(wav_dir.glob(f"{spk}_*.wav")) for spk in SPEAKERS}
    for spk, files in wav_by_user.items():
        for wav_path in files:
            if not wav_path.exists():
                print(f"[WARN] missing file: {wav_path}")
                continue
            run_gradcam_speaker(
                wav_path=str(wav_path),
                target_speaker=spk,
                save_dir=f"gradcam_results/{spk}/{wav_path.stem}",
                upscale=12,
            )


Testing with: /home/SpeakerRec/BioVoice/data/wavs/idan_012.wav
[Grad-CAM] Layer: layer4.2.conv3


Exception ignored in: <function BaseCAM.__del__ at 0x7f7ea54ec040>
Traceback (most recent call last):
  File "/home/SpeakerRec/BioVoice/.venvResnet/lib/python3.10/site-packages/pytorch_grad_cam/base_cam.py", line 212, in __del__
    self.activations_and_grads.release()
AttributeError: 'GradCAM' object has no attribute 'activations_and_grads'


wrapped_model(feats) shape: torch.Size([1, 3])


  return cm.get_cmap("magma")(img2d)[..., :3].astype(np.float32)


Saved Grad-CAM results to: gradcam_results/idan/idan_012
                 entropy  support@0.5
layer4.2.conv3  0.068207     0.030352
