In [None]:
# %%
from __future__ import annotations

import glob
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

# ----------------------------
# Config
# ----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

PROJECT_ROOT = Path("/home/SpeakerRec/BioVoice")
WAV_DIR = PROJECT_ROOT / "data" / "wavs"
CONCEPTS_ROOT = PROJECT_ROOT / "concept" / "new_concepts"
OUT_DIR = PROJECT_ROOT / "output"
OUT_DIR.mkdir(parents=True, exist_ok=True)
OUT_CSV = OUT_DIR / "tcav_results.csv"

# Mel windowing
N_MELS = 72
TARGET_FRAMES = 192
WINDOW_STRIDE = 24

# CAV training
SAMPLES_PER_WAV_FOR_CAV = 2
MAX_WAVS_FOR_CAV = 400
BATCH_SIZE_CAV = 32

# Stamp
STAMP_STRENGTH = 1.0      # scaled by background std
MAX_ABS_DELTA = 4.0       # clamp

# Layer keys you want to probe
LAYER_KEYS =  ["stem","stage0","stage1","stage2","stage3","stage4","stage5"]

EVAL_WAV_GLOB = str(WAV_DIR / "**" / "*.wav")

def infer_true_label_from_path(wav_path: Path) -> str:
    return wav_path.parent.name

def id_to_name(i: int, id_to_speaker):
    if isinstance(id_to_speaker, (list, tuple)):
        return str(id_to_speaker[i])
    return str(id_to_speaker[i])

print("DEVICE:", DEVICE)
print("WAV_DIR:", WAV_DIR)
print("CONCEPTS_ROOT:", CONCEPTS_ROOT)
print("OUT_CSV:", OUT_CSV)

# ----------------------------
# Model + head + wrapper
# ----------------------------
redim_model = torch.hub.load(
    "IDRnD/ReDimNet",
    "ReDimNet",
    model_name="b5",
    train_type="ptn",
    dataset="vox2",
).to(DEVICE).eval()

HEAD_PATH = Path.cwd() / "output" / "redim_speaker_head_linear.pt"
assert HEAD_PATH.exists(), f"Missing head checkpoint: {HEAD_PATH}"

ckpt = torch.load(HEAD_PATH, map_location=DEVICE)
speaker_to_id: Dict[str, int] = ckpt["speaker_to_id"]
id_to_speaker = ckpt["id_to_speaker"]
l2_norm_emb = bool(ckpt.get("l2_norm_emb", True))

fc_w = ckpt["state_dict"]["fc.weight"]
in_dim = int(fc_w.shape[1])
num_classes = int(fc_w.shape[0])

print("Loaded head:", HEAD_PATH)
print("Head in_dim:", in_dim, "num_classes:", num_classes, "l2_norm_emb:", l2_norm_emb)

class SpeakerHead(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

head = SpeakerHead(in_dim=in_dim, num_classes=num_classes).to(DEVICE)
head.load_state_dict(ckpt["state_dict"])
head.eval()

class ReDimNetMelLogitsWrapper(nn.Module):
    """
    Input:  mel4d [B, 1, N_MELS, T]
    Output: logits [B, num_speakers]
    """
    def __init__(self, redim_model, head: nn.Module, l2_norm_emb: bool):
        super().__init__()
        self.backbone = redim_model.backbone
        self.pool = redim_model.pool
        self.bn = redim_model.bn
        self.linear = redim_model.linear
        self.head = head
        self.l2_norm_emb = l2_norm_emb

    def forward(self, mel4d: torch.Tensor) -> torch.Tensor:
        # device-safe
        mel4d = mel4d.to(next(self.parameters()).device).float()

        x = self.backbone(mel4d)
        x = self.pool(x)
        x = self.bn(x)
        emb = self.linear(x)
        if self.l2_norm_emb:
            emb = emb / (emb.norm(p=2, dim=1, keepdim=True) + 1e-12)
        return self.head(emb)

wrapped_model = ReDimNetMelLogitsWrapper(redim_model, head, l2_norm_emb=l2_norm_emb).to(DEVICE).eval()
print("wrapped_model ready.")

# ----------------------------
# Layer resolution + hooks
# ----------------------------
def resolve_layer_module(backbone: torch.nn.Module, layer_key: str) -> torch.nn.Module:
    """
    Resolve a layer module from the *exact backbone instance used in forward*.
    """
    bk = backbone

    if hasattr(bk, layer_key):
        return getattr(bk, layer_key)

    if layer_key.startswith("stage") and layer_key[5:].isdigit():
        idx = int(layer_key[5:])
        for attr in ["stages", "stage", "blocks", "layers"]:
            if hasattr(bk, attr):
                seq = getattr(bk, attr)
                try:
                    return seq[idx]
                except Exception:
                    pass

    candidates = []
    for name, mod in bk.named_modules():
        if name.endswith(layer_key) or (layer_key in name):
            candidates.append((name, mod))
    if candidates:
        candidates.sort(key=lambda x: len(x[0]))
        print(f"[resolve_layer_module] Using fallback match: {candidates[0][0]}")
        return candidates[0][1]

    raise ValueError(f"Could not resolve layer_key={layer_key}")
def pool_activation_to_vec(act: torch.Tensor) -> torch.Tensor:
    if act.ndim == 2:      # [B,C]
        return act
    if act.ndim == 3:      # [B,C,T]
        return act.mean(dim=2)
    if act.ndim == 4:      # [B,C,H,W]
        return act.mean(dim=(2,3))
    return act.flatten(start_dim=1)

class LayerHook:
    def __init__(self, layer: torch.nn.Module):
        self.out = None
        self.h = layer.register_forward_hook(self._hook)
    def _hook(self, module, inp, out):
        if isinstance(out, (tuple, list)):
            out = out[0]
        self.out = out
    def close(self):
        self.h.remove()

# ----------------------------
# Mel cache (huge speedup)
# ----------------------------
class MelStore:
    def __init__(self):
        self.cache: Dict[str, torch.Tensor] = {}

    @torch.no_grad()
    def get(self, wav_path: Path) -> torch.Tensor:
        k = str(wav_path)
        if k in self.cache:
            return self.cache[k]

        wav, _sr = torchaudio.load(k)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.to(DEVICE).float()                # IMPORTANT: match spec device
        mel3d = redim_model.spec(wav)               # [1,N_MELS,T] on DEVICE
        mel2d = mel3d[0].detach().cpu().float()     # cache on CPU

        if mel2d.shape[0] != N_MELS:
            raise ValueError(f"Expected N_MELS={N_MELS}, got {mel2d.shape[0]} for {wav_path}")

        self.cache[k] = mel2d
        return mel2d

mel_store = MelStore()

# ----------------------------
# Windows
# ----------------------------
def mel2d_to_windows(mel2d: torch.Tensor, target_frames: int, stride: int) -> torch.Tensor:
    """
    Returns windows4d: [B,1,N_MELS,target_frames] on CPU float32.
    """
    n_mels, t = mel2d.shape
    if t <= target_frames:
        chunk = mel2d
        if t < target_frames:
            chunk = torch.nn.functional.pad(chunk, (0, target_frames - t), value=0.0)
        return chunk.unsqueeze(0).unsqueeze(0).contiguous().float()

    starts = list(range(0, t - target_frames + 1, stride))
    wins = [mel2d[:, s:s+target_frames] for s in starts]
    x = torch.stack(wins, dim=0)  # [B,N_MELS,T]
    return x.unsqueeze(1).contiguous().float()  # [B,1,N_MELS,T]

# ----------------------------
# Concept patches (0..1 raw masks only)
# ----------------------------
def extract_nonzero_patch(mask2d: torch.Tensor, threshold: float = 1e-3) -> torch.Tensor:
    m = mask2d > threshold
    if not m.any():
        raise ValueError("Mask has no nonzero region.")
    ys, xs = torch.where(m)
    y0, y1 = int(ys.min()), int(ys.max()) + 1
    x0, x1 = int(xs.min()), int(xs.max()) + 1
    return mask2d[y0:y1, x0:x1].clone()

def load_concept_patches(concept_dir: Path) -> List[torch.Tensor]:
    npys = sorted(concept_dir.rglob("*.npy"))
    if not npys:
        raise ValueError(f"No .npy files found under {concept_dir}")

    patches: List[torch.Tensor] = []
    skipped = 0

    for p in npys:
        arr = np.load(str(p))
        if arr.ndim != 2:
            skipped += 1
            continue
        t = torch.from_numpy(arr).float()
        if t.shape[0] != N_MELS:
            skipped += 1
            continue
        mn = float(t.min().item())
        mx = float(t.max().item())
        # accept only 0..1 masks (raw_energy / systematic raw)
        if mn < -1e-3 or mx > 1.0 + 1e-3:
            skipped += 1
            continue
        try:
            patches.append(extract_nonzero_patch(t, threshold=1e-3))
        except Exception:
            skipped += 1

    if not patches:
        raise ValueError(
            f"No usable 0..1 mask npys found under {concept_dir}. "
            f"Ensure raw_energy/systematic raw masks are saved (0..1)."
        )
    print(f"[concept] {concept_dir.name}: loaded {len(patches)} patches, skipped {skipped}")
    return patches

def list_concept_dirs(concepts_root: Path) -> List[Path]:
    if (concepts_root / "raw_energy.npy").exists():
        return [concepts_root]
    dirs = [p for p in concepts_root.iterdir() if p.is_dir()]
    dirs.sort()
    return dirs

# ----------------------------
# In-distribution datasets
# ----------------------------
@dataclass(frozen=True)
class StampConfig:
    strength: float = STAMP_STRENGTH
    max_abs_delta: float = MAX_ABS_DELTA
    seed: int = 0

def random_crop_or_pad(mel2d: torch.Tensor, target_frames: int, rng: random.Random) -> torch.Tensor:
    _n_mels, t = mel2d.shape
    if t == target_frames:
        return mel2d
    if t > target_frames:
        s = rng.randint(0, t - target_frames)
        return mel2d[:, s:s+target_frames]
    return torch.nn.functional.pad(mel2d, (0, target_frames - t), value=0.0)
class BackgroundWindowDataset(torch.utils.data.Dataset):
    """
    Negative examples: real background windows.
    Returns x3d: [1, N_MELS, T] on CPU.
    DataLoader will batch into [B, 1, N_MELS, T] (correct).
    """
    def __init__(
        self,
        wav_paths: Sequence[Path],
        target_frames: int,
        samples_per_wav: int,
        seed: int = 0,
    ):
        self.wav_paths = list(wav_paths)
        self.target_frames = int(target_frames)
        self.samples_per_wav = int(samples_per_wav)
        self.seed = int(seed)

        self.index: List[Tuple[int, int]] = []
        for wi in range(len(self.wav_paths)):
            for ri in range(self.samples_per_wav):
                self.index.append((wi, ri))

    def __len__(self) -> int:
        return len(self.index)

    def __getitem__(self, idx: int) -> torch.Tensor:
        wi, ri = self.index[idx]
        rng = random.Random(self.seed + idx * 1337 + ri)

        mel2d = mel_store.get(self.wav_paths[wi])  # [N_MELS, Tfull] CPU
        win2d = random_crop_or_pad(mel2d, self.target_frames, rng)  # [N_MELS, T]

        return win2d.unsqueeze(0).float()  # [1, N_MELS, T]


class StampedConceptDataset(torch.utils.data.Dataset):
    """
    Positive examples: background windows with a randomly placed patch stamp.
    Returns x3d: [1, N_MELS, T] on CPU.
    DataLoader will batch into [B, 1, N_MELS, T] (correct).
    """
    def __init__(self, bg: BackgroundWindowDataset, patches: List[torch.Tensor], stamp: StampConfig):
        self.bg = bg
        self.patches = patches
        self.stamp = stamp

    def __len__(self) -> int:
        return len(self.bg)

    def __getitem__(self, idx: int) -> torch.Tensor:
        x3d = self.bg[idx].clone()  # [1, N_MELS, T]
        rng = random.Random(self.stamp.seed + idx * 9176)

        patch = self.patches[rng.randint(0, len(self.patches) - 1)]  # [H,W], 0..1
        ph, pw = patch.shape
        _, n_mels, t_frames = x3d.shape

        if ph > n_mels or pw > t_frames:
            raise ValueError(f"Patch {patch.shape} too big for window [N_MELS={n_mels},T={t_frames}]")

        y0 = rng.randint(0, n_mels - ph)
        x0 = rng.randint(0, t_frames - pw)

        bg_std = float(x3d.std().clamp_min(1e-6))
        delta = self.stamp.strength * bg_std
        delta = float(max(-self.stamp.max_abs_delta, min(self.stamp.max_abs_delta, delta)))

        # stamp onto the single channel (index 0)
        x3d[0, y0:y0+ph, x0:x0+pw] += delta * patch.to(x3d.dtype)
        return x3d.float()


# ----------------------------
# CAV training (manual, robust)
# ----------------------------
@torch.no_grad()
def collect_layer_features(loader: DataLoader, layer: torch.nn.Module) -> np.ndarray:
    """
    Collect pooled activations for one layer.
    Returns X [N,D] numpy float32.
    """
    hook = LayerHook(layer)
    X_list: List[np.ndarray] = []

    wrapped_model.eval()
    for x4d in loader:
        x4d = x4d.to(DEVICE)  # [B,1,N_MELS,T]
        hook.out = None
        _ = wrapped_model(x4d)
        act = hook.out
        if act is None:
            hook.close()
            raise RuntimeError("Hook didn't capture activation. Check layer resolution.")
        vec = pool_activation_to_vec(act).detach().cpu().float().numpy()  # [B,D]
        X_list.append(vec)

    hook.close()
    return np.concatenate(X_list, axis=0)

def train_cav_logreg(X_pos: np.ndarray, X_neg: np.ndarray, seed: int) -> Tuple[np.ndarray, float]:
    """
    Train linear classifier to separate pos vs neg.
    Returns cav_dir [D] (normalized) and heldout acc.
    """
    X = np.concatenate([X_neg, X_pos], axis=0)
    y = np.concatenate([np.zeros(len(X_neg), dtype=np.int64), np.ones(len(X_pos), dtype=np.int64)], axis=0)

    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=seed, stratify=y)

    clf = LogisticRegression(
        max_iter=2000,
        class_weight="balanced",
        solver="liblinear",
        random_state=seed,
    )
    clf.fit(X_tr, y_tr)
    acc = float(clf.score(X_te, y_te))

    w = clf.coef_[0].astype(np.float32)   # direction towards class=1
    w = w / (np.linalg.norm(w) + 1e-9)
    return w, acc

# ----------------------------
# Directional derivatives (TCAV-style) for a layer
# ----------------------------
def directional_derivatives(
    windows4d_cpu: torch.Tensor,  # [B,1,N_MELS,T] CPU
    target_idx: int,
    layer: torch.nn.Module,
) -> torch.Tensor:
    """
    Returns g [B,D] where D matches pooled activation dim for that layer.
    Computes grad wrt *hooked activation*, then pools grad the same way as activations.
    This avoids pooled.grad == None issues.
    """
    # Hard-fail if inference mode is enabled (enable_grad cannot override it)
    if hasattr(torch, "is_inference_mode_enabled") and torch.is_inference_mode_enabled():
        raise RuntimeError(
            "torch.inference_mode() is enabled; gradients are disabled. "
            "Restart kernel and ensure you are not inside inference_mode."
        )

    with torch.enable_grad():
        hook = LayerHook(layer)
        wrapped_model.eval()

        x = windows4d_cpu.to(DEVICE).float()
        x = x.detach()  # keep clean graph
        x.requires_grad_(True)

        wrapped_model.zero_grad(set_to_none=True)
        hook.out = None

        logits = wrapped_model(x)  # [B,C]
        act = hook.out
        if act is None:
            hook.close()
            raise RuntimeError("Hook didn't capture activation. Wrong layer module.")

        if isinstance(act, (tuple, list)):
            act = act[0]

        if not isinstance(act, torch.Tensor):
            hook.close()
            raise RuntimeError(f"Hook output is not a Tensor: {type(act)}")

        if not act.requires_grad:
            hook.close()
            raise RuntimeError(
                "Hooked activation does not require grad. "
                "This usually means you're in inference_mode() or grad is globally disabled."
            )

        score = logits[:, target_idx].sum()

        grad_act = torch.autograd.grad(
            outputs=score,
            inputs=act,
            retain_graph=False,
            create_graph=False,
            allow_unused=False,
        )[0]

        hook.close()

        # Pool grad to match the CAV feature space (same pooling as activations)
        g = pool_activation_to_vec(grad_act)  # [B,D]
        return g.detach().cpu().float()


@torch.no_grad()
def predict_over_windows(windows4d_cpu: torch.Tensor) -> Tuple[int, str, float]:
    """
    windows4d_cpu: [B,1,N_MELS,T] CPU
    Returns (pred_id, pred_name, pred_prob) using mean softmax over windows.
    """
    x = windows4d_cpu.to(DEVICE)
    logits = wrapped_model(x)                 # [B,C]
    probs = torch.softmax(logits, dim=1)      # [B,C]
    mean_probs = probs.mean(dim=0)            # [C]
    pred_id = int(mean_probs.argmax().item())
    pred_prob = float(mean_probs[pred_id].item())
    pred_name = id_to_name(pred_id, id_to_speaker)
    return pred_id, pred_name, pred_prob

# ----------------------------
# Main runner
# ----------------------------
def run_tcav_style_to_csv(
    concepts_root: Path,
    eval_wav_paths: List[Path],
    layer_keys: List[str],
    *,
    target_mode: str = "pred",  # "pred" or "true"
    seed: int = 123,
):
    # Pick wavs for CAV training (generic backgrounds)
    cav_wavs = eval_wav_paths.copy()
    random.Random(seed).shuffle(cav_wavs)
    if MAX_WAVS_FOR_CAV is not None:
        cav_wavs = cav_wavs[:MAX_WAVS_FOR_CAV]

    bg_ds = BackgroundWindowDataset(
        wav_paths=cav_wavs,
        target_frames=TARGET_FRAMES,
        samples_per_wav=SAMPLES_PER_WAV_FOR_CAV,
        seed=seed,
    )

    concept_dirs = list_concept_dirs(concepts_root)
    if not concept_dirs:
        raise ValueError(f"No concept dirs found under {concepts_root}")

    # Resolve layer modules once
    layers: Dict[str, torch.nn.Module] = {}
    for lk in layer_keys:
        layers[lk] = resolve_layer_module(wrapped_model.backbone, lk)

        print(f"[layer] {lk} -> {layers[lk].__class__.__name__}")

    # Train CAVs: (concept, layer) -> (cav_dir, cav_acc)
    cav_dir: Dict[Tuple[str, str], np.ndarray] = {}
    cav_acc: Dict[Tuple[str, str], float] = {}

    neg_loader = DataLoader(bg_ds, batch_size=BATCH_SIZE_CAV, shuffle=True, num_workers=0)

    for cdir in concept_dirs:
        cname = cdir.name
        patches = load_concept_patches(cdir)

        pos_ds = StampedConceptDataset(
            bg=bg_ds,
            patches=patches,
            stamp=StampConfig(strength=STAMP_STRENGTH, max_abs_delta=MAX_ABS_DELTA, seed=seed),
        )
        pos_loader = DataLoader(pos_ds, batch_size=BATCH_SIZE_CAV, shuffle=True, num_workers=0)

        for lk, layer_mod in layers.items():
            X_neg = collect_layer_features(neg_loader, layer_mod)
            X_pos = collect_layer_features(pos_loader, layer_mod)

            w, acc = train_cav_logreg(X_pos=X_pos, X_neg=X_neg, seed=seed)
            cav_dir[(cname, lk)] = w
            cav_acc[(cname, lk)] = acc
            print(f"[CAV] concept={cname} layer={lk} acc={acc:.3f} dim={w.shape[0]}")

    # Evaluate wavs
    rows: List[Dict[str, object]] = []

    for wav_path in eval_wav_paths:
        true_label = infer_true_label_from_path(wav_path)

        mel2d = mel_store.get(wav_path)  # cached CPU mel
        windows4d = mel2d_to_windows(mel2d, TARGET_FRAMES, WINDOW_STRIDE)  # [B,1,N_MELS,T] CPU

        pred_id, pred_name, pred_prob = predict_over_windows(windows4d)

        if target_mode == "pred":
            target_idx = pred_id
        elif target_mode == "true":
            if true_label not in speaker_to_id:
                # skip if label not in head mapping
                continue
            target_idx = int(speaker_to_id[true_label])
        else:
            raise ValueError(f"Unknown target_mode={target_mode}")

        # For each layer compute grads once (g [B,D])
        grads_by_layer: Dict[str, torch.Tensor] = {}
        for lk, layer_mod in layers.items():
            grads_by_layer[lk] = directional_derivatives(windows4d, target_idx=target_idx, layer=layer_mod)

        # For each concept/layer compute dd and metrics
        for cdir in concept_dirs:
            cname = cdir.name
            for lk in layer_keys:
                g = grads_by_layer[lk]  # [B,D]
                w = torch.from_numpy(cav_dir[(cname, lk)]).float()  # [D]
                dd = (g * w.unsqueeze(0)).sum(dim=1)  # [B]

                pos_pct = float((dd > 0).float().mean().item() * 100.0)
                magnitude = float(dd.abs().mean().item())

                rows.append({
                    "path": str(wav_path),
                    "concept name": cname,
                    "layer name": lk,
                    "positive percentage": pos_pct,
                    "magnitude": magnitude,
                    "true label": str(true_label),
                    "predicted label": str(pred_name),
                    "predicted probability": float(pred_prob),
                    "cav acc": float(cav_acc[(cname, lk)]),
                })

    df = pd.DataFrame(rows, columns=[
        "path",
        "concept name",
        "layer name",
        "positive percentage",
        "magnitude",
        "true label",
        "predicted label",
        "predicted probability",
        "cav acc",
    ])
    df.to_csv(OUT_CSV, index=False)
    print("Wrote:", OUT_CSV, "rows:", len(df))
    return df

# ----------------------------
# Run
# ----------------------------
eval_wav_paths = [Path(p) for p in glob.glob(EVAL_WAV_GLOB, recursive=True)]
eval_wav_paths.sort()
print("Eval wavs:", len(eval_wav_paths))

df = run_tcav_style_to_csv(
    concepts_root=CONCEPTS_ROOT,
    eval_wav_paths=eval_wav_paths,
    layer_keys=LAYER_KEYS,
    target_mode="pred",  # change to "true" to explain ground-truth logit (requires speaker_to_id coverage)
    seed=123,
)

df.head()


DEVICE: cuda
WAV_DIR: /home/SpeakerRec/BioVoice/data/wavs
CONCEPTS_ROOT: /home/SpeakerRec/BioVoice/concept/new_concepts
OUT_CSV: /home/SpeakerRec/BioVoice/output/tcav_results.csv


Using cache found in /home/SpeakerRec/.cache/torch/hub/IDRnD_ReDimNet_master
  ckpt = torch.load(HEAD_PATH, map_location=DEVICE)


Loaded head: /home/SpeakerRec/BioVoice/redimnet/tcav/output/redim_speaker_head_linear.pt
Head in_dim: 192 num_classes: 3 l2_norm_emb: True
wrapped_model ready.
Eval wavs: 92
[layer] stage4 -> Sequential
[concept] constant_long_thick: loaded 1 patches, skipped 120
[CAV] concept=constant_long_thick layer=stage4 acc=0.568 dim=2304
[concept] constant_long_thin: loaded 1 patches, skipped 120
[CAV] concept=constant_long_thin layer=stage4 acc=0.500 dim=2304
[concept] constant_short_thick: loaded 1 patches, skipped 120
[CAV] concept=constant_short_thick layer=stage4 acc=0.392 dim=2304
[concept] constant_short_thin: loaded 1 patches, skipped 120
[CAV] concept=constant_short_thin layer=stage4 acc=0.351 dim=2304
Wrote: /home/SpeakerRec/BioVoice/output/tcav_results.csv rows: 368


Unnamed: 0,path,concept name,layer name,positive percentage,magnitude,true label,predicted label,predicted probability,cav acc
0,/home/SpeakerRec/BioVoice/data/wavs/eden_001.wav,constant_long_thick,stage4,100.0,3e-05,wavs,eden,0.769236,0.567568
1,/home/SpeakerRec/BioVoice/data/wavs/eden_001.wav,constant_long_thin,stage4,100.0,3.8e-05,wavs,eden,0.769236,0.5
2,/home/SpeakerRec/BioVoice/data/wavs/eden_001.wav,constant_short_thick,stage4,100.0,2.7e-05,wavs,eden,0.769236,0.391892
3,/home/SpeakerRec/BioVoice/data/wavs/eden_001.wav,constant_short_thin,stage4,100.0,1.8e-05,wavs,eden,0.769236,0.351351
4,/home/SpeakerRec/BioVoice/data/wavs/eden_002.wav,constant_long_thick,stage4,100.0,3.5e-05,wavs,eden,0.789764,0.567568


In [None]:
# pick 1 real wav
wav = next(WAV_FOLDER.glob("*.wav"))
mel_real = redim_model.spec(torchaudio.load(str(wav))[0][:1].float())  # (1,N,T)
mel_real = mel_real[0].cpu().numpy()

# pick 1 concept npy (a numbered file, not raw_energy)
cfile = next((CONCEPT_ROOT / concept_names[0]).glob("*.npy"))
mel_con = np.load(cfile)

def stats(name, x):
    print(name, "shape", x.shape,
          "min/max", float(x.min()), float(x.max()),
          "p1/p50/p99", np.percentile(x, [1,50,99]))

stats("REAL", mel_real)
stats("CON ", mel_con)

# also check if spec already has ~zero mean per mel bin
print("REAL per-mel mean abs:", float(np.mean(np.abs(mel_real.mean(axis=1)))))
print("CON  per-mel mean abs:", float(np.mean(np.abs(mel_con.mean(axis=1)))))


REAL shape (72, 183) min/max -8.323339462280273 10.122404098510742 p1/p50/p99 [-7.13723767 -0.20183468  7.66903698]
CON  shape (72, 304) min/max -1.6737136840820312 11.715995788574219 p1/p50/p99 [-1.67371368e+00  9.53674316e-07  9.53674316e-07]
REAL per-mel mean abs: 3.2983405162667623e-07
CON  per-mel mean abs: 9.313661166743259e-07
