In [None]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import librosa
import soundfile as sf
import time
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.display import Audio, display
from torchmetrics import NormalizedRootMeanSquaredError
import auraloss  # Import auraloss for MultiResolutionSTFTLoss

from src.dataset import NSynth   
from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.dataset import load_raw_waveform
from src.utils.logger import save_training_results

# Mel spectrogram with log amplitude (dB) and Z-score normalization
sample_rate = 16000
n_fft = 1024
hop_length = n_fft // 4
n_mels = 128

mel_transform = nn.Sequential(
    T.MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels),
    T.AmplitudeToDB(stype="power"),
)

# Datasets and DataLoaders
train_dataset = NSynth(partition='training', transform=mel_transform)
valid_dataset = NSynth(partition='validation', transform=mel_transform)
test_dataset  = NSynth(partition='testing', transform=mel_transform)

# Subset for quicker training
batch_size = 64
training_subset_size = 70000 # len(train_dataset) # 50000
train_dataset = Subset(train_dataset, list(range(training_subset_size)))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader =  DataLoader(test_dataset,  batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

# Model, Optimizer, and Loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_height = n_mels
input_width  = 251
latent_dim   = 512
in_channels  = 1
filters      = [32, 64, 128]
learning_rate = 1e-4

model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define MultiResolutionSTFTLoss
loss_fn = auraloss.freq.MelSTFTLoss(
    sample_rate=sample_rate,
    fft_size=n_fft,
    hop_size=hop_length,
    win_length=n_fft,
    n_mels=n_mels
).to(device)

# Training Loop
num_epochs = 50
log_interval = 10
avg_epoch_time = 0.0

# Learning rate scheduler to reduce learning rate based on validation loss
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

print(f"Starting training on {device}...")
for epoch in tqdm(range(num_epochs), desc="Training"):
    model.train()
    start_epoch_time = time.time()
    train_loss = 0.0
    
    # Training Loop
    for i, (mel_spec, _, _, _) in enumerate(train_loader):
        mel_spec = mel_spec.to(device) 

        optimizer.zero_grad()
        output = model(mel_spec)
        
        # Convert the mel-spectrogram back to waveform using InverseMelScale
        inverse_mel = T.InverseMelScale(
            n_stft=n_fft // 2 + 1,  # Ensure this matches n_fft from MelSpectrogram
            mel_scale="slaney"       # Matches the default Mel scale used in the forward transform
        ).to(device)

        # Convert the output mel spectrogram to waveform
        output_waveform = inverse_mel(output)  # Shape: [batch_size, 1, time_samples]
        print("output_waveform.shape = ", output_waveform.shape)
        target_waveform = inverse_mel(mel_spec)  # Shape: [batch_size, 1, time_samples]
        print("target_waveform.shape = ", target_waveform.shape)

        # Squeeze batch and channel dimensions if needed
        output_waveform = output_waveform.squeeze(1)  # Shape: [batch_size, time_samples]
        print("output_waveform.shape = ", output_waveform.shape)
        target_waveform = target_waveform.squeeze(1)  # Shape: [batch_size, time_samples]
        print("target_waveform.shape = ", target_waveform.shape)

        # Calculate the loss between the reconstructed waveform and the target waveform
        loss = loss_fn(output_waveform, target_waveform)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * mel_spec.size(0)

    # Compute average epoch loss for training
    train_loss /= len(train_loader.dataset)

    # Validation Loop
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for i, (mel_spec, _, _, _) in enumerate(valid_loader):
            mel_spec = mel_spec.to(device)
            output = model(mel_spec)
            
            output_waveform = inverse_mel(output)  # Convert to waveform
            
            loss = loss_fn(output_waveform, mel_spec)  # Use MultiResolutionSTFTLoss here
            valid_loss += loss.item() * mel_spec.size(0)

    # Compute average validation loss
    valid_loss /= len(valid_loader.dataset)
    
    # Step the scheduler with validation loss
    scheduler.step(valid_loss)

    epoch_time = time.time() - start_epoch_time
    avg_epoch_time += epoch_time
    print(f"Epoch {epoch+1}, train_loss={train_loss}, valid_loss={valid_loss}, Time: {epoch_time:.2f}s")

    # Save training results after the last epoch
    if epoch == num_epochs - 1:
        save_training_results({
            "train_loss": train_loss,
            "valid_loss": valid_loss,
            "num_epochs": num_epochs,
            "avg_epoch_time": avg_epoch_time / num_epochs,
            "learning_rate": learning_rate,
            "training_subset_size": training_subset_size,
            "batch_size": batch_size,
            "sample_rate": sample_rate,
            "n_fft": n_fft,
            "hop_length": hop_length,
            "n_mels": n_mels,
            "input_height": input_height,
            "input_width": input_width,
            "latent_dim": latent_dim,
            "in_channels": in_channels,
            "filters": filters,
        })

print("Training complete.")

# --------------------------
# Testing, Inversion, and Audio Playback
# --------------------------
model.eval()

# Samples to compare
test_indices = [random.choice(range(len(test_dataset))) for _ in range(10)]

for idx in test_indices:
    print(f"\n=== Test sample index: {idx} ===")
    # (mel_spec, sample_rate, key, metadata) from dataset
    mel_spec, sample_rate, key, metadata = test_dataset[idx]

    # Listen to the Original Audio with no transform applied
    raw_waveform, raw_sr = load_raw_waveform("testing", key)
    print(f"Key: {key}")
    print("Original audio:")
    display(Audio(raw_waveform.numpy(), rate=raw_sr))

    # Reconstruct using the model
    mel_spec = mel_spec.unsqueeze(0).to(device)  # shape [1, 1, n_mels, time_frames]
    with torch.no_grad():
        reconstructed_mel = model(mel_spec)  # shape [1, 1, n_mels, time_frames]

    # Convert the reconstructed mel to waveform
    recon_np = reconstructed_mel.squeeze().cpu().numpy()  # [n_mels, time_frames]
    recon_power = librosa.db_to_power(recon_np)  # dB -> power
    reconstructed_audio = librosa.feature.inverse.mel_to_audio(
        recon_power, sr=sample_rate, n_fft=n_fft, hop_length=hop_length
    )

    print("Reconstructed audio:")
    display(Audio(reconstructed_audio, rate=raw_sr))


[(128, 251), (64, 126)]
[(128, 251), (64, 126), (32, 63)]
[(128, 251), (64, 126), (32, 63), (16, 32)]
Encoder:  Encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=65536, out_features=512, bias=True)
  )
)
Decoder:  Decoder(
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=65536, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(128, 16, 32))
    (2): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 0))
    (3): ReLU()
    (4): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), s

Training:   0%|          | 0/50 [00:22<?, ?it/s]


ValueError: too many values to unpack (expected 3)