In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

In [2]:
class ConvSTFT(nn.Module):
    """
    Conv1d-based STFT that is ONNX-export-friendly.

    Emulates torch.stft with:
      - center=True  -> zero pad by n_fft//2 at both ends
      - normalized=False
      - return_complex=False (we output real/imag via two filter banks)
    """
    def __init__(
        self,
        n_fft: int = 1024,
        hop_length: int = 320,
        win_length: int = 800,
        window: torch.Tensor | None = None,
        pad_center: bool = True,
    ):
        super().__init__()
        assert win_length <= n_fft, "win_length must be <= n_fft"

        self.n_fft = int(n_fft)
        self.hop_length = int(hop_length)
        self.win_length = int(win_length)
        self.pad_center = bool(pad_center)

        # Window
        if window is None:
            window = torch.hann_window(win_length, periodic=False, dtype=torch.float32)
        else:
            window = window.to(dtype=torch.float32)
        self.register_buffer("window", window, persistent=False)

        # Build Fourier basis for positive frequencies [0 .. n_fft//2]
        # Real kernels:  window[n] * cos(2πkn/N)
        # Imag kernels: -window[n] * sin(2πkn/N)
        
        # Build Fourier basis for positive frequencies [0..n_fft//2]
        num_bins = n_fft // 2 + 1
        
        # Center the window inside an n_fft frame
        offset = (n_fft - win_length) // 2
        win_full = torch.zeros(n_fft, dtype=torch.float32)
        win_full[offset:offset+win_length] = self.window  # centered window
        
        # n over 0..n_fft-1, k over 0..num_bins-1
        n = torch.arange(n_fft, dtype=torch.float32).unsqueeze(0)  # [1, n_fft]
        k = torch.arange(num_bins, dtype=torch.float32).unsqueeze(1)  # [num_bins, 1]
        ang = 2 * math.pi * k @ (n / float(n_fft))  # [num_bins, n_fft]
        
        cos_kernels = torch.cos(ang) * win_full      # [num_bins, n_fft]
        sin_kernels = -torch.sin(ang) * win_full     # [num_bins, n_fft]
        
        weight = torch.cat([cos_kernels, sin_kernels], dim=0).unsqueeze(1)  # [2*num_bins, 1, n_fft]
        self.register_buffer("fourier_basis", weight.contiguous(), persistent=False)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, T] waveform

        Returns:
            stft_out: [B, num_bins, frames, 2]  where last dim is (real, imag)
        """
        B, T = x.shape

        # center=True -> zero pad by n_fft//2
        if self.pad_center:
            pad = self.n_fft // 2
            x = F.pad(x, (pad, pad), mode="constant", value=0.0)  # [B, T + 2*pad]

        # Convolution to compute dot product with each Fourier kernel at each frame
        # Input must be [B, 1, T]
        y = F.conv1d(
            x.unsqueeze(1),                       # [B, 1, T+pad*2]
            self.fourier_basis,                   # [2*F, 1, win_length]
            bias=None,
            stride=self.hop_length,
            padding=0,
            dilation=1,
            groups=1,
        )  # -> [B, 2*F, frames]

        num_bins = self.n_fft // 2 + 1
        frames = y.shape[-1]
        y = y.view(B, 2, num_bins, frames)        # [B, 2, F, frames]
        real = y[:, 0, :, :]                      # [B, F, frames]
        imag = y[:, 1, :, :]                      # [B, F, frames]

        # Pack to match torch.stft(return_complex=False) output layout
        stft_out = torch.stack([real, imag], dim=-1)  # [B, F, frames, 2]
        return stft_out

In [3]:
# Assumes your ConvSTFT is patched for Option A: n_fft-length kernels with centered window
class LogMelSpectrogramConvSTFT(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        sr: int = 32000,
        win_length: int = 800,
        hopsize: int = 320,
        n_fft: int = 1024,
        fmin: float = 0.0,
        fmax: float | None = None,
    ):
        super().__init__()

        if fmax is None:
            fmax = sr // 2

        self.n_mels = int(n_mels)
        self.sr = int(sr)
        self.win_length = int(win_length)
        self.hopsize = int(hopsize)
        self.n_fft = int(n_fft)
        self.fmin = float(fmin)

        nyquist = sr // 2
        if fmax is None:
            fmax = nyquist
        elif fmax > nyquist:
            print(f"[LogMel] fmax={fmax} > Nyquist={nyquist}. Clamping to Nyquist.")
            fmax = float(nyquist)
        self.fmax = float(fmax)

        assert 0.0 <= self.fmin < self.fmax <= (self.sr / 2), \
            f"Invalid band: fmin={self.fmin}, fmax={self.fmax}, nyquist={self.sr/2}"

        # Pre-emphasis kernel y[t] = x[t] - 0.97*x[t-1]
        self.register_buffer(
            "preemphasis_kernel",
            torch.tensor([[[-0.97, 1.0]]], dtype=torch.float32),
            persistent=False
        )

        # STFT via conv: ConvSTFT must center win inside n_fft and use n_fft-length kernels internally
        self.stft = ConvSTFT(
            n_fft=self.n_fft,
            hop_length=self.hopsize,
            win_length=self.win_length,
            window=torch.hann_window(self.win_length, periodic=False, dtype=torch.float32),
            pad_center=True,
        )

        # Kaldi mel filter bank (Nyquist excluded), then pad Nyquist
        mel_bins, _ = torchaudio.compliance.kaldi.get_mel_banks(
            num_bins=self.n_mels,
            window_length_padded=self.n_fft,
            sample_freq=self.sr,
            low_freq=self.fmin,
            high_freq=self.fmax,
            vtln_low=100.0,
            vtln_high=-500.0,
            vtln_warp_factor=1.0,
        )
        mel_bins = F.pad(mel_bins, (0, 1), value=0.0)  # -> [n_mels, n_fft//2 + 1]
        self.register_buffer("mel_basis", mel_bins.to(torch.float32), persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Current behavior (length T-1) to match EfficientAT reference:
        x = F.conv1d(x.unsqueeze(1), self.preemphasis_kernel).squeeze(1)

        # STFT -> power
        stft_out = self.stft(x)               # [B, F, frames, 2]
        power = (stft_out ** 2).sum(dim=-1)   # [B, F, frames]

        # Mel projection (use correct einsum string, keep dtype/device aligned)
        mel_basis = self.mel_basis.to(dtype=power.dtype, device=power.device)  # [M, F]
        mel = torch.einsum('mf,bft->bmt', mel_basis, power)                    # [B, M, T]

        # Log compression + normalization
        log_mel = (mel + 1e-5).log()
        log_mel = (log_mel + 4.5) / 5.0
        return log_mel

    @torch.no_grad()
    def power_forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.conv1d(x.unsqueeze(1), self.preemphasis_kernel).squeeze(1)
        stft_out = self.stft(x)
        power = (stft_out ** 2).sum(dim=-1)
        return power

In [4]:
class MelSpecAugment(nn.Module):
    """
    SpecAugment on log-mel (training only).
    """
    def __init__(self, freqm: int = 48, timem: int = 192):
        super().__init__()
        self.freqm = (
            nn.Identity() if freqm == 0 else torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
        )
        self.timem = (
            nn.Identity() if timem == 0 else torchaudio.transforms.TimeMasking(timem, iid_masks=True)
        )

    def forward(self, log_mel: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return log_mel
        x = self.freqm(log_mel)
        x = self.timem(x)
        return x

In [5]:
class RandomMelEdgeProjector(nn.Module):
    """
    Re-project power spectrogram with randomized fmin/fmax during training.
    """
    def __init__(
        self,
        n_mels: int = 128,
        sr: int = 32000,
        n_fft: int = 1024,
        fmin: float = 0.0,
        fmax: float | None = None,
        fmin_aug_range: int = 10,
        fmax_aug_range: int = 1000,
    ):
        super().__init__()
        nyquist = sr // 2
        if fmax is None:
            fmax = nyquist
        elif fmax > nyquist:
            # Optional: print/log once so you know a clamp happened
            print(f"[LogMel] fmax={fmax} > Nyquist={nyquist}. Clamping to Nyquist.")
            fmax = float(nyquist)
        
        self.base_fmin = float(fmin)
        self.base_fmax = float(fmax)
        assert 0.0 <= self.base_fmin < self.base_fmax <= (sr / 2), f"Invalid band: fmin={self.fmin}, fmax={self.fmax}, nyquist={self.sr/2}"

        assert fmin_aug_range >= 1
        assert fmax_aug_range >= 1

        self.n_mels = n_mels
        self.sr = sr
        self.n_fft = n_fft
        
        self.fmin_aug_range = int(fmin_aug_range)
        self.fmax_aug_range = int(fmax_aug_range)

    def _build_mel_basis(self, fmin: float, fmax: float, device, dtype):
        mb, _ = torchaudio.compliance.kaldi.get_mel_banks(
            num_bins=self.n_mels,
            window_length_padded=self.n_fft,
            sample_freq=self.sr,
            low_freq=fmin,
            high_freq=fmax,
            vtln_low=100.0,
            vtln_high=-500.0,
            vtln_warp_factor=1.0,
        )
        mb = torch.nn.functional.pad(mb, (0, 1), value=0.0)
        return mb.to(device=device, dtype=dtype)

    def forward(self, power_spec: torch.Tensor) -> torch.Tensor:
        if self.training:
            fmin = self.base_fmin + torch.randint(self.fmin_aug_range, (1,), device=power_spec.device).item()
            fmax = self.base_fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,), device=power_spec.device).item()
            nyquist = self.sr // 2 
            if fmax is None:
                fmax = nyquist
            elif fmax > nyquist:
                # Optional: print/log once so you know a clamp happened
                print(f"[LogMel] fmax={fmax} > Nyquist={nyquist}. Clamping to Nyquist.")
                fmax = float(nyquist)
        else:
            fmin, fmax = self.base_fmin, self.base_fmax

        mel_basis = self._build_mel_basis(fmin, fmax, device=power_spec.device, dtype=power_spec.dtype)
        mel = torch.matmul(mel_basis, power_spec)  # [B, n_mels, frames]
        return mel

#### for training only

In [6]:
extractor = LogMelSpectrogramConvSTFT(
    n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, fmin=0.0, fmax=16000.0
).train()

edge_aug = RandomMelEdgeProjector(
    n_mels=128, sr=32000, n_fft=1024, fmin=0.0, fmax=16000.0, fmin_aug_range=10, fmax_aug_range=2000
).train()

mask_aug = MelSpecAugment(freqm=48, timem=192).train()

waveform = torch.randn(1, 32000, dtype=torch.float32)  # [B, T]  # [B, T]
# If you need to inspect the power spectrogram
with torch.no_grad():
    power = extractor.power_forward(waveform)   # [B, n_fft//2+1, frames]

mel_lin = edge_aug(power)                          # randomized mel basis
log_mel = (mel_lin + 1e-5).log()
log_mel = (log_mel + 4.5) / 5.0
log_mel = mask_aug(log_mel)
print(log_mel)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


In [7]:
extractor.eval()
waveform = torch.randn(1, 32000, dtype=torch.float32)  # [B, T]  # [B, T]
log_mel = extractor(waveform)  # [B, n_mels, frames]
print(log_mel)

tensor([[[ 0.9018,  0.5767,  0.3806,  ...,  0.4952,  0.5852,  0.4006],
         [ 1.1505,  0.8254,  0.6293,  ...,  0.7439,  0.8338,  0.6493],
         [ 1.0858,  0.4610,  0.4911,  ..., -0.3809,  0.7485,  0.5415],
         ...,
         [ 2.6883,  2.8038,  2.7573,  ...,  2.8198,  2.8366,  2.8351],
         [ 2.6765,  2.6485,  2.8606,  ...,  2.9086,  2.8913,  2.7825],
         [ 2.6966,  2.7689,  2.8318,  ...,  2.7054,  2.9003,  2.7768]]])


In [8]:
model = LogMelSpectrogramConvSTFT(
    n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, fmin=0.0, fmax=16000.0
).eval()

dummy = torch.randn(1, 32000, dtype=torch.float32)  # [B, T]
torch.onnx.export(
    model,
    dummy,
    "logmel_convstft.onnx",
    opset_version=17,  # 13+ works; 17 recommended
    input_names=["waveform"],
    output_names=["log_mel"],
    dynamic_axes={"waveform": {0: "batch", 1: "time"}, "log_mel": {0: "batch", 2: "frames"}},
)
print("Exported logmel_convstft.onnx")

Exported logmel_convstft.onnx


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [9]:
import torch.nn as nn
import torchaudio
import torch


class AugmentMelSTFT(nn.Module):
    def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
                 fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):
        torch.nn.Module.__init__(self)
        # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e

        self.win_length = win_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.sr = sr
        self.fmin = fmin
        if fmax is None:
            fmax = sr // 2 - fmax_aug_range // 2
            print(f"Warning: FMAX is None setting to {fmax} ")
        self.fmax = fmax
        self.hopsize = hopsize
        self.register_buffer('window',
                             torch.hann_window(win_length, periodic=False),
                             persistent=False)
        assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
        assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
        self.fmin_aug_range = fmin_aug_range
        self.fmax_aug_range = fmax_aug_range

        self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
        if freqm == 0:
            self.freqm = torch.nn.Identity()
        else:
            self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
        if timem == 0:
            self.timem = torch.nn.Identity()
        else:
            self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)

    def forward(self, x):
        x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
        x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
                       center=True, normalized=False, window=self.window, return_complex=False)
        x = (x ** 2).sum(dim=-1)  # power mag
        fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
        fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
        # don't augment eval data
        if not self.training:
            fmin = self.fmin
            fmax = self.fmax

        mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels,  self.n_fft, self.sr,
                                        fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0)
        mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
                                    device=x.device)
        with torch.cuda.amp.autocast(enabled=False):
            melspec = torch.matmul(mel_basis, x)

        melspec = (melspec + 0.00001).log()

        if self.training:
            melspec = self.freqm(melspec)
            melspec = self.timem(melspec)

        melspec = (melspec + 4.5) / 5.  # fast normalization

        return melspec

In [10]:

waveform = torch.randn(1, 32000, dtype=torch.float32)  # [B, T]  # [B, T]

n_mels = 128
sample_rate = 32000
win_length=800
hopsize=320
n_fft=1024
fmin=0.0
fmax=16000.0

extractor = LogMelSpectrogramConvSTFT(
    n_mels=n_mels, sr=sample_rate, win_length=win_length, hopsize=hopsize, 
    n_fft=n_fft, fmin=fmin, fmax=fmax
).eval()

log_mel = extractor(waveform)  # [B, n_mels, frames]
print(log_mel)

# model to preprocess waveform into mel spectrograms
mel = AugmentMelSTFT(n_mels=n_mels,
                     sr=sample_rate,
                     win_length=win_length,
                     hopsize=hopsize,
                     n_fft=n_fft,
                     freqm=48,
                     timem=192,
                     fmin=fmin,
                     fmax=fmax,
                     fmin_aug_range=10, 
                     fmax_aug_range=2000
                     
                     )

mel.eval()
log_mel2 = mel(waveform)  # [B, n_mels, frames]
print(log_mel2)
print(log_mel.shape, log_mel2.shape)
print(log_mel -log_mel2) 

tensor([[[ 0.5383,  0.2939,  0.3629,  ...,  0.6083,  0.2326, -0.2799],
         [ 0.7870,  0.5426,  0.6115,  ...,  0.8569,  0.4812, -0.0317],
         [ 0.6747,  0.3803,  0.4762,  ..., -0.0573,  0.6308,  0.3580],
         ...,
         [ 2.6563,  2.8344,  2.7222,  ...,  2.8449,  2.6650,  2.9557],
         [ 2.7153,  2.8136,  2.6892,  ...,  2.8037,  2.6342,  2.8036],
         [ 2.7115,  2.7829,  2.7012,  ...,  2.7807,  2.7476,  2.8246]]])
tensor([[[ 0.6461,  0.3027,  0.3629,  ...,  0.6083,  0.2326, -0.2852],
         [ 0.8947,  0.5513,  0.6115,  ...,  0.8569,  0.4812, -0.0371],
         [ 0.7495,  0.3793,  0.4762,  ..., -0.0573,  0.6308,  0.3376],
         ...,
         [ 2.7759,  2.8342,  2.7222,  ...,  2.8449,  2.6650,  2.9557],
         [ 2.8732,  2.8141,  2.6892,  ...,  2.8037,  2.6342,  2.8036],
         [ 2.8636,  2.7832,  2.7012,  ...,  2.7807,  2.7476,  2.8246]]])
torch.Size([1, 128, 100]) torch.Size([1, 128, 100])
tensor([[[-1.0774e-01, -8.7723e-03, -2.9802e-07,  ...,  1.0729e-

Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:873.)
  return _VF.stft(  # type: ignore[attr-defined]
  with torch.cuda.amp.autocast(enabled=False):
