In [None]:

%run init_notebook.py

import torch
import torch.nn as nn
import torchaudio
import torchmetrics

from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.models import compute_conv2D_output_size, compute_flattened_size

# Define parameters
input_height = 64
input_width = 128
latent_dim = 30
in_channels = 1
filters = [32, 64, 128]

model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters)
print("AutoEncoder model:")
print(model)

# Dummy input tensor: shape [batch_size, in_channels, input_height, input_width]
batch_size = 4
dummy_input = torch.randn(batch_size, in_channels, input_height, input_width)

# Dummy input through autoencoder
output = model(dummy_input)

# Check shapes
print("Input shape:", dummy_input.shape)
print("Output shape:", output.shape)
if dummy_input.shape == output.shape:
    print("Success: Output shape matches input shape.")
else:
    print("Mismatch: Adjust output_padding in your decoder layers if necessary.")


In [None]:
%run init_notebook.py

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import librosa
import time
import torchaudio
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.display import Audio, display
from torchmetrics import NormalizedRootMeanSquaredError

from src.dataset import NSynth   
from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.dataset import load_raw_waveform
from src.utils.logger import save_training_results
from src.utils.models import compute_magnitude_and_phase

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# STFT transform
sample_rate = 16000
n_fft = 512
hop_length = 128
win_length = n_fft  # Same as n_fft

# Apply the correct transform for magnitude and phase (onesided=False to handle complex spectrogram)
stft_transform = T.Spectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    power=None,  # Keep as complex spectrogram (magnitude and phase)
    onesided=False,  # Make sure we keep the full spectrum (complex-valued)
    center=False
).to(device)

# Define the inverse STFT function (onesided=False)
istft_transform = torchaudio.transforms.InverseSpectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    onesided=False  # Make sure we reconstruct using the full complex spectrogram
).to(device)

# Datasets and DataLoaders
train_dataset = NSynth(partition='training')
valid_dataset = NSynth(partition='validation')
test_dataset  = NSynth(partition='testing')

# Subset for quicker training
batch_size = 64
training_subset_size = len(train_dataset)
train_dataset = Subset(train_dataset, list(range(training_subset_size)))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

# Model, Optimizer, and Loss

input_height = 512 
input_width  = 497
latent_dim   = 128
learning_rate = 1e-4

model = AutoEncoder((input_height, input_width), latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss().to(device) # NormalizedRootMeanSquaredError(normalization='l2').to(device)

# Training Loop
num_epochs = 50
log_interval = 10
avg_epoch_time = 0.0

# Learning rate scheduler to reduce learning rate based on validation loss
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

print(f"Starting training on {device}...")
for epoch in tqdm(range(num_epochs), desc="Training"):
    model.train()
    start_epoch_time = time.time()
    train_loss = 0.0
    
    # Training Loop
    for i, (waveform, _, _, _) in enumerate(train_loader):
        # waveform has shape [batch_size, 1, time_steps]

        waveform = waveform.to(device)

        # Apply STFT transformation
        stft_spec = stft_transform(waveform)  # [batch_size, 1, freq_bins, time_frames]

        # Extract magnitude and phase
        magnitude, phase = compute_magnitude_and_phase(stft_spec)

        input_data = torch.cat([magnitude, phase], dim=1).to(device)  # [batch_size, 2, freq_bins, time_frames]

        optimizer.zero_grad()
        output = model(input_data)  # shape [batch_size, 2, freq_bins, time_frames]

        loss = criterion(output, input_data)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * waveform.size(0)

    # Compute average epoch loss for training
    train_loss /= len(train_loader.dataset)

    # Validation Loop
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for i, (waveform, _, _, _) in enumerate(valid_loader):
            waveform = waveform.to(device)
            stft_spec = stft_transform(waveform)
            
            # Extract magnitude and phase
            magnitude, phase = compute_magnitude_and_phase(stft_spec)
            input_data = torch.cat([magnitude, phase], dim=1).to(device)

            output = model(input_data)
            loss = criterion(output, input_data)
            valid_loss += loss.item() * waveform.size(0)

    # Compute average validation loss
    valid_loss /= len(valid_loader.dataset)
    
    # Step the scheduler with validation loss
    scheduler.step(valid_loss)

    epoch_time = time.time() - start_epoch_time
    avg_epoch_time += epoch_time
    print(f"Epoch {epoch+1}, train_loss={train_loss}, valid_loss={valid_loss}, Time: {epoch_time:.2f}s")

    # Save training results after the last epoch
    if epoch == num_epochs - 1:
        save_training_results({
            "train_loss": train_loss,
            "valid_loss": valid_loss,
            "num_epochs": num_epochs,
            "avg_epoch_time": avg_epoch_time / num_epochs,
            "learning_rate": learning_rate,
            "training_subset_size": training_subset_size,
            "batch_size": batch_size,
            "sample_rate": sample_rate,
            "n_fft": n_fft,
            "hop_length": hop_length,
            "input_height": input_height,
            "input_width": input_width,
            "latent_dim": latent_dim,
        })

print("Training complete.")


# --------------------------
# Testing, Inversion, and Audio Playback
# --------------------------
model.eval()

# Samples to compare
test_indices = [random.choice(range(len(test_dataset))) for _ in range(10)]

for idx in test_indices:
    print(f"\n=== Test sample index: {idx} ===")
    # (stft_spec, sample_rate, key, metadata) from dataset; here we only need key and sample_rate
    _, sample_rate, key, metadata = test_dataset[idx]

    # Listen to the Original Audio with no transform applied
    raw_waveform, raw_sr = load_raw_waveform("testing", key)
    print(f"Key: {key}")
    print("Original audio:")
    display(Audio(raw_waveform.numpy(), rate=raw_sr))

    # Ensure raw_waveform has a batch and channel dimension, expected shape: [batch, 1, time_steps]
    if raw_waveform.ndim == 2:  # e.g., [1, time_steps]
        raw_waveform = raw_waveform.unsqueeze(0)
    
    raw_waveform = raw_waveform.to(device)

    # Apply STFT transformation
    stft_spec = stft_transform(raw_waveform)  # [batch, 1, freq_bins, time_frames]
    
    # Extract magnitude and phase using the helper function
    magnitude, phase = compute_magnitude_and_phase(stft_spec)  # both: [batch, 1, freq_bins, time_frames]
    input_data = torch.cat([magnitude, phase], dim=1).to(device)
    
    # Reconstruct using the model
    with torch.no_grad():
        output = model(input_data)  # shape: [batch, 2, freq_bins, time_frames]
    
    # Split the network output into magnitude and phase
    recon_magnitude = output[:, 0, :, :]  # [batch, freq_bins, time_frames]
    recon_phase = output[:, 1, :, :]          # [batch, freq_bins, time_frames]

    if torch.isnan(recon_magnitude).any() or torch.isinf(recon_magnitude).any():
        print("NaN or Inf detected in magnitude output!")
    if torch.isnan(recon_phase).any() or torch.isinf(recon_phase).any():
        print("NaN or Inf detected in phase output!")
    # Recombine into a complex spectrogram
    recon_complex = recon_magnitude * torch.exp(1j * recon_phase)
    
    # Apply the inverse STFT transformation to get the waveform
    reconstructed_waveform = istft_transform(recon_complex)  # [batch, time_steps]
    
    print("Reconstructed audio:")
    display(Audio(reconstructed_waveform.squeeze().cpu().numpy(), rate=raw_sr))

Encoder:  Encoder(
  (encoder_blocks): Sequential(
    (0): Sequential(
      (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=516096, out_features=128, bias=True)
)
Decoder:  Decoder(
  (fc2): Linear(in_features=128, out_features=516096, bias=True)
  (unflatten): Unflatten(dim=1, unflattened_size=(128, 64, 63))
  (deconv_blocks): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (2): Sequential(
   

Training:   2%|▏         | 1/50 [07:11<5:52:17, 431.37s/it]

Epoch 1, train_loss=2.4877455362019965, valid_loss=2.2662928536555467, Time: 431.37s
