In [1]:
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

PROJECT_ROOT = Path.cwd().parents[2]
sys.path.append(str(PROJECT_ROOT))
print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

redim_model = torch.hub.load(
    "IDRnD/ReDimNet",
    "ReDimNet",
    model_name="b5",
    train_type="ptn",
    dataset="vox2",
).to(DEVICE).eval()

print("Loaded ReDimNet successfully.")



PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet successfully.


In [2]:
# another head model - YOAV!!!!!!!!!!!!!!!!!!!!!

HEAD_PATH = Path.cwd() / "output" / "redim_speaker_head_linear.pt"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

ckpt = torch.load(HEAD_PATH, map_location=DEVICE)

# Restore speaker mappings
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
SPEAKERS = list(speaker_to_id.keys())

# Define head with correct dimensions
class SpeakerHead(nn.Module):
    def __init__(self, in_dim=192, num_classes=3):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

head = SpeakerHead(
    in_dim=192,
    num_classes=len(SPEAKERS)
).to(DEVICE)

# Load weights
head.load_state_dict(ckpt["state_dict"])
head.eval()

print("Loaded speaker head from:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("L2-normalized embeddings:", ckpt.get("l2_norm_emb", False))


Loaded speaker head from: /home/SpeakerRec/BioVoice/redimnet/grad_cam/2.0/output/redim_speaker_head_linear.pt
Speakers: ['eden', 'idan', 'yoav']
L2-normalized embeddings: True


  ckpt = torch.load(HEAD_PATH, map_location=DEVICE)


In [3]:
def to_float01(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    if img.max() > 1.0:
        img = img / 255.0
    return np.clip(img, 0.0, 1.0)

def upsample_hw(arr: np.ndarray, size_hw: tuple[int, int],mode: str = "bilinear") -> np.ndarray:
    arr = np.ascontiguousarray(arr)

    if arr.ndim == 2:
        t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()  # [1,1,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0, 0].cpu().numpy()

    if arr.ndim == 3 and arr.shape[2] == 3:
        t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()  # [1,3,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0].permute(1, 2, 0).cpu().numpy()

    raise ValueError(f"Unsupported shape: {arr.shape}")

def plot_mel(mel, save_path=None, title=None):
    if isinstance(mel, torch.Tensor):
        mel = mel.squeeze(0).detach().cpu().numpy()

    mel_norm = (mel - mel.min()) / (mel.max() - mel.min() + 1e-8)

    plt.figure(figsize=(10, 4))
    plt.imshow(mel_norm, aspect="auto", cmap="magma", interpolation="bilinear",origin="lower")
    plt.colorbar(label="Normalized Energy")
    plt.xlabel("Time Frames")
    plt.ylabel("Mel Frequency Bins")
    if title:
        plt.title(title)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200)
        plt.close()
    else:
        plt.show()

def wav_to_model_mel(wav_path: str):
    wav, sr = torchaudio.load(wav_path)
    wav = wav[:1, :].float().to(DEVICE)

    with torch.no_grad():
        mel = redim_model.spec(wav)  # [1, 80, T]

    mel4d = mel.unsqueeze(0)  # [1, 1, 80, T]

    mel_np = mel.squeeze(0).cpu().numpy()  # [80, T]
    mel_norm = (mel_np - mel_np.min()) / (mel_np.max() - mel_np.min() + 1e-8)
    rgb_base = np.stack([mel_norm] * 3, axis=-1)  # [80, T, 3]
    rgb_base = np.flipud(rgb_base)  # keep your old convention

    return mel4d, rgb_base


def magma_rgb(img2d: np.ndarray) -> np.ndarray:
    img2d = img2d.astype(np.float32)
    img2d = (img2d - img2d.min()) / (img2d.max() - img2d.min() + 1e-8)
    return cm.get_cmap("magma")(img2d)[..., :3].astype(np.float32)


In [4]:
class ReDimNetMelLogitsWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, 80, T]
    Output: logits [B, 3]
    """
    def __init__(self, redim_model, head, l2_norm_emb=True):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool     = redim_model.pool
        self.bn       = redim_model.bn
        self.linear   = redim_model.linear  # [B, 192]
        self.head     = head                # [B, 3]
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d):
        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)  # [B, 192]

        # match head training: L2-normalized embeddings
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)

        logits = self.head(emb)  # [B, 3]
        return logits

wrapped_model = ReDimNetMelLogitsWrapper(redim_model, head, l2_norm_emb=True).to(DEVICE).eval()
print("Wrapper (logits) ready.")



Wrapper (logits) ready.


In [5]:
target_layers = [
    wrapped_model.backbone.stem[0],
    wrapped_model.backbone.stage0[2],
    wrapped_model.backbone.stage1[2],
    wrapped_model.backbone.stage2[2],
    wrapped_model.backbone.stage3[2],
    wrapped_model.backbone.stage4[2],
    wrapped_model.backbone.stage5[2],
]
layer_names = ["stem", "stage0", "stage1", "stage2", "stage3", "stage4", "stage5"]

print("Target layers:")
for n, l in zip(layer_names, target_layers):
    print(n, "->", l)



Target layers:
stem -> Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
stage0 -> Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
stage1 -> Conv2d(32, 128, kernel_size=(2, 1), stride=(2, 1))
stage2 -> Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
stage3 -> Conv2d(64, 128, kernel_size=(2, 1), stride=(2, 1))
stage4 -> Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
stage5 -> Conv2d(128, 256, kernel_size=(2, 1), stride=(2, 1))


In [6]:
class ClassLogitTarget:
    def __init__(self, class_idx: int):
        self.class_idx = int(class_idx)

    def __call__(self, model_output: torch.Tensor) -> torch.Tensor:
        if model_output.ndim == 1:
            return model_output[self.class_idx]
        if model_output.ndim == 2:
            return model_output[:, self.class_idx].sum()
        raise ValueError(f"Unexpected model_output shape: {tuple(model_output.shape)}")


In [7]:
redim_model.eval()
head.eval()

for p in redim_model.parameters():
    p.requires_grad = True   # needed for Grad-CAM

for p in head.parameters():
    p.requires_grad = False  # optional (head not needed to have grads)


In [8]:
def run_gradcam_speaker(
    wav_path: str,
    target_speaker: str,
    save_dir: str = "gradcam_results_cls",
    upscale: int = 10,
    cam_quantile: float = 0.85,
):
    assert target_speaker in speaker_to_id, f"Unknown speaker: {target_speaker}"

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    fname = Path(wav_path).stem

    # ----------------------------
    # 1) Compute model Mel (input to backbone)
    # ----------------------------
    mel4d, _ = wav_to_model_mel(wav_path)  # expected: [1,1,80,T]
    mel4d = mel4d.detach().to(DEVICE)
    mel4d.requires_grad_(True)  # important for Grad-CAM

    mel2d = mel4d.squeeze(0).squeeze(0).detach().cpu().numpy()  # [80, T]
    mel2d = mel2d.astype(np.float32)

    mel_norm = (mel2d - mel2d.min()) / (mel2d.max() - mel2d.min() + 1e-8)
    mel_norm = np.clip(mel_norm.astype(np.float32), 0.0, 1.0)

    # Save original mel
    originals_dir = save_dir / "original"
    originals_dir.mkdir(parents=True, exist_ok=True)
    plot_mel(
        mel2d,
        save_path=originals_dir / f"{fname}_mel.png",
        title=f"{fname} (target={target_speaker})",
    )

    # ----------------------------
    # 2) Target definition
    # ----------------------------
    target = ClassLogitTarget(speaker_to_id[target_speaker])

    # ----------------------------
    # 3) Iterate over ALL layers
    # ----------------------------

    for layer, lname in zip(target_layers, layer_names):
        print(f"[Grad-CAM] Layer: {lname}")

        layer_dir = save_dir / lname
        layer_dir.mkdir(parents=True, exist_ok=True)

        wrapped_model.zero_grad(set_to_none=True)

        cam = GradCAM(
            model=wrapped_model,
            target_layers=[layer],
        )

        grayscale_cam = cam(
            input_tensor=mel4d,
            targets=[target],
            aug_smooth=False,
            eigen_smooth=True,
        )[0]  # [H', W']

        # Flip to match mel orientation (your convention)
        grayscale_cam = np.flipud(grayscale_cam).astype(np.float32)

        # ----------------------------
        # 4) Align CAM to Mel space
        # ----------------------------
        cam_on_mel = upsample_hw(grayscale_cam, mel2d.shape, mode="bicubic").astype(np.float32)

        # Normalize CAM to [0,1]
        cam_norm = (cam_on_mel - cam_on_mel.min()) / (cam_on_mel.max() - cam_on_mel.min() + 1e-8)
        cam_norm = np.clip(cam_norm.astype(np.float32), 0.0, 1.0)


        # ----------------------------
        # 6) High-activation focus (relative threshold)
        # ----------------------------
        thr = float(np.quantile(cam_norm, cam_quantile))
        mask = cam_norm >= thr

        mel_focus = np.where(mask, mel2d, mel2d.min())
        plot_mel(
            mel_focus,
            save_path=layer_dir / f"{fname}_mel_focus_q{cam_quantile}.png",
            title=f"{fname} | {lname} | focus (q={cam_quantile})",
        )

        # ----------------------------
        # 7) Visualization overlay (needs float32 in [0,1])
        # ----------------------------
        H, W = mel_norm.shape
        H2, W2 = H * upscale, W * upscale

        rgb_big = upsample_hw(
            np.stack([mel_norm] * 3, axis=-1),
            (H2, W2),
            mode="bicubic",
        ).astype(np.float32)
        rgb_big = np.clip(rgb_big, 0.0, 1.0)  # IMPORTANT (bicubic can overshoot)

        cam_big = upsample_hw(cam_norm, (H2, W2), mode="bicubic").astype(np.float32)
        cam_big = np.clip(cam_big, 0.0, 1.0)

        overlay = show_cam_on_image(rgb_big, cam_big, use_rgb=True)

        plt.imsave(layer_dir / f"{fname}_overlay.png", overlay)







In [9]:
wav_dir = PROJECT_ROOT / "data" / "wavs"

test_wav = wav_dir / "idan_009.wav"
run_gradcam_speaker(
    wav_path=str(test_wav),
    target_speaker="idan",
    save_dir=f"gradcam_results/idan/{test_wav.stem}",
    upscale=12,
)


# top3_by_user = {
#     "eden": ["eden_017.wav", "eden_021.wav", "eden_012.wav"],
#     "idan": ["idan_009.wav", "idan_004.wav", "idan_012.wav"],
#     "yoav": ["yoav_028.wav", "yoav_024.wav", "yoav_022.wav"],
# }

# for spk, fnames in top3_by_user.items():
#     for fname in fnames:
#         wav_path = wav_dir / fname
#         if not wav_path.exists():
#             print(f"[WARN] missing file: {wav_path}")
#             continue

#         run_gradcam_speaker(
#             wav_path=str(wav_path),
#             target_speaker=spk,
#             save_dir=f"gradcam_results_2.0//{spk}/{wav_path.stem}",
#             thr=0.4,      # or 0.3 if you prefer
#             upscale=8,
#         )

[Grad-CAM] Layer: stem
[Grad-CAM] Layer: stage0
[Grad-CAM] Layer: stage1
[Grad-CAM] Layer: stage2
[Grad-CAM] Layer: stage3
[Grad-CAM] Layer: stage4
[Grad-CAM] Layer: stage5
