In [None]:
import sys
from pathlib import Path

# PROJECT_ROOT = BioVoice/
PROJECT_ROOT = Path.cwd().parents[1]
sys.path.append(str(PROJECT_ROOT))

print("PROJECT_ROOT =", PROJECT_ROOT)



In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image




In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

redim_model = torch.hub.load(
    "IDRnD/ReDimNet",
    "ReDimNet",
    model_name="b5",
    train_type="ptn",
    dataset="vox2",
).to(DEVICE).eval()


print("Loaded ReDimNet successfully.")



In [None]:
def embed_with_redim(wav_path: str):
    wav, sr = torchaudio.load(wav_path)

    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)

    wav = wav[:1, :].float().to(DEVICE)

    with torch.no_grad():
        emb = redim_model(wav)      # [1,192]

    emb = emb.squeeze().cpu().numpy()
    emb = emb / (np.linalg.norm(emb) + 1e-12)
    return emb

class CosineSimilarityTarget:
    def __init__(self, ref_embedding):
        self.ref = torch.tensor(ref_embedding).float().to(DEVICE)

    def __call__(self, model_output):
        emb = model_output.squeeze(0)     # [192]
        return F.cosine_similarity(emb, self.ref, dim=0)

def wav_to_model_mel(wav_path):
    wav, sr = torchaudio.load(wav_path)

    wav = wav[:1, :].float().to(DEVICE)

    with torch.no_grad():
        mel = redim_model.spec(wav)   # [1, 80, T]

    mel4d = mel.unsqueeze(0)          # [1, 1, 80, T]

    mel_np = mel.squeeze(0).cpu().numpy()
    mel_norm = (mel_np - mel_np.min()) / (mel_np.max() - mel_np.min() + 1e-8)
    rgb_base = np.stack([mel_norm] * 3, axis=-1)
    # rgb_base = np.flipud(rgb_base)

    return mel4d, rgb_base


def plot_mel(mel, save_path=None, title=None):
    """
    mel: torch.Tensor or np.ndarray with shape [80, T] or [1, 80, T]
    """

    if isinstance(mel, torch.Tensor):
        mel = mel.squeeze(0).cpu().numpy()

    # normalize for visualization
    mel_norm = (mel - mel.min()) / (mel.max() - mel.min() + 1e-8)

    # flip so low freq is at bottom
    mel_norm = np.flipud(mel_norm)

    plt.figure(figsize=(10, 4))
    plt.imshow(mel_norm, aspect="auto", cmap="magma")
    plt.colorbar(label="Normalized Energy")
    plt.xlabel("Time Frames")
    plt.ylabel("Mel Frequency Bins")

    if title:
        plt.title(title)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=200)
        plt.close()
    else:
        plt.show()




In [None]:
wav_dir = PROJECT_ROOT / "data" / "wavs"

best_templates = {
    "eden": wav_dir / "eden_013.wav",
    "idan": wav_dir / "idan_012.wav",
    "yoav": wav_dir / "yoav_022.wav",
}

ref_emb = embed_with_redim(str(best_templates["yoav"]))

print("Reference embedding shape =", ref_emb.shape)



In [None]:
class ReDimNetMelWrapper(torch.nn.Module): # grad cam expects a model and hooks into forward (4d input)
    def __init__(self, redim_model):
        super().__init__()
        self.backbone = redim_model.backbone # copy the feature extractor
        self.pool     = redim_model.pool 
        self.bn       = redim_model.bn
        self.linear   = redim_model.linear

    def forward(self, mel4d):
        """
        mel4d shape: [B, 1, 80, T]
        """
        x = self.backbone(mel4d)   # goes through stem, stage0..stage5 automatically
        x = self.pool(x) 
        x = self.bn(x)
        x = self.linear(x)
        return x
wrapped_model = ReDimNetMelWrapper(redim_model).to(DEVICE).eval()
print("Wrapper ready.")



In [None]:

# Pick ONE Conv2d per stage + stem
target_layers = [
    wrapped_model.backbone.stem[0],      # Conv2d(1, 32, 3x3, ...)
    wrapped_model.backbone.stage0[2],    # Conv2d in stage0
    wrapped_model.backbone.stage1[2],    # Conv2d in stage1
    wrapped_model.backbone.stage2[2],    # Conv2d in stage2
    wrapped_model.backbone.stage3[2],    # Conv2d in stage3
    wrapped_model.backbone.stage4[2],    # Conv2d in stage4
    wrapped_model.backbone.stage5[2],    # Conv2d in stage5
]

layer_names = ["stem", "stage0", "stage1", "stage2", "stage3", "stage4", "stage5"]

print("Target Conv2d layers:")
for name, layer in zip(layer_names, target_layers):
    print(f"{name}: {layer}")



In [None]:
def to_float01(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    if img.max() > 1.0:
        img = img / 255.0
    return np.clip(img, 0.0, 1.0)


In [None]:
def run_gradcam_redim(wav_path, ref_emb, save_dir="gradcam_results_redim_v2", thr=0.8):
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    fname = Path(wav_path).stem

    originals_dir = save_dir / "original"
    originals_dir.mkdir(parents=True, exist_ok=True)

    # ---- compute mel + base image ----
    mel4d, rgb_base = wav_to_model_mel(wav_path)
    rgb_base01 = to_float01(rgb_base)

    out_file = originals_dir / f"{fname}_mel_rgbbase.png"
    plt.imsave(out_file, rgb_base01)
    print("Saved:", out_file)

    mel = mel4d.squeeze(0)  # expected (80, T) or (1,80,T) depending on your pipeline
    plot_mel(
        mel,
        save_path=originals_dir / f"{fname}_original_mel.png",
        title=fname
    )
    print("Saved:", originals_dir / f"{fname}_original_mel.png")

    target = CosineSimilarityTarget(ref_emb)

    for layer, lname in zip(target_layers, layer_names):
        layer_dir = save_dir / lname
        layer_dir.mkdir(parents=True, exist_ok=True)

        cam = GradCAM(model=wrapped_model, target_layers=[layer])

        grayscale_cam = cam(
            input_tensor=mel4d,
            targets=[target],
            aug_smooth=True,
            eigen_smooth=True
        )[0]

        h, w = rgb_base01.shape[:2]
        if grayscale_cam.shape == (w, h):
            grayscale_cam = grayscale_cam.T
        if grayscale_cam.shape != (h, w):
            raise ValueError(
                f"shape mismatch: rgb_base={rgb_base01.shape}, grayscale_cam={grayscale_cam.shape} ({lname})"
            )

        mask = (grayscale_cam >= thr).astype(np.float32)
        cam_masked = grayscale_cam * mask

        overlay_masked = show_cam_on_image(rgb_base01, cam_masked, use_rgb=True)
        overlay_masked01 = to_float01(overlay_masked)

        overlay_only = overlay_masked01 * mask[..., None]
        overlay_only = np.clip(overlay_only, 0.0, 1.0)

        plt.imsave(layer_dir / f"{fname}_overlay_thr{thr:.2f}.png", overlay_masked01)
        plt.imsave(layer_dir / f"{fname}_overlay_only_thr{thr:.2f}.png", overlay_only)
        plt.imsave(layer_dir / f"{fname}_mask_thr{thr:.2f}.png", mask, cmap="gray")

        print(f"Saved layer outputs in: {layer_dir}")



In [None]:
test_wav = wav_dir / "idan_022.wav"
run_gradcam_redim(str(test_wav), ref_emb , save_dir=f"gradcam_results_redim_v3/idan/{Path(test_wav).stem}", thr=0.3)

# wav_by_user = {
#     "eden": [],
#     "idan": [],
#     "yoav": [],
# }

# for wav_file in sorted(wav_dir.glob("eden_*.wav"))[:5]:
#     wav_by_user["eden"].append(wav_file)
# for wav_file in sorted(wav_dir.glob("idan_*.wav"))[:5]:
#     wav_by_user["idan"].append(wav_file)

# for wav_file in sorted(wav_dir.glob("yoav_*.wav"))[:5]:
#     wav_by_user["yoav"].append(wav_file)



# for speaker, wav_list in wav_by_user.items():
#     for wav_path in wav_list:
#         print(f"Processing {speaker} - {wav_path.name} ...")
#         run_gradcam_redim(str(wav_path), ref_emb, save_dir=f"gradcam_results_redim_v3/{speaker}/{Path(wav_path).stem}")

    




In [None]:
for speaker, wav_list in wav_by_user.items():
    for wav_path in wav_list:
        test_emb = embed_with_redim(str(wav_path))
        cos = float(np.dot(test_emb, ref_emb))  
        print(wav_path.name, cos)
