In [27]:
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

PROJECT_ROOT = Path.cwd().parents[2]
sys.path.append(str(PROJECT_ROOT))
from utils.spectogram_player_html import save_spectrogram_player_html

print("PROJECT_ROOT =", PROJECT_ROOT)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

redim_model = torch.hub.load(
    "IDRnD/ReDimNet",
    "ReDimNet",
    model_name="b5",
    train_type="ptn",
    dataset="vox2",
).to(DEVICE).eval()

print("Loaded ReDimNet successfully.")



PROJECT_ROOT = /home/SpeakerRec/BioVoice
Using device: cuda


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


Loaded ReDimNet successfully.


In [28]:
# another head model - YOAV!!!!!!!!!!!!!!!!!!!!!

HEAD_PATH = Path.cwd() / "output" / "redim_speaker_head_linear.pt"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

ckpt = torch.load(HEAD_PATH, map_location=DEVICE)

# Restore speaker mappings
speaker_to_id = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
SPEAKERS = list(speaker_to_id.keys())

# Define head with correct dimensions
class SpeakerHead(nn.Module):
    def __init__(self, in_dim=192, num_classes=3):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

head = SpeakerHead(
    in_dim=192,
    num_classes=len(SPEAKERS)
).to(DEVICE)

# Load weights
head.load_state_dict(ckpt["state_dict"])
head.eval()

print("Loaded speaker head from:", HEAD_PATH)
print("Speakers:", SPEAKERS)
print("L2-normalized embeddings:", ckpt.get("l2_norm_emb", False))


Loaded speaker head from: /home/SpeakerRec/BioVoice/redimnet/grad_cam/2.0/output/redim_speaker_head_linear.pt
Speakers: ['eden', 'idan', 'yoav']
L2-normalized embeddings: True


  ckpt = torch.load(HEAD_PATH, map_location=DEVICE)


In [29]:
# # ========== 1) LOAD YOUR TRAINED FC HEAD (SPEAKER CLASSIFIER) ==========
# SPEAKERS = ["eden", "idan", "yoav"]
# speaker_to_id = {s: i for i, s in enumerate(SPEAKERS)}
# id_to_speaker = {i: s for s, i in speaker_to_id.items()}

# class SpeakerHead(nn.Module):
#     def __init__(self, in_dim=192, num_classes=3):
#         super().__init__()
#         self.fc = nn.Linear(in_dim, num_classes)

#     def forward(self, x):
#         return self.fc(x)

# head = SpeakerHead(in_dim=192, num_classes=len(SPEAKERS)).to(DEVICE)

# head_path = PROJECT_ROOT / "data" / "redim_speaker_head.pt"
# assert head_path.exists(), f"Missing head checkpoint: {head_path}"

# ckpt = torch.load(head_path, map_location="cpu")
# head.load_state_dict(ckpt["state_dict"])
# head.eval()

# if "speakers" in ckpt:
#     SPEAKERS = ckpt["speakers"]
#     speaker_to_id = {s: i for i, s in enumerate(SPEAKERS)}
#     id_to_speaker = {i: s for s, i in speaker_to_id.items()}

# print("Loaded head from:", head_path)
# print("Speakers:", SPEAKERS)




In [30]:
def to_float01(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    if img.max() > 1.0:
        img = img / 255.0
    return np.clip(img, 0.0, 1.0)

def upsample_hw(arr: np.ndarray, size_hw: tuple[int, int],mode: str = "bilinear") -> np.ndarray:
    arr = np.ascontiguousarray(arr)

    if arr.ndim == 2:
        t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).float()  # [1,1,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0, 0].cpu().numpy()

    if arr.ndim == 3 and arr.shape[2] == 3:
        t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).float()  # [1,3,H,W]
        t = F.interpolate(t, size=size_hw, mode=mode, align_corners=False)
        return t[0].permute(1, 2, 0).cpu().numpy()

    raise ValueError(f"Unsupported shape: {arr.shape}")

def plot_mel(mel, save_path=None, title=None):
    if isinstance(mel, torch.Tensor):
        mel = mel.squeeze(0).detach().cpu().numpy()

    mel_norm = (mel - mel.min()) / (mel.max() - mel.min() + 1e-8)

    plt.figure(figsize=(10, 4))
    plt.imshow(mel_norm, aspect="auto", cmap="magma", interpolation="bilinear",origin="lower")
    plt.colorbar(label="Normalized Energy")
    plt.xlabel("Time Frames")
    plt.ylabel("Mel Frequency Bins")
    if title:
        plt.title(title)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200)
        plt.close()
    else:
        plt.show()

def wav_to_model_mel(wav_path: str):
    wav, sr = torchaudio.load(wav_path)
    wav = wav[:1, :].float().to(DEVICE)

    with torch.no_grad():
        mel = redim_model.spec(wav)  # [1, 80, T]

    mel4d = mel.unsqueeze(0)  # [1, 1, 80, T]

    mel_np = mel.squeeze(0).cpu().numpy()  # [80, T]
    mel_norm = (mel_np - mel_np.min()) / (mel_np.max() - mel_np.min() + 1e-8)
    rgb_base = np.stack([mel_norm] * 3, axis=-1)  # [80, T, 3]
    rgb_base = np.flipud(rgb_base)  

    return mel4d, rgb_base, sr


def magma_rgb(img2d: np.ndarray) -> np.ndarray:
    img2d = img2d.astype(np.float32)
    img2d = (img2d - img2d.min()) / (img2d.max() - img2d.min() + 1e-8)
    return cm.get_cmap("magma")(img2d)[..., :3].astype(np.float32)


In [31]:
class ReDimNetMelLogitsWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, 80, T]
    Output: logits [B, 3]
    """
    def __init__(self, redim_model, head, l2_norm_emb=True):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool     = redim_model.pool
        self.bn       = redim_model.bn
        self.linear   = redim_model.linear  # [B, 192]
        self.head     = head                # [B, 3]
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d):
        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)  # [B, 192]

        # match head training: L2-normalized embeddings
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)

        logits = self.head(emb)  # [B, 3]
        return logits

wrapped_model = ReDimNetMelLogitsWrapper(redim_model, head, l2_norm_emb=True).to(DEVICE).eval()
print("Wrapper (logits) ready.")



Wrapper (logits) ready.


In [32]:
target_layers = [
    wrapped_model.backbone.stem[0],
    wrapped_model.backbone.stage0[2],
    wrapped_model.backbone.stage1[2],
    wrapped_model.backbone.stage2[2],
    wrapped_model.backbone.stage3[2],
    wrapped_model.backbone.stage4[2],
    wrapped_model.backbone.stage5[2],
]
layer_names = ["stem", "stage0", "stage1", "stage2", "stage3", "stage4", "stage5"]

print("Target layers:")
for n, l in zip(layer_names, target_layers):
    print(n, "->", l)



Target layers:
stem -> Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
stage0 -> Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
stage1 -> Conv2d(32, 128, kernel_size=(2, 1), stride=(2, 1))
stage2 -> Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
stage3 -> Conv2d(64, 128, kernel_size=(2, 1), stride=(2, 1))
stage4 -> Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
stage5 -> Conv2d(128, 256, kernel_size=(2, 1), stride=(2, 1))


In [33]:
class ClassLogitTarget:
    def __init__(self, class_idx: int):
        self.class_idx = int(class_idx)

    def __call__(self, model_output: torch.Tensor) -> torch.Tensor:
        if model_output.ndim == 1:
            return model_output[self.class_idx]
        if model_output.ndim == 2:
            return model_output[:, self.class_idx].sum()
        raise ValueError(f"Unexpected model_output shape: {tuple(model_output.shape)}")


In [34]:
redim_model.eval()
head.eval()

for p in redim_model.parameters():
    p.requires_grad = True   # needed for Grad-CAM

for p in head.parameters():
    p.requires_grad = False  # optional (head not needed to have grads)


In [35]:
# def run_gradcam_speaker(
#     wav_path: str,
#     target_speaker: str,
#     save_dir: str = "gradcam_results_cls",
#     upscale: int = 10,
#     cam_quantile: float = 0.85,
# ):
#     assert target_speaker in speaker_to_id, f"Unknown speaker: {target_speaker}"

#     save_dir = Path(save_dir)
#     save_dir.mkdir(parents=True, exist_ok=True)

#     fname = Path(wav_path).stem

#     # ----------------------------
#     # 1) Compute model Mel
#     # ----------------------------
#     mel4d, rgb_base = wav_to_model_mel(wav_path)  # mel4d: [1,1,80,T]
#     mel4d = mel4d.detach().to(DEVICE)

#     mel2d = mel4d.squeeze(0).squeeze(0).cpu().numpy()  # [80, T]
#     mel_norm = (mel2d - mel2d.min()) / (mel2d.max() - mel2d.min() + 1e-8)

#     # Save original mel
#     originals_dir = save_dir / "original"
#     originals_dir.mkdir(exist_ok=True)
#     plot_mel(
#         mel2d,
#         save_path=originals_dir / f"{fname}_mel.png",
#         title=f"{fname} (target={target_speaker})",
#     )

#     # ----------------------------
#     # 2) Target definition
#     # ----------------------------
#     target = ClassLogitTarget(speaker_to_id[target_speaker])

#     # ----------------------------
#     # 3) Iterate over ALL layers
#     # ----------------------------
#     layer_stats = {}

#     for layer, lname in zip(target_layers, layer_names):
#         print(f"[Grad-CAM] Layer: {lname}")

#         layer_dir = save_dir / lname
#         layer_dir.mkdir(parents=True, exist_ok=True)

#         wrapped_model.zero_grad(set_to_none=True)

#         cam = GradCAM(
#             model=wrapped_model,
#             target_layers=[layer],
#         )

#         grayscale_cam = cam(
#             input_tensor=mel4d,
#             targets=[target],
#             aug_smooth=False,
#             eigen_smooth=True,
#         )[0]  # [H', W']

#         # Flip to match mel orientation
#         grayscale_cam = np.flipud(grayscale_cam)

#         # ----------------------------
#         # 4) Align CAM to Mel space
#         # ----------------------------
#         cam_on_mel = upsample_hw(grayscale_cam, mel2d.shape, mode="bicubic")

#         # Normalize CAM (continuous, NOT binary)
#         cam_norm = (cam_on_mel - cam_on_mel.min()) / (
#             cam_on_mel.max() - cam_on_mel.min() + 1e-8
#         )

#         # ----------------------------
#         # 5) Continuous CAM-weighted Mel
#         # ----------------------------
#         mel_weighted = mel2d * cam_norm

#         plot_mel(
#             mel_weighted,
#             save_path=layer_dir / f"{fname}_mel_weighted.png",
#             title=f"{fname} | {lname} | weighted",
#         )

#         # ----------------------------
#         # 6) High-activation focus (relative threshold)
#         # ----------------------------
#         thr = np.quantile(cam_norm, cam_quantile)
#         mask = cam_norm >= thr

#         mel_focus = np.where(mask, mel2d, mel2d.min())

#         plot_mel(
#             mel_focus,
#             save_path=layer_dir / f"{fname}_mel_focus_q{cam_quantile}.png",
#             title=f"{fname} | {lname} | focus",
#         )

#         # ----------------------------
#         # 7) Visualization overlay
#         # ----------------------------
#         H, W = mel_norm.shape
#         H2, W2 = H * upscale, W * upscale

#         rgb_big = upsample_hw(
#             np.stack([mel_norm] * 3, axis=-1),
#             (H2, W2),
#             mode="bicubic",
#         )

#         cam_big = upsample_hw(cam_norm, (H2, W2), mode="bicubic")
#         overlay = show_cam_on_image(rgb_big, cam_big, use_rgb=True)

#         plt.imsave(
#             layer_dir / f"{fname}_overlay.png",
#             overlay,
#         )

#         # ----------------------------
#         # 8) Quantitative layer stats
#         # ----------------------------
#         entropy = -np.sum(
#             cam_norm * np.log(cam_norm + 1e-8)
#         ) / cam_norm.size

#         support = (cam_norm > 0.5).mean()

#         layer_stats[lname] = {
#             "entropy": float(entropy),
#             "support@0.5": float(support),
#         }

#     # ----------------------------
#     # 9) Save stats summary
#     # ----------------------------
#     stats_df = pd.DataFrame(layer_stats).T
#     stats_df.to_csv(save_dir / "layer_cam_stats.csv")

#     print("Saved Grad-CAM results to:", save_dir)
#     print(stats_df)


In [36]:
import numpy as np
import matplotlib.pyplot as plt

def hz_to_mel(hz: np.ndarray, mel_scale: str = "htk") -> np.ndarray:
    hz = np.asarray(hz, dtype=np.float64)
    mel_scale = mel_scale.lower()
    if mel_scale == "htk":
        return 2595.0 * np.log10(1.0 + hz / 700.0)

    # slaney (librosa-style)
    f_sp = 200.0 / 3
    brkfrq = 1000.0
    brkpt = brkfrq / f_sp  # 15
    logstep = np.exp(np.log(6.4) / 27.0)

    mel = np.empty_like(hz)
    lin = hz < brkfrq
    mel[lin] = hz[lin] / f_sp
    mel[~lin] = brkpt + np.log(hz[~lin] / brkfrq) / np.log(logstep)
    return mel

def mel_to_hz(mel: np.ndarray, mel_scale: str = "htk") -> np.ndarray:
    mel = np.asarray(mel, dtype=np.float64)
    mel_scale = mel_scale.lower()
    if mel_scale == "htk":
        return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)

    # slaney (librosa-style)
    f_sp = 200.0 / 3
    brkfrq = 1000.0
    brkpt = brkfrq / f_sp  # 15
    logstep = np.exp(np.log(6.4) / 27.0)

    hz = np.empty_like(mel)
    lin = mel < brkpt
    hz[lin] = f_sp * mel[lin]
    hz[~lin] = brkfrq * (logstep ** (mel[~lin] - brkpt))
    return hz

def mel_bin_centers_hz(n_mels: int, f_min: float, f_max: float, mel_scale: str = "htk") -> np.ndarray:
    m_min = hz_to_mel(f_min, mel_scale)
    m_max = hz_to_mel(f_max, mel_scale)
    mel_points = np.linspace(m_min, m_max, n_mels + 2)  # endpoints + centers
    center_mels = mel_points[1:-1]
    return mel_to_hz(center_mels, mel_scale)

def _get_spec_params(spec, sr_fallback: int):
    # Try common attribute names; fall back safely.
    f_min = getattr(spec, "f_min", None)
    if f_min is None:
        f_min = getattr(spec, "fmin", 0.0)

    f_max = getattr(spec, "f_max", None)
    if f_max is None:
        f_max = getattr(spec, "fmax", None)

    mel_scale = getattr(spec, "mel_scale", "htk")

    # If spec doesn't provide f_max, assume Nyquist.
    if f_max is None:
        f_max = sr_fallback / 2.0

    return float(f_min), float(f_max), str(mel_scale)


def plot_mel_with_hz_axis(mel2d: np.ndarray, sr: int, spec=None, save_path=None, title=None):
    if isinstance(mel2d, torch.Tensor):
        mel2d = mel2d.detach().cpu().numpy()

    n_mels, T = mel2d.shape

    if spec is None:
        f_min, f_max, mel_scale = 0.0, sr / 2.0, "htk"
    else:
        f_min, f_max, mel_scale = _get_spec_params(spec, sr_fallback=sr)

    centers_hz = mel_bin_centers_hz(n_mels, f_min, f_max, mel_scale)

    mel_norm = (mel2d - mel2d.min()) / (mel2d.max() - mel2d.min() + 1e-8)

    plt.figure(figsize=(12, 4))
    plt.imshow(
        mel_norm,
        aspect="auto",
        cmap="magma",
        interpolation="bilinear",
        origin="lower",
    )

    idxs = np.linspace(0, n_mels - 1, 8).round().astype(int)
    plt.yticks(idxs, [f"{centers_hz[i]:.0f}" for i in idxs])

    plt.ylabel("Frequency (Hz)")
    plt.xlabel("Time Frames")
    if title:
        plt.title(title)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=400)
        plt.close()
    else:
        plt.show()


In [37]:
def run_gradcam_speaker(
    wav_path: str,
    target_speaker: str,
    save_dir: str = "gradcam_results_cls",
    thr: float = 0.3,
    upscale: int = 8,
):
    assert target_speaker in speaker_to_id, f"Unknown speaker: {target_speaker}"

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    fname = Path(wav_path).stem

    originals_dir = save_dir / "original"
    originals_dir.mkdir(parents=True, exist_ok=True)

    mel4d, rgb_base, sr = wav_to_model_mel(wav_path)
    mel4d = mel4d.detach()

    rgb_base01 = to_float01(rgb_base)

    
    plot_mel(
        mel4d.squeeze(0),
        save_path=originals_dir / f"{fname}_original_mel.png",
        title=f"{fname} (target={target_speaker})",
    )
    
    target = ClassLogitTarget(speaker_to_id[target_speaker])

    for layer, lname in zip(target_layers, layer_names):
        layer_dir = save_dir / lname
        layer_dir.mkdir(parents=True, exist_ok=True)
        
        wrapped_model.zero_grad(set_to_none=True)
        cam = GradCAM(model=wrapped_model, target_layers=[layer])

        grayscale_cam = cam(
            input_tensor=mel4d,
            targets=[target],
            aug_smooth=False,
            eigen_smooth=True,
        )[0]

        grayscale_cam = np.flipud(grayscale_cam)

        h, w = rgb_base01.shape[:2]
        H2, W2 = h * upscale, w * upscale

        rgb_big = upsample_hw(rgb_base01, (H2, W2), mode="bilinear")
        cam_big = upsample_hw(grayscale_cam, (H2, W2))

        mask_big = (cam_big >= thr).astype(np.float32)
        cam_big_masked = cam_big * mask_big

        overlay_big = show_cam_on_image(rgb_big, cam_big_masked, use_rgb=True)

        plt.imsave(layer_dir / f"{fname}_overlay_{target_speaker}_thr{thr:.2f}.png", overlay_big)
        
        rgb_big = magma_rgb(rgb_big[..., 0]) 
        mel_only = rgb_big * mask_big[..., None]
        mel_only = np.clip(mel_only, 0.0, 1.0)

        plt.imsave(
            layer_dir / f"{fname}_mel_masked_{target_speaker}_thr{thr:.2f}.png",
            mel_only
        )
        save_spectrogram_player_html(
            audio_path=wav_path,
            spectrogram_png_path=layer_dir / f"{fname}_mel_masked_{target_speaker}_thr{thr:.2f}.png",
            out_html_path=layer_dir / f"{fname}_mel_masked_{target_speaker}_thr{thr:.2f}.html",
            total_time_sec=None,
            copy_audio=True,
            embed_image=True,
        )
        mel = mel4d.squeeze(0)
        mel2d = mel[0] if mel.ndim == 3 else mel  # (80, T)

        mel2d_np = mel2d.detach().float().cpu().numpy()
        cam_on_mel = upsample_hw(grayscale_cam, mel2d_np.shape)
        mel_masked = np.where(cam_on_mel >= thr, mel2d_np, mel2d_np.min())
        mel_masked = np.flipud(mel_masked)
        plot_mel(
            mel_masked,
            save_path=layer_dir / f"{fname}_mel_masked_plot_{target_speaker}_thr{thr:.2f}.png",
            title=f"{fname} ({lname}, target={target_speaker})",
        )
        plt.imsave(originals_dir / f"{fname}_rgb.png", rgb_big )

        spec_png = originals_dir / f"{fname}_rgb.png"

        print(f"Saved: {layer_dir}")



In [38]:
wav_dir = PROJECT_ROOT / "data" / "augmented_wavs"

# test_wav = wav_dir / "idan_012.wav"
# run_out_dir = Path(f"gradcam_results_2.0/idan/{test_wav.stem}")
# originals_dir = run_out_dir / "original"

# run_gradcam_speaker(
#     wav_path=str(test_wav),
#     target_speaker="idan",
#     save_dir=str(run_out_dir),
#     thr=0.2,
#     upscale=12,
# )

# top3_by_user = {
#     "eden": ["eden_017.wav", "eden_021.wav", "eden_012.wav"],
#     "idan": ["idan_009.wav", "idan_004.wav", "idan_012.wav"],
#     "yoav": ["yoav_028.wav", "yoav_024.wav", "yoav_022.wav"],
# }

wav_by_user: dict[str, list[Path]] = {}
for spk in SPEAKERS:  
    wav_by_user[spk] = list(wav_dir.glob(f"{spk}_*.wav"))
    
for spk, fnames in wav_by_user.items():
    for fname in fnames:
        wav_path = wav_dir / fname
        if not wav_path.exists():
            print(f"[WARN] missing file: {wav_path}")
            continue

        run_gradcam_speaker(
            wav_path=str(wav_path),
            target_speaker=spk,
            save_dir=f"gradcam_results_3.0//{spk}/{wav_path.stem}",
            thr=0.2,      
            upscale=8,
        )

KeyboardInterrupt: 