In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

In [2]:
class ConvSTFT(nn.Module):
    """
    Conv1d-based STFT that is ONNX-export-friendly.

    Emulates torch.stft with:
      - center=True  -> zero pad by n_fft//2 at both ends
      - normalized=False
      - return_complex=False (we output real/imag via two filter banks)
    """
    def __init__(
        self,
        n_fft: int = 1024,
        hop_length: int = 320,
        win_length: int = 800,
        window: torch.Tensor | None = None,
        pad_center: bool = True,
    ):
        super().__init__()
        assert win_length <= n_fft, "win_length must be <= n_fft"

        self.n_fft = int(n_fft)
        self.hop_length = int(hop_length)
        self.win_length = int(win_length)
        self.pad_center = bool(pad_center)

        # Window
        if window is None:
            window = torch.hann_window(win_length, periodic=False, dtype=torch.float32)
        else:
            window = window.to(dtype=torch.float32)
        self.register_buffer("window", window, persistent=False)

        # Build Fourier basis for positive frequencies [0 .. n_fft//2]
        # Real kernels:  window[n] * cos(2πkn/N)
        # Imag kernels: -window[n] * sin(2πkn/N)
        
        # Build Fourier basis for positive frequencies [0..n_fft//2]
        num_bins = n_fft // 2 + 1
        
        # Center the window inside an n_fft frame
        offset = (n_fft - win_length) // 2
        win_full = torch.zeros(n_fft, dtype=torch.float32)
        win_full[offset:offset+win_length] = self.window  # centered window
        
        # n over 0..n_fft-1, k over 0..num_bins-1
        n = torch.arange(n_fft, dtype=torch.float32).unsqueeze(0)  # [1, n_fft]
        k = torch.arange(num_bins, dtype=torch.float32).unsqueeze(1)  # [num_bins, 1]
        ang = 2 * math.pi * k @ (n / float(n_fft))  # [num_bins, n_fft]
        
        cos_kernels = torch.cos(ang) * win_full      # [num_bins, n_fft]
        sin_kernels = -torch.sin(ang) * win_full     # [num_bins, n_fft]
        
        weight = torch.cat([cos_kernels, sin_kernels], dim=0).unsqueeze(1)  # [2*num_bins, 1, n_fft]
        self.register_buffer("fourier_basis", weight.contiguous(), persistent=False)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, T] waveform

        Returns:
            stft_out: [B, num_bins, frames, 2]  where last dim is (real, imag)
        """
        B, T = x.shape

        # center=True -> zero pad by n_fft//2
        if self.pad_center:
            pad = self.n_fft // 2
            x = F.pad(x, (pad, pad), mode="constant", value=0.0)  # [B, T + 2*pad]

        # Convolution to compute dot product with each Fourier kernel at each frame
        # Input must be [B, 1, T]
        y = F.conv1d(
            x.unsqueeze(1),                       # [B, 1, T+pad*2]
            self.fourier_basis,                   # [2*F, 1, win_length]
            bias=None,
            stride=self.hop_length,
            padding=0,
            dilation=1,
            groups=1,
        )  # -> [B, 2*F, frames]

        num_bins = self.n_fft // 2 + 1
        frames = y.shape[-1]
        y = y.view(B, 2, num_bins, frames)        # [B, 2, F, frames]
        real = y[:, 0, :, :]                      # [B, F, frames]
        imag = y[:, 1, :, :]                      # [B, F, frames]

        # Pack to match torch.stft(return_complex=False) output layout
        stft_out = torch.stack([real, imag], dim=-1)  # [B, F, frames, 2]
        return stft_out

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

# Assumes your ConvSTFT is patched for Option A: n_fft-length kernels with centered window
class LogMelSpectrogramConvSTFT(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        sr: int = 32000,
        win_length: int = 800,
        hopsize: int = 320,
        n_fft: int = 1024,
        fmin: float = 0.0,
        fmax: float | None = None,
    ):
        super().__init__()

        if fmax is None:
            fmax = sr // 2

        self.n_mels = int(n_mels)
        self.sr = int(sr)
        self.win_length = int(win_length)
        self.hopsize = int(hopsize)
        self.n_fft = int(n_fft)
        self.fmin = float(fmin)

        nyquist = sr // 2
        if fmax is None:
            fmax = nyquist
        elif fmax > nyquist:
            print(f"[LogMel] fmax={fmax} > Nyquist={nyquist}. Clamping to Nyquist.")
            fmax = float(nyquist)
        self.fmax = float(fmax)

        assert 0.0 <= self.fmin < self.fmax <= (self.sr / 2), \
            f"Invalid band: fmin={self.fmin}, fmax={self.fmax}, nyquist={self.sr/2}"

        # Pre-emphasis kernel y[t] = x[t] - 0.97*x[t-1]
        self.register_buffer(
            "preemphasis_kernel",
            torch.tensor([[[-0.97, 1.0]]], dtype=torch.float32),
            persistent=False
        )

        # STFT via conv: ConvSTFT must center win inside n_fft and use n_fft-length kernels internally
        self.stft = ConvSTFT(
            n_fft=self.n_fft,
            hop_length=self.hopsize,
            win_length=self.win_length,
            window=torch.hann_window(self.win_length, periodic=False, dtype=torch.float32),
            pad_center=True,
        )

        # Kaldi mel filter bank (Nyquist excluded), then pad Nyquist
        mel_bins, _ = torchaudio.compliance.kaldi.get_mel_banks(
            num_bins=self.n_mels,
            window_length_padded=self.n_fft,
            sample_freq=self.sr,
            low_freq=self.fmin,
            high_freq=self.fmax,
            vtln_low=100.0,
            vtln_high=-500.0,
            vtln_warp_factor=1.0,
        )
        mel_bins = F.pad(mel_bins, (0, 1), value=0.0)  # -> [n_mels, n_fft//2 + 1]
        self.register_buffer("mel_basis", mel_bins.to(torch.float32), persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Current behavior (length T-1) to match EfficientAT reference:
        x = F.conv1d(x.unsqueeze(1), self.preemphasis_kernel).squeeze(1)

        # STFT -> power
        stft_out = self.stft(x)               # [B, F, frames, 2]
        power = (stft_out ** 2).sum(dim=-1)   # [B, F, frames]

        # Mel projection (use correct einsum string, keep dtype/device aligned)
        mel_basis = self.mel_basis.to(dtype=power.dtype, device=power.device)  # [M, F]
        mel = torch.einsum('mf,bft->bmt', mel_basis, power)                    # [B, M, T]

        # Log compression + normalization
        log_mel = (mel + 1e-5).log()
        log_mel = (log_mel + 4.5) / 5.0
        return log_mel

    @torch.no_grad()
    def power_forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.conv1d(x.unsqueeze(1), self.preemphasis_kernel).squeeze(1)
        stft_out = self.stft(x)
        power = (stft_out ** 2).sum(dim=-1)
        return power

In [4]:
# ---- Paste in your ConvSTFT and LogMelSpectrogramConvSTFT here, or import them ----
# from your_module import ConvSTFT, LogMelSpectrogramConvSTFT
#
# NOTE: Ensure ConvSTFT uses n_fft-length kernels with the window centered inside n_fft (Option A).
# Also ensure LogMelSpectrogramConvSTFT uses mel projection with einsum: torch.einsum('mf, bft -> bmt', mel_basis, power).

# ----- Reference: AugmentMelSTFT from EfficientAT (as provided by you) -----
class AugmentMelSTFT(nn.Module):
    def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
                 fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):
        super().__init__()

        self.win_length = win_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.sr = sr
        self.fmin = fmin
        if fmax is None:
            fmax = sr // 2 - fmax_aug_range // 2
            print(f"Warning: FMAX is None setting to {fmax} ")
        self.fmax = fmax
        self.hopsize = hopsize

        self.register_buffer('window', torch.hann_window(win_length, periodic=False), persistent=False)

        assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
        assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
        self.fmin_aug_range = fmin_aug_range
        self.fmax_aug_range = fmax_aug_range

        self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)

        if freqm == 0:
            self.freqm = torch.nn.Identity()
        else:
            self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
        if timem == 0:
            self.timem = torch.nn.Identity()
        else:
            self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)

    def forward(self, x):
        # Pre-emphasis (length reduces by 1 sample)
        x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)

        # Reference STFT
        x = torch.stft(
            x,
            self.n_fft,
            hop_length=self.hopsize,
            win_length=self.win_length,
            center=True,
            normalized=False,
            window=self.window,
            return_complex=False,
        )
        x = (x ** 2).sum(dim=-1)  # power: [B, F, T]

        # Aug ranges are ignored in eval mode
        fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
        fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
        if not self.training:
            fmin = self.fmin
            fmax = self.fmax

        # Kaldi mel basis (Nyquist excluded) then pad Nyquist bin
        mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(
            self.n_mels,  self.n_fft, self.sr,
            fmin, fmax,
            vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0
        )
        mel_basis = torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0)  # [M, F]

        # Project to mel
        with torch.cuda.amp.autocast(enabled=False):
            mel_basis = torch.as_tensor(mel_basis, device=x.device, dtype=x.dtype)  # [M, F]
            melspec = torch.einsum('mf, bft -> bmt', mel_basis, x)  # [B, M, T]

        # Log compression
        melspec = (melspec + 1e-5).log()

        # (Optional) masking only in training mode
        if self.training:
            melspec = self.freqm(melspec)
            melspec = self.timem(melspec)

        # Fast normalization
        melspec = (melspec + 4.5) / 5.0
        return melspec


In [5]:
# ---------- Comparison Utilities ----------
def compute_metrics(a: torch.Tensor, b: torch.Tensor, eps_mask: float = 1e-3):
    """
    a, b: tensors of the same shape
    Returns a dict of absolute and (masked) relative error metrics.
    """
    with torch.no_grad():
        abs_err = (a - b).abs()
        max_abs = abs_err.max().item()
        mean_abs = abs_err.mean().item()

        # Masked relative error to avoid division by tiny reference values
        ref = b.abs()
        mask = ref > eps_mask
        if mask.any():
            rel_err = abs_err[mask] / ref[mask]
            max_rel = rel_err.max().item()
            mean_rel = rel_err.mean().item()
        else:
            max_rel = float('nan')
            mean_rel = float('nan')

    return {
        "max_abs_err": max_abs,
        "mean_abs_err": mean_abs,
        "masked_max_rel_err": max_rel,
        "masked_mean_rel_err": mean_rel,
    }


def compare_logmel_modules(
    LogMelClass,            # your LogMelSpectrogramConvSTFT class object
    AugmentMelClass=AugmentMelSTFT,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    dtype: torch.dtype = torch.float32,
    B: int = 2,
    T: int = 16000,
    n_mels: int = 128,
    sr: int = 32000,
    win_length: int = 800,
    hopsize: int = 320,
    n_fft: int = 1024,
    fmin: float = 0.0,
    fmax: float | None = None,
    plot: bool = False,
    also_test_float64: bool = True,
):
    torch.manual_seed(0)

    # Instantiate both models with the same configuration
    ref = AugmentMelClass(
        n_mels=n_mels, sr=sr, win_length=win_length, hopsize=hopsize, n_fft=n_fft,
        freqm=0, timem=0,  # ensure no masking even if .train() is used
        fmin=fmin, fmax=fmax if fmax is not None else sr // 2,
        fmin_aug_range=1, fmax_aug_range=1,  # augmentation off effectively
    ).to(device).eval()

    test = LogMelClass(
        n_mels=n_mels, sr=sr, win_length=win_length, hopsize=hopsize, n_fft=n_fft, fmin=fmin, fmax=fmax
    ).to(device).eval()

    # NOTE: Different constructors - your LogMelSpectrogramConvSTFT likely uses "hopsize" param name.
    # If it's 'hopsize', the above passes it; if it's 'hop_length', adapt accordingly.

    # Random batch
    x = torch.randn(B, T, dtype=dtype, device=device)

    with torch.no_grad():
        # Reference (torch.stft path)
        y_ref = ref(x)             # [B, M, frames]

        # Test (Conv1d STFT path)
        y_test = test(x)           # [B, M, frames]

    # Check shapes
    assert y_ref.shape == y_test.shape, f"Shape mismatch: ref={y_ref.shape}, test={y_test.shape}"

    # Metrics
    metrics = compute_metrics(y_test, y_ref, eps_mask=1e-3)

    # Allclose with realistic fp32 tolerances (for n_fft=1024)
    atol = 5e-3
    rtol = 5e-3
    allclose = torch.allclose(y_test, y_ref, atol=atol, rtol=rtol)

    print("=== Log-Mel: ConvSTFT vs AugmentMelSTFT (fp32) ===")
    print(f"Device: {device}  DType: {dtype}")
    print(f"Input:  B={B}, T={T}, sr={sr}, n_fft={n_fft}, win={win_length}, hop={hopsize}, n_mels={n_mels}")
    print(f"Output shape: {y_test.shape}  (B, M, frames)")
    print(f"Max abs error:          {metrics['max_abs_err']:.6e}")
    print(f"Mean abs error:         {metrics['mean_abs_err']:.6e}")
    print(f"Masked max rel error:   {metrics['masked_max_rel_err']:.6e} (|ref|>1e-3)")
    print(f"Masked mean rel error:  {metrics['masked_mean_rel_err']:.6e} (|ref|>1e-3)")
    print(f"torch.allclose:         {allclose} (atol={atol}, rtol={rtol})")

    if plot:
        try:
            import matplotlib.pyplot as plt

            b = 0
            fig, axs = plt.subplots(2, 3, figsize=(14, 6))
            axs = axs.ravel()

            im0 = axs[0].imshow(y_ref[b].cpu().numpy(), aspect='auto', origin='lower')
            axs[0].set_title("AugmentMelSTFT (ref)")
            plt.colorbar(im0, ax=axs[0], fraction=0.046)

            im1 = axs[1].imshow(y_test[b].cpu().numpy(), aspect='auto', origin='lower')
            axs[1].set_title("LogMelSpectrogramConvSTFT (test)")
            plt.colorbar(im1, ax=axs[1], fraction=0.046)

            diff = (y_test[b] - y_ref[b]).cpu().numpy()
            im2 = axs[2].imshow(abs(diff), aspect='auto', origin='lower')
            axs[2].set_title("|difference|")
            plt.colorbar(im2, ax=axs[2], fraction=0.046)

            # Histograms
            axs[3].hist(diff.flatten(), bins=100, alpha=0.7)
            axs[3].set_title("Diff histogram")

            axs[4].plot((abs(diff)).mean(axis=0))
            axs[4].set_title("Mean |diff| per frame")

            axs[5].plot((abs(diff)).mean(axis=1))
            axs[5].set_title("Mean |diff| per mel bin")

            plt.tight_layout()
            plt.show()
        except Exception as e:
            print(f"(Plot skipped due to error: {e})")

    # Optional: Double-precision sanity check
    if also_test_float64:
        x64 = x.double()
        ref64 = ref.double()
        test64 = test.double()

        with torch.no_grad():
            y_ref64 = ref64(x64)
            y_test64 = test64(x64)

        assert y_ref64.shape == y_test64.shape
        metrics64 = compute_metrics(y_test64, y_ref64, eps_mask=1e-9)
        allclose64 = torch.allclose(y_test64, y_ref64, atol=1e-8, rtol=1e-8)

        print("\n=== Log-Mel: ConvSTFT vs AugmentMelSTFT (float64) ===")
        print(f"Max abs error:          {metrics64['max_abs_err']:.6e}")
        print(f"Mean abs error:         {metrics64['mean_abs_err']:.6e}")
        print(f"Masked max rel error:   {metrics64['masked_max_rel_err']:.6e} (|ref|>1e-9)")
        print(f"Masked mean rel error:  {metrics64['masked_mean_rel_err']:.6e} (|ref|>1e-9)")
        print(f"torch.allclose:         {allclose64} (atol=1e-8, rtol=1e-8)")

    return {
        "fp32": {**metrics, "allclose": allclose},
        "fp64": {**metrics64, "allclose": allclose64} if also_test_float64 else None,
    }



In [6]:
results = compare_logmel_modules(LogMelSpectrogramConvSTFT, plot=False, also_test_float64=True)


=== Log-Mel: ConvSTFT vs AugmentMelSTFT (fp32) ===
Device: cuda  DType: torch.float32
Input:  B=2, T=16000, sr=32000, n_fft=1024, win=800, hop=320, n_mels=128
Output shape: torch.Size([2, 128, 50])  (B, M, frames)
Max abs error:          1.290884e+00
Mean abs error:         3.625098e-03
Masked max rel error:   6.989396e+00 (|ref|>1e-3)
Masked mean rel error:  4.634821e-03 (|ref|>1e-3)
torch.allclose:         False (atol=0.005, rtol=0.005)

=== Log-Mel: ConvSTFT vs AugmentMelSTFT (float64) ===
Max abs error:          1.291029e+00
Mean abs error:         3.624981e-03
Masked max rel error:   6.989822e+00 (|ref|>1e-9)
Masked mean rel error:  4.633626e-03 (|ref|>1e-9)
torch.allclose:         False (atol=1e-8, rtol=1e-8)


Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:873.)
  return _VF.stft(  # type: ignore[attr-defined]
  with torch.cuda.amp.autocast(enabled=False):


In [7]:
import torch
import torch.nn.functional as F
import torchaudio

def compare_logmel_stepwise(
    conv_model: LogMelSpectrogramConvSTFT,
    ref_model: nn.Module,  # AugmentMelSTFT
    x: torch.Tensor,
    fmin: float,
    fmax: float,
    eps: float = 1e-5,
    dtype: torch.dtype = torch.float32,
):
    conv_model.eval()
    ref_model.eval()

    device = x.device
    x = x.to(dtype)

    # ----- 1) Pre-emphasis -----
    pe_kernel = conv_model.preemphasis_kernel.to(dtype=dtype, device=device)
    x_pe = F.conv1d(x.unsqueeze(1), pe_kernel).squeeze(1)   # match both paths

    # ----- 2) STFT -----
    # Conv path power
    stft_conv = conv_model.stft(x_pe)                       # [B, F, T, 2]
    power_conv = (stft_conv ** 2).sum(dim=-1)               # [B, F, T]

    # Reference path power (torch.stft)
    window = torch.hann_window(conv_model.win_length, periodic=False, dtype=dtype, device=device)
    stft_ref = torch.stft(
        x_pe, n_fft=conv_model.n_fft, hop_length=conv_model.hopsize, win_length=conv_model.win_length,
        center=True, normalized=False, window=window, return_complex=False, pad_mode="constant"
    )                                                       # [B, F, T, 2]
    power_ref = (stft_ref ** 2).sum(dim=-1)                 # [B, F, T]

    print("Power max abs:", (power_conv - power_ref).abs().max().item(),
          "mean abs:", (power_conv - power_ref).abs().mean().item())

    # ----- 3) Single shared mel basis -----
    mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(
        conv_model.n_mels, conv_model.n_fft, conv_model.sr,
        fmin, fmax, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
    )
    mel_basis = F.pad(mel_basis, (0, 1), value=0.0).to(device=device, dtype=dtype)  # [M, F]

    mel_conv = torch.einsum('mf,bft->bmt', mel_basis, power_conv)     # [B, M, T]
    mel_ref  = torch.einsum('mf,bft->bmt', mel_basis, power_ref)      # [B, M, T]

    print("Pre-log mel max abs:", (mel_conv - mel_ref).abs().max().item(),
          "mean abs:", (mel_conv - mel_ref).abs().mean().item())

    # ----- 4) Apply identical log + norm -----
    y_conv = (mel_conv + eps).log(); y_conv = (y_conv + 4.5) / 5.0
    y_ref  = (mel_ref  + eps).log(); y_ref  = (y_ref  + 4.5) / 5.0

    diff = (y_conv - y_ref).abs()
    print("Log-mel max abs:", diff.max().item(), "mean abs:", diff.mean().item())

    # Also report masked relative error on log-mel
    mask = y_ref.abs() > 1e-3
    if mask.any():
        rel = diff[mask] / y_ref.abs()[mask]
        print("Log-mel masked max rel:", rel.max().item(), "mean rel:", rel.mean().item())
    else:
        print("No elements exceed rel-error mask threshold in y_ref.")

    return {
        "power_max_abs": (power_conv - power_ref).abs().max().item(),
        "power_mean_abs": (power_conv - power_ref).abs().mean().item(),
        "mel_prelog_max_abs": (mel_conv - mel_ref).abs().max().item(),
        "mel_prelog_mean_abs": (mel_conv - mel_ref).abs().mean().item(),
        "logmel_max_abs": diff.max().item(),
        "logmel_mean_abs": diff.mean().item(),
    }

In [8]:
waveform = torch.randn(1, 32000, dtype=torch.float32)  # [B, T]  # [B, T]

n_mels = 128
sample_rate = 32000
win_length=800
hopsize=320
n_fft=1024
fmin=0.0
fmax=16000.0

extractor = LogMelSpectrogramConvSTFT(
    n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, fmin=0.0, fmax=16000.0
).eval()
mel = AugmentMelSTFT(n_mels=n_mels,
                     sr=sample_rate,
                     win_length=win_length,
                     hopsize=hopsize,
                     n_fft=n_fft,
                     freqm=48,
                     timem=192,
                     fmin=fmin,
                     fmax=fmax,
                     fmin_aug_range=10, 
                     fmax_aug_range=2000
                     
                     )

mel.eval()


message = compare_logmel_stepwise(extractor,
    mel,  # AugmentMelSTFT
    waveform,
    fmin,
    fmax,
    eps = 1e-5,
    dtype = torch.float32)

print(message)

Power max abs: 0.452392578125 mean abs: 0.014420101419091225
Pre-log mel max abs: 0.46875 mean abs: 0.016031377017498016
Log-mel max abs: 0.00011710822582244873 mean abs: 2.7025034796679392e-06
Log-mel masked max rel: 0.002050897106528282 mean rel: 3.383140210644342e-06
{'power_max_abs': 0.452392578125, 'power_mean_abs': 0.014420101419091225, 'mel_prelog_max_abs': 0.46875, 'mel_prelog_mean_abs': 0.016031377017498016, 'logmel_max_abs': 0.00011710822582244873, 'logmel_mean_abs': 2.7025034796679392e-06}
