In [1]:

from pathlib import Path
import os
import random
import torchaudio
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from speechbrain.pretrained import EncoderClassifier
import itertools
import ast
import json
import librosa
import librosa.display
import matplotlib.pyplot as plt
import torch.nn as nn


  from .autonotebook import tqdm as notebook_tqdm
The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.
  torchaudio.set_audio_backend("soundfile")
torchvision is not available - cannot save figures
The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.
  torchaudio.set_audio_backend("soundfile")


In [2]:
# --------- Grad-CAM hook class ----------
class GradCAM:
    """
    Grad-CAM (Gradient-weighted Class Activation Mapping) for audio.
    
    Identifies which time frames in audio are most important for the model's 
    classification decision by:
    1. Recording layer activations during forward pass
    2. Recording gradients during backward pass
    3. Computing importance weights (gradients averaged over time)
    4. Creating a heatmap showing per-frame importance
    
    This helps visualize what the neural network "pays attention to" when 
    classifying a speaker.
    """
    
    def __init__(self, target_layer):
        """
        Initialize Grad-CAM for a specific layer.
        
        Args:
            target_layer: The neural network layer to analyze (e.g., last embedding layer)
        """
        self.target_layer = target_layer
        self.activations = None  # Will store layer output: [batch, channels, time_frames]
        self.gradients = None    # Will store gradients: [batch, channels, time_frames]

        # Install hooks that automatically capture data during forward/backward passes
        self.target_layer.register_forward_hook(self._forward_hook)
        self.target_layer.register_backward_hook(self._backward_hook)

    def _forward_hook(self, module, inp, out):
        """
        Called automatically during forward pass.
        Saves the layer's output (activations) for later analysis.
        
        Shape: [batch=1, channels=128, time_frames=1000] for audio
        """
        self.activations = out.detach()  


    def _backward_hook(self, module, grad_input, grad_output):
        """
        Called automatically during backward pass.
        Saves gradients flowing back through this layer.
        
        Gradients show how much each activation contributed to the loss.
        Shape: [batch=1, channels=128, time_frames=1000]
        """
        self.gradients = grad_output[0]

    def generate(self):
        """
        Generate the Grad-CAM heatmap from captured activations and gradients.
        
        Algorithm:
        1. Compute importance weight for each channel (average gradient over time)
        2. Weight each channel's activation by its importance
        3. Sum weighted activations across channels → per-frame importance score
        4. Normalize to [0, 1] range for visualization
        
        Returns:
            np.array: 1D array of shape [time_frames] with values in [0, 1]
                     0 = not important, 1 = very important for classification
        """


        grads = self.gradients        # [batch=1, channels=128, time=1000]
        acts = self.activations       # [batch=1, channels=128, time=1000]

        # STEP 1: Compute importance weight per channel
        # Average gradients over time dimension to get one weight per channel
        # This tells us: "how important is channel X overall?"
        weights = grads.mean(dim=2, keepdim=True)  # [batch=1, channels=128, 1]

        # STEP 2: Weight each activation by its channel's importance
        # Then sum across channels → importance score per time frame
        # This tells us: "for each time frame, how important is it?"
        cam = (weights * acts).sum(dim=1)  # [batch=1, time=1000]

        # STEP 3: Clean up the heatmap
        cam = F.relu(cam)  # Keep only positive contributions
        cam = cam.squeeze(0).detach().cpu().numpy()  # Convert [1, 1000] → [1000] numpy array

        # STEP 4: Normalize to [0, 1] range for visualization
        cam -= cam.min()  # Shift minimum to 0
        cam /= (cam.max() + 1e-8)  # Scale maximum to 1 (add 1e-8 to avoid division by zero)
        
        return cam  # [time_frames] with values in [0, 1]


In [3]:
output_dir = "gradcam_results"
os.makedirs(output_dir, exist_ok=True)
data_dir = "data"
# Reload the ecapa model fresh
model_path = "ecapa_pretrained"  
ecapa = EncoderClassifier.from_hparams(
    source=model_path,
    savedir=model_path
)

# # # Save model layers to a text file
# model_str = str(ecapa.mods)
# with open("ecapa_layers.txt", "w", encoding="utf-8") as f:
#     f.write(model_str)

# Use the last SERes2Net block 
target_layer = ecapa.mods.embedding_model.blocks[-1]
cam_extractor = GradCAM(target_layer)

num_speakers = 3  # yoav, idan, eden #NOTE: ecapa pretrained has 1 speaker in classifier head so we need to change it to 3

# New classifier head on top of ECAPA embedding
new_classifier = nn.Linear(192, num_speakers)
new_classifier = new_classifier.to(next(ecapa.parameters()).device)

print("New classifier:", new_classifier)

  torch.load(path, map_location=device), strict=False


New classifier: Linear(in_features=192, out_features=3, bias=True)


  stats = torch.load(path, map_location=device)


In [4]:
# Load data files

files = [f for f in os.listdir(data_dir) if f.endswith(".wav")]


speakers = {"yoav": [], "idan": [], "eden": []}

for f in files:
    prefix = f.split("_")[0]  
    if prefix in speakers:
        speakers[prefix].append(os.path.join(data_dir, f))


selected = {}

for speaker, file_list in speakers.items():
    selected[speaker] = file_list[:5]  

# with open("selected_files.json", "w", encoding="utf-8") as f:
#     json.dump(selected, f, indent=4)


In [None]:
# def run_gradcam_on_wav(wav_path, speaker):
#     # Load audio at 16kHz and convert to tensor
#     wav, sr = librosa.load(wav_path, sr=16000)
#     wav_tensor = torch.tensor([wav]).float()

#     # Switch to eval mode for inference
#     ecapa.eval()
#     new_classifier.eval()

#     # Enable gradients even in eval mode (needed for Grad-CAM)
#     with torch.enable_grad():
#         # Step 1: Convert raw audio to mel-spectrogram features
#         features = ecapa.mods.compute_features(wav_tensor)
#         lengths = torch.tensor([features.shape[-1]])
#         # Normalize features (zero mean, unit variance)
#         features = ecapa.mods.mean_var_norm(features, lengths)

#         # Enable gradient tracking for feature analysis
#         features.requires_grad_(True)

#         # Step 2: Extract speaker embedding (192-dim vector)
#         emb = ecapa.mods.embedding_model(features)
#         emb = emb.squeeze()  # Remove any size-1 dimensions

#         # Ensure embedding is 2D for classifier: [1, 192]
#         if emb.dim() == 1:
#             emb = emb.unsqueeze(0)

#         # Step 3: Classify speaker (get logits for 3 speakers)
#         logits = new_classifier(emb)  # [1, 3] → scores for each speaker
#         logits = logits.squeeze(0)    # Remove batch dim → [3]

#         # Get predicted class (0=yoav, 1=idan, 2=eden)
#         pred_class = logits.argmax().item()
#         print(f"Predicted class: {pred_class}, logits: {logits}")

#         # Step 4: Compute gradients for Grad-CAM
#         # Zero out previous gradients
#         ecapa.zero_grad()
#         new_classifier.zero_grad()

#         # Create one-hot vector for predicted class
#         one_hot = torch.zeros_like(logits)
#         one_hot[pred_class] = 1

#         # Backward pass: compute gradients w.r.t. predicted class
#         logits.backward(gradient=one_hot, retain_graph=True)

#         # Step 5: Generate Grad-CAM heatmap (not displayed, just for reference)
#         cam = cam_extractor.generate()  # Returns normalized heatmap [T]

#         # --- Prepare spectrogram for visualization ---
#         # Compute mel-spectrogram to match ECAPA preprocessing for display
#         n_fft = 512
#         hop_length = 80
#         n_mels = 80
#         mel_spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft,
#                                                   hop_length=hop_length, n_mels=n_mels)
#         log_mel = librosa.power_to_db(mel_spec, ref=np.max)

#         # Step 6: Plot mel-spectrogram only
#         fig, ax = plt.subplots(figsize=(10, 4))
#         librosa.display.specshow(log_mel, sr=sr, hop_length=hop_length, x_axis='time', y_axis='mel', ax=ax)
#         ax.set_title(f"Mel-spectrogram: {os.path.basename(wav_path)} (Predicted: {pred_class})")
        
#         plt.tight_layout()
#         out_path = os.path.join(output_dir, f"gradcam_{os.path.basename(wav_path)}.png")
#         fig.savefig(out_path)
#         plt.close(fig)
#         print(f"Saved Mel-spectrogram to {out_path}")

In [5]:
def run_gradcam_on_wav(wav_path, speaker):
    # Load audio at 16kHz and convert to tensor
    wav, sr = librosa.load(wav_path, sr=16000)
    wav_tensor = torch.tensor([wav]).float()

    # Switch to eval mode for inference
    ecapa.eval()
    new_classifier.eval()

    # Enable gradients even in eval mode (needed for Grad-CAM)
    with torch.enable_grad():
        # Step 1: Convert raw audio to mel-spectrogram features
        features = ecapa.mods.compute_features(wav_tensor)
        lengths = torch.tensor([features.shape[-1]])
        features = ecapa.mods.mean_var_norm(features, lengths)

        # Enable gradient tracking for feature analysis
        features.requires_grad_(True)

        # Step 2: Extract speaker embedding
        emb = ecapa.mods.embedding_model(features)
        emb = emb.squeeze()

        if emb.dim() == 1:
            emb = emb.unsqueeze(0)

        # Step 3: Classify speaker
        logits = new_classifier(emb)
        logits = logits.squeeze(0)
        pred_class = logits.argmax().item()
        print(f"Predicted class: {pred_class}, logits: {logits}")

        # Step 4: Compute gradients for Grad-CAM
        ecapa.zero_grad()
        new_classifier.zero_grad()

        one_hot = torch.zeros_like(logits)
        one_hot[pred_class] = 1
        logits.backward(gradient=one_hot, retain_graph=True)

        # Step 5: Generate Grad-CAM heatmap
        cam = cam_extractor.generate()  # [T]

        # Prepare mel spectrogram
        n_fft = 512
        hop_length = 80
        n_mels = 80
        mel_spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft,
                                                  hop_length=hop_length, n_mels=n_mels)
        log_mel = librosa.power_to_db(mel_spec, ref=np.max)

        fig, ax = plt.subplots(figsize=(10, 4))

        # 1. plot mel spectrogram
        librosa.display.specshow(log_mel, sr=sr, hop_length=hop_length,
                                 x_axis='time', y_axis='mel', cmap='magma', ax=ax)

        # 2. resize CAM to spectrogram time axis
        cam_resized = np.interp(
            np.linspace(0, len(cam), log_mel.shape[1]),
            np.arange(len(cam)),
            cam
        )

        # 3. overlay CAM (heatmap)
        ax.imshow(cam_resized[np.newaxis, :],
                  cmap='jet',
                  aspect='auto',
                  alpha=0.4,
                  extent=[0, log_mel.shape[1], 0, log_mel.shape[0]])

        ax.set_title(f"Grad-CAM: {os.path.basename(wav_path)} (Predicted: {pred_class})")

        plt.tight_layout()
        out_path = os.path.join(output_dir, f"gradcam_{os.path.basename(wav_path)}.png")
        fig.savefig(out_path)
        plt.close(fig)
        print(f"Saved Grad-CAM overlay to {out_path}")


In [6]:
# run_gradcam_on_wav("data/yoav_001.wav", "yoav")
for speaker in selected:
    print(f"\n===== Speaker: {speaker} =====")
    for wav_path in selected[speaker]:
        run_gradcam_on_wav(wav_path, speaker)




===== Speaker: yoav =====


  wav_tensor = torch.tensor([wav]).float()
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\SpectralOps.cpp:878.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Predicted class: 1, logits: tensor([-15.2227,  -1.5052, -10.4673], grad_fn=<SqueezeBackward1>)


  fig.savefig(out_path)


Saved Grad-CAM overlay to gradcam_results\gradcam_yoav_001.wav.png


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Predicted class: 1, logits: tensor([-13.1848,  -8.0947, -13.4141], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_yoav_002.wav.png
Predicted class: 1, logits: tensor([-9.2355,  4.7210,  0.4761], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_yoav_003.wav.png
Predicted class: 1, logits: tensor([-16.0184,  -7.0774,  -9.7240], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_yoav_004.wav.png
Predicted class: 1, logits: tensor([-3.9303,  5.8597, -6.4010], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_yoav_005.wav.png

===== Speaker: idan =====
Predicted class: 1, logits: tensor([-13.3602,   8.0547,   1.7203], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_idan_001.wav.png
Predicted class: 2, logits: tensor([-7.7366,  1.1696, 10.7644], grad_fn=<SqueezeBackward1>)
Saved Grad-CAM overlay to gradcam_results\gradcam_idan_002.wav.png
Predicted 