In [11]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parents[1]   
sys.path.append(str(PROJECT_ROOT))

print("Added to path:", PROJECT_ROOT)


Added to path: /home/SpeakerRec/BioVoice


In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


from utils.Preprocess import audio_to_mel_spectrogram


In [15]:


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# -----------------------------------------------------------
# Load pretrained ReDimNet (IDRnD hub model)
# -----------------------------------------------------------
redim_model = torch.hub.load(
    "IDRnD/ReDimNet",
    "ReDimNet",
    model_name="b5",         
    train_type="ptn",
    dataset="vox2",
).to(DEVICE).eval()


print("ReDimNet loaded successfully.")


# -----------------------------------------------------------
# Function: embed WAV file with ReDimNet
# Returns a 192-D normalized embedding
# -----------------------------------------------------------
def embed_with_redim(wav_path: str) -> np.ndarray:
    # Load waveform
    wav, sr = torchaudio.load(wav_path)

    # Resample to 16 kHz if needed
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)

    # Keep mono only
    if wav.shape[0] > 1:
        wav = wav[:1, :]

    wav = wav.to(DEVICE).float()

    # Embedding extraction
    with torch.no_grad():
        emb = redim_model(wav)   # shape: [1, 192]

    emb = emb.squeeze().cpu().numpy()  # → [192]

    # Normalize (mandatory for cosine similarity)
    emb = emb / (np.linalg.norm(emb) + 1e-12)

    return emb



Using device: cuda


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master


ReDimNet loaded successfully.


In [17]:
# -----------------------------------------------------------
# CELL 3 — Reference Embedding (using PROJECT_ROOT paths)
# -----------------------------------------------------------

wav_dir = PROJECT_ROOT / "data" / "wavs"

best_templates = {
    "eden": wav_dir / "eden_013.wav",
    "idan": wav_dir / "idan_012.wav",
    "yoav": wav_dir / "yoav_022.wav",
}


ref_emb = embed_with_redim(str(best_templates["yoav"]))

print("Reference embedding shape:", ref_emb.shape)
print("Reference embedding norm:", np.linalg.norm(ref_emb))


Reference embedding shape: (192,)
Reference embedding norm: 0.99999994


In [20]:
# -----------------------------------------------------------
# CELL 4 — Cosine Similarity Target for Grad-CAM
# -----------------------------------------------------------


class CosineSimilarityTarget:
    """
    Target function for Grad-CAM on embedding models.
    Returns a scalar cosine similarity between:
        model_output = embedding(test_wav)  [1, 192]
        reference_embedding = ref_emb       [192]
    Grad-CAM uses gradients of this score to compute heatmaps.
    """
    def __init__(self, reference_embedding):
        # Convert reference embedding → Torch Tensor on correct device
        self.ref = torch.tensor(reference_embedding).float().to(DEVICE)

    def __call__(self, model_output):
        # model_output shape: [1, 192]
        emb = model_output.squeeze(0)  # → [192]
        return F.cosine_similarity(emb, self.ref, dim=0)


In [22]:
# -----------------------------------------------------------
# CELL 5 — Prepare Mel Spectrogram for Visualization
# -----------------------------------------------------------


def wav_to_mel_for_display(wav_path: str):
    """
    Uses your existing audio_to_mel_spectrogram() utility to
    generate the mel spectrogram for visualization.

    Returns:
        mel       -> raw mel spectrogram (numpy) shape [n_mels, T]
        rgb_base  -> normalized 3-channel image for CAM overlay
    """
    mel = audio_to_mel_spectrogram(
        file_path=Path(wav_path),
        normalization_fn=lambda x: x,   # keep raw mel values
    )

    mel = mel.astype("float32")

    # Normalize mel → [0,1] for visualization
    mel_min = mel.min()
    mel_max = mel.max()
    mel_norm = (mel - mel_min) / (mel_max - mel_min + 1e-8)

    # Convert to RGB (GradCAM overlay expects 3 channels)
    rgb_base = np.stack([mel_norm] * 3, axis=-1)

    return mel, rgb_base


In [24]:
# -----------------------------------------------------------
# CELL 6 — Define ReDimNet Target Layers (6 stages)
# -----------------------------------------------------------

# Each stage has a Conv2d layer at index [2]
target_layers = [
    redim_model.backbone.stage0[2],
    redim_model.backbone.stage1[2],
    redim_model.backbone.stage2[2],
    redim_model.backbone.stage3[2],
    redim_model.backbone.stage4[2],
    redim_model.backbone.stage5[2],
]

layer_names = ["stage0", "stage1", "stage2", "stage3", "stage4", "stage5"]

print("Number of target layers:", len(target_layers))
for name, layer in zip(layer_names, target_layers):
    print(f"{name}: {layer.__class__.__name__}")


Number of target layers: 6
stage0: Conv2d
stage1: Conv2d
stage2: Conv2d
stage3: Conv2d
stage4: Conv2d
stage5: Conv2d


In [33]:
# -----------------------------------------------------------
# CELL 7 — Run Grad-CAM on ALL stages using MODEL MEL
# -----------------------------------------------------------

from pytorch_grad_cam import GradCAM

def run_gradcam_redim(wav_path, ref_emb, save_dir="gradcam_results_redim"):
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    wav_path = str(wav_path)
    fname = Path(wav_path).stem

    # 1) Load waveform
    wav, sr = torchaudio.load(wav_path)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    wav = wav[:1, :].float().to(DEVICE)

    # 2) Use ReDimNet's EXACT mel (what the model really sees)
    with torch.no_grad():
        mel = redim_model.spec(wav)            # [1, 80, T]

    # Prepare mel for CAM (must be 4D: [1,1,H,W])
    mel_for_cam = mel.unsqueeze(0)             # [1, 1, 80, T]

    # Prepare mel image for overlay
    mel_np = mel.squeeze(0).cpu().numpy()      # [80, T]
    mel_norm = (mel_np - mel_np.min()) / (mel_np.max() - mel_np.min() + 1e-8)
    rgb_base = np.stack([mel_norm]*3, axis=-1)

    # 3) Cosine similarity target
    target = CosineSimilarityTarget(ref_emb)

    # 4) Run GradCAM on all stages
    for layer, name in zip(target_layers, layer_names):

        cam = GradCAM(model=redim_model, target_layers=[layer])

        grayscale_cam = cam(
            input_tensor=mel_for_cam,
            targets=[target],
            aug_smooth=True,
            eigen_smooth=True
        )[0]

        overlay = show_cam_on_image(rgb_base, grayscale_cam, use_rgb=True)

        out_file = save_dir / f"{fname}_{name}.png"
        plt.imsave(out_file, overlay)
        print("Saved:", out_file)


In [34]:
test_wav = PROJECT_ROOT / "data" / "wavs" / "yoav_001.wav"
run_gradcam_redim(test_wav, ref_emb)


NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now