In [1]:
import math
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional

# ====== 1) Your model from previous step ======
from LogMelAffine import MelSpectrogramMatched

# For this snippet, Iâ€™ll redefine a stub import path:
# Make sure you actually import your implemented class instead
from types import SimpleNamespace


# ====== 2) Example dataset (replace with your own) ======
class WaveDataset(Dataset):
    """
    Replace this with your real dataset.
    This example returns raw (B, T) waveforms at a fixed sample rate.
    """
    def __init__(self, wavs: List[torch.Tensor], sample_rate: int = 16000):
        self.wavs = wavs
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.wavs)

    def __getitem__(self, idx):
        wav = self.wavs[idx]
        return wav  # shape: (T,)


def collate_pad(batch: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pads variable-length 1D waveforms to the max length in batch and returns a mask.
    Returns:
        wavs: (B, T_max)
        mask: (B, T_max) with 1.0 for valid samples else 0.0 (optional if you need masking later)
    """
    lengths = [x.shape[-1] for x in batch]
    T_max = max(lengths)
    B = len(batch)
    wavs = torch.zeros(B, T_max, dtype=batch[0].dtype)
    mask = torch.zeros(B, T_max, dtype=batch[0].dtype)
    for i, x in enumerate(batch):
        T = x.shape[-1]
        wavs[i, :T] = x
        mask[i, :T] = 1.0
    return wavs, mask


# ====== 3) Build the torchaudio reference (target) ======
def build_torchaudio_reference(
    sample_rate: int,
    n_fft: int,
    win_length: int,
    hop_length: int,
    n_mels: int,
    f_min: float,
    f_max: Optional[float],
    power: float,
    norm: Optional[str],
    mel_scale: str,
    log_eps: float
):
    """
    Returns a callable that maps wav -> y (log-mel target).
    We compute torchaudio MelSpectrogram (linear), then apply torch.log with clamp (eps).
    """
    mel_ref = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        window_fn=torch.hann_window,   # matches periodic=True behavior
        power=power,                   # |X|^power
        normalized=False,
        center=True,
        pad_mode="reflect",
        onesided=True,
        norm=norm,
        mel_scale=mel_scale,
    )

    def compute_logmel(wav: torch.Tensor) -> torch.Tensor:
        """
        wav: (B, T)
        return: (B, M, frames) log-mel (natural log) with log_eps clamp
        """
        with torch.no_grad():
            mel_lin = mel_ref(wav)  # (B, M, frames)
            y = torch.log(torch.clamp_min(mel_lin, log_eps))
        return y

    return mel_ref, compute_logmel


# ====== 4) Build your trainable model (prediction) ======
def build_trainable_mel(
    sample_rate: int,
    n_fft: int,
    win_length: int,
    hop_length: int,
    n_mels: int,
    f_min: float,
    f_max: Optional[float],
    power: float,
    norm: Optional[str],
    mel_scale: str,
    log_eps: float
) -> nn.Module:
    """
    Builds MelSpectrogramMatched with learn_affine=True.
    All other parameters (STFT, mel) are frozen.
    """
    model = MelSpectrogramMatched(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        power=power,
        normalized=False,
        center=True,
        pad_mode="reflect",
        onesided=True,
        norm=norm,
        mel_scale=mel_scale,
        apply_log=True,        # We want y_hat to be log-mel as well
        log_eps=log_eps,
        learn_affine=True,     # enable the learnable per-mel affine
    )

    # Freeze everything except affine parameters
    for name, p in model.named_parameters():
        if "affine" in name:
            p.requires_grad = True
        else:
            p.requires_grad = False

    return model


# ====== 5) Training loop ======
def train_affine_to_match_torchaudio(
    train_loader: DataLoader,
    sample_rate: int = 16000,
    n_fft: int = 1024,
    win_length: int = 800,
    hop_length: int = 320,
    n_mels: int = 64,
    f_min: float = 0.0,
    f_max: Optional[float] = None,
    power: float = 2.0,
    norm: Optional[str] = None,
    mel_scale: str = "htk",
    log_eps: float = 1e-6,
    lr: float = 1e-3,
    weight_decay: float = 0.0,
    max_epochs: int = 5,
    grad_clip: Optional[float] = 5.0,
    use_amp: bool = False,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    # Build reference (target) and model (prediction)
    mel_ref_module, compute_logmel = build_torchaudio_reference(
        sample_rate, n_fft, win_length, hop_length, n_mels, f_min, f_max, power, norm, mel_scale, log_eps
    )
    model = build_trainable_mel(
        sample_rate, n_fft, win_length, hop_length, n_mels, f_min, f_max, power, norm, mel_scale, log_eps
    )

    mel_ref_module.to(device)
    model.to(device)

    # Only optimize affine params
    affine_params = [p for p in model.parameters() if p.requires_grad]
    assert len(affine_params) > 0, "No trainable parameters found (affine)."
    optimizer = torch.optim.AdamW(affine_params, lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and device.startswith("cuda")))

    # Loss in log-domain: MSE usually works well
    criterion = nn.MSELoss()

    model.train()
    mel_ref_module.eval()  # we compute targets without grad

    for epoch in range(1, max_epochs + 1):
        running_loss = 0.0
        count = 0

        for wavs, _mask in train_loader:
            wavs = wavs.to(device)

            # Compute target: y (log-mel from torchaudio)
            with torch.no_grad():
                y = compute_logmel(wavs)   # (B, M, frames)

            # Compute prediction: y_hat (log-mel from our module with learnable affine)
            optimizer.zero_grad(set_to_none=True)

            if scaler.is_enabled():
                with torch.cuda.amp.autocast():
                    y_hat = model(wavs)         # (B, M, frames)
                    # Align shapes (frames may differ by 1 if padding/center differences exist; but we matched configs)
                    # If your dataset is tricky, you can min-align:
                    T_min = min(y.shape[-1], y_hat.shape[-1])
                    loss = criterion(y_hat[..., :T_min], y[..., :T_min])
            else:
                y_hat = model(wavs)
                T_min = min(y.shape[-1], y_hat.shape[-1])
                loss = criterion(y_hat[..., :T_min], y[..., :T_min])

            # Backward + step
            if scaler.is_enabled():
                scaler.scale(loss).backward()
                if grad_clip is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(affine_params, grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(affine_params, grad_clip)
                optimizer.step()

            running_loss += loss.item() * wavs.size(0)
            count += wavs.size(0)

        epoch_loss = running_loss / max(count, 1)
        print(f"[Epoch {epoch}/{max_epochs}] train MSE: {epoch_loss:.6f}")

    print("Training complete.")
    return model




In [2]:
# ====== 6) Example usage ======

# Make a toy dataset: a few random waveforms (replace with real audio)
sr = 16000
ex_wavs = [torch.randn(sr * 2) for _ in range(64)]  # 2 seconds each

ds = WaveDataset(ex_wavs, sample_rate=sr)
dl = DataLoader(ds, batch_size=8, shuffle=True, collate_fn=collate_pad, num_workers=0)

trained_model = train_affine_to_match_torchaudio(
    train_loader=dl,
    sample_rate=sr,
    n_fft=1024,
    win_length=800,
    hop_length=320,
    n_mels=64,
    f_min=0.0,
    f_max=sr / 2,
    power=2.0,
    norm=None,
    mel_scale="htk",
    log_eps=1e-6,
    lr=1e-3,
    weight_decay=0.0,
    max_epochs=3,
    grad_clip=5.0,
    use_amp=True,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

# Save learned affine parameters if you want
state = {k: v for k, v in trained_model.state_dict().items() if "affine" in k}
torch.save(state, "learned_affine_only.pt")
print("Saved learned affine weights to learned_affine_only.pt")

  scaler = torch.cuda.amp.GradScaler(enabled=(use_amp and device.startswith("cuda")))
  with torch.cuda.amp.autocast():


[Epoch 1/3] train MSE: 0.000000
[Epoch 2/3] train MSE: 0.000000
[Epoch 3/3] train MSE: 0.000000
Training complete.
Saved learned affine weights to learned_affine_only.pt
