In [None]:
%run init_notebook.py

In [None]:

%run init_notebook.py

import torch
import torch.nn as nn
import torchaudio

from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.models import compute_conv2D_output_size, compute_flattened_size

# Define parameters
input_height = 64
input_width = 128
latent_dim = 30
in_channels = 1
filters = [32, 64, 128]

model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters)
print("AutoEncoder model:")
print(model)

# Dummy input tensor: shape [batch_size, in_channels, input_height, input_width]
batch_size = 4
dummy_input = torch.randn(batch_size, in_channels, input_height, input_width)

# Dummy input through autoencoder
output = model(dummy_input)

# Check shapes
print("Input shape:", dummy_input.shape)
print("Output shape:", output.shape)
if dummy_input.shape == output.shape:
    print("Success: Output shape matches input shape.")
else:
    print("Mismatch: Adjust output_padding in your decoder layers if necessary.")


In [2]:
%run init_notebook.py

import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import librosa
import soundfile as sf
import time
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.display import Audio, display

from src.dataset import NSynth   
from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.dataset import load_raw_waveform

sample_rate = 16000
n_fft = 1024
hop_length = n_fft // 4
n_mels = 64

# Mel spectogram with log amplitude (db)
mel_transform = nn.Sequential(
    T.MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels),
    T.AmplitudeToDB(stype="power")
)

# Datasets and DataLoaders
train_dataset = NSynth(partition='training', transform=mel_transform)
valid_dataset = NSynth(partition='validation', transform=mel_transform)
test_dataset  = NSynth(partition='testing', transform=mel_transform)

# Subset for quicker training
train_dataset = Subset(train_dataset, list(range(1000)))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader( test_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)

# Model, Optimizer, and Loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_height = 64
input_width  = 251
latent_dim   = 128
in_channels  = 1
filters      = [32, 64, 128]

model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training Loop
num_epochs = 50
log_interval = 10

print(f"Starting training on {device}...")
for epoch in tqdm(range(num_epochs), desc="Training"):
    model.train()
    start_epoch_time = time.time()
    train_loss = 0.0
    
    # Only interested in mel_spec from (mel_spec, sample_rate, key, metadata)
    for i, (mel_spec, _, _, _) in enumerate(train_loader):
        mel_spec = mel_spec.to(device) 

        optimizer.zero_grad()
        output = model(mel_spec)
        
        loss = criterion(output, mel_spec)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * mel_spec.size(0)

        if (i + 1) % log_interval == 0:
            print(f"  [Epoch {epoch+1}, Step {i+1}/{len(train_loader)}] loss: {loss.item():.4f}")

    # Compute average epoch loss
    train_loss /= len(train_loader.dataset)
    epoch_time = time.time() - start_epoch_time

    print(f"Epoch {epoch+1}, train_loss={train_loss:.4f}, Time: {epoch_time:.2f}s")

print("Training complete.")

# --------------------------
# Testing, Inversion, and Audio Playback
# --------------------------
model.eval()

# Samples to compare
test_indices = [random.choice(range(len(test_dataset))) for _ in range(10)]

for idx in test_indices:
    print(f"\n=== Test sample index: {idx} ===")
    # (mel_spec, sample_rate, key, metadata) from dataset
    mel_spec, sample_rate, key, metadata = test_dataset[idx]

    # Listen to the Original Audio with no transform applied
    raw_waveform, raw_sr = load_raw_waveform("testing", key)
    print(f"Key: {key}")
    print("Original audio:")
    display(Audio(raw_waveform.numpy(), rate=raw_sr))

    # Reconstruct using the model
    mel_spec = mel_spec.unsqueeze(0).to(device)  # shape [1, 1, 64, 128]
    with torch.no_grad():
        reconstructed_mel = model(mel_spec)  # shape [1, 1, 64, 128]

    # Convert the reconstructed mel to waveform
    recon_np = reconstructed_mel.squeeze().cpu().numpy()  # [64, 128]
    recon_power = librosa.db_to_power(recon_np)  # dB -> power
    reconstructed_audio = librosa.feature.inverse.mel_to_audio(
        recon_power, sr=sample_rate, n_fft=n_fft, hop_length=hop_length
    )

    print("Reconstructed audio:")
    display(Audio(reconstructed_audio, rate=raw_sr))


[(64, 251), (32, 126)]
[(64, 251), (32, 126), (16, 63)]
[(64, 251), (32, 126), (16, 63), (8, 32)]
Encoder:  Encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=32768, out_features=128, bias=True)
  )
)
Decoder:  Decoder(
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=32768, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(128, 8, 32))
    (2): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 0))
    (3): ReLU()
    (4): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride

Training:   0%|          | 0/50 [00:00<?, ?it/s]

  [Epoch 1, Step 10/16] loss: 2994.1494


Training:   2%|▏         | 1/50 [00:22<18:23, 22.51s/it]

Epoch 1, train_loss=3059.8166, Time: 22.51s
  [Epoch 2, Step 10/16] loss: 1067.9877


Training:   4%|▍         | 2/50 [00:44<17:49, 22.28s/it]

Epoch 2, train_loss=1192.1885, Time: 22.12s
  [Epoch 3, Step 10/16] loss: 800.1092


Training:   6%|▌         | 3/50 [01:06<17:25, 22.25s/it]

Epoch 3, train_loss=807.2681, Time: 22.21s
  [Epoch 4, Step 10/16] loss: 694.3187


Training:   8%|▊         | 4/50 [01:28<16:58, 22.14s/it]

Epoch 4, train_loss=718.8883, Time: 21.98s
  [Epoch 5, Step 10/16] loss: 715.1700


Training:  10%|█         | 5/50 [01:50<16:33, 22.09s/it]

Epoch 5, train_loss=683.1047, Time: 21.98s
  [Epoch 6, Step 10/16] loss: 608.6961


Training:  12%|█▏        | 6/50 [02:13<16:16, 22.20s/it]

Epoch 6, train_loss=661.5088, Time: 22.43s
  [Epoch 7, Step 10/16] loss: 620.2453


Training:  14%|█▍        | 7/50 [02:35<15:51, 22.13s/it]

Epoch 7, train_loss=642.8996, Time: 21.98s
  [Epoch 8, Step 10/16] loss: 617.6389


Training:  16%|█▌        | 8/50 [02:57<15:30, 22.16s/it]

Epoch 8, train_loss=617.5354, Time: 22.23s
  [Epoch 9, Step 10/16] loss: 572.1246


Training:  18%|█▊        | 9/50 [03:19<15:03, 22.03s/it]

Epoch 9, train_loss=569.6076, Time: 21.74s
  [Epoch 10, Step 10/16] loss: 454.3307


Training:  20%|██        | 10/50 [03:41<14:45, 22.14s/it]

Epoch 10, train_loss=468.1519, Time: 22.39s
  [Epoch 11, Step 10/16] loss: 365.3917


Training:  22%|██▏       | 11/50 [04:03<14:26, 22.22s/it]

Epoch 11, train_loss=394.9323, Time: 22.39s
  [Epoch 12, Step 10/16] loss: 357.7527


Training:  24%|██▍       | 12/50 [04:26<14:04, 22.22s/it]

Epoch 12, train_loss=368.0600, Time: 22.21s
  [Epoch 13, Step 10/16] loss: 349.5011


Training:  26%|██▌       | 13/50 [04:48<13:41, 22.20s/it]

Epoch 13, train_loss=347.5374, Time: 22.18s
  [Epoch 14, Step 10/16] loss: 319.5978


Training:  28%|██▊       | 14/50 [05:10<13:17, 22.14s/it]

Epoch 14, train_loss=319.4722, Time: 22.00s
  [Epoch 15, Step 10/16] loss: 291.1854


Training:  30%|███       | 15/50 [05:32<12:53, 22.09s/it]

Epoch 15, train_loss=295.7544, Time: 21.97s
  [Epoch 16, Step 10/16] loss: 273.9614


Training:  32%|███▏      | 16/50 [05:54<12:32, 22.14s/it]

Epoch 16, train_loss=282.2459, Time: 22.25s
  [Epoch 17, Step 10/16] loss: 273.6564


Training:  34%|███▍      | 17/50 [06:16<12:10, 22.14s/it]

Epoch 17, train_loss=270.2778, Time: 22.16s
  [Epoch 18, Step 10/16] loss: 251.6316


Training:  36%|███▌      | 18/50 [06:38<11:47, 22.10s/it]

Epoch 18, train_loss=257.4764, Time: 21.98s
  [Epoch 19, Step 10/16] loss: 231.5952


Training:  38%|███▊      | 19/50 [07:00<11:24, 22.07s/it]

Epoch 19, train_loss=242.2288, Time: 22.00s
  [Epoch 20, Step 10/16] loss: 229.3536


Training:  40%|████      | 20/50 [07:22<11:03, 22.11s/it]

Epoch 20, train_loss=227.5024, Time: 22.20s
  [Epoch 21, Step 10/16] loss: 217.9447


Training:  42%|████▏     | 21/50 [07:44<10:39, 22.07s/it]

Epoch 21, train_loss=214.1492, Time: 21.97s
  [Epoch 22, Step 10/16] loss: 204.0850


Training:  44%|████▍     | 22/50 [08:07<10:20, 22.16s/it]

Epoch 22, train_loss=202.4854, Time: 22.39s
  [Epoch 23, Step 10/16] loss: 213.8890


Training:  46%|████▌     | 23/50 [08:29<09:59, 22.19s/it]

Epoch 23, train_loss=193.3874, Time: 22.27s
  [Epoch 24, Step 10/16] loss: 174.8864


Training:  48%|████▊     | 24/50 [08:51<09:36, 22.18s/it]

Epoch 24, train_loss=186.5547, Time: 22.14s
  [Epoch 25, Step 10/16] loss: 180.5647


Training:  50%|█████     | 25/50 [09:13<09:14, 22.18s/it]

Epoch 25, train_loss=181.9226, Time: 22.19s
  [Epoch 26, Step 10/16] loss: 187.0184


Training:  52%|█████▏    | 26/50 [09:36<08:52, 22.18s/it]

Epoch 26, train_loss=177.2482, Time: 22.17s
  [Epoch 27, Step 10/16] loss: 175.1600


Training:  54%|█████▍    | 27/50 [09:58<08:28, 22.12s/it]

Epoch 27, train_loss=172.9301, Time: 22.00s
  [Epoch 28, Step 10/16] loss: 191.4992


Training:  56%|█████▌    | 28/50 [10:20<08:08, 22.20s/it]

Epoch 28, train_loss=169.3866, Time: 22.38s
  [Epoch 29, Step 10/16] loss: 180.1975


Training:  58%|█████▊    | 29/50 [10:42<07:46, 22.20s/it]

Epoch 29, train_loss=165.3373, Time: 22.20s
  [Epoch 30, Step 10/16] loss: 144.1614


Training:  60%|██████    | 30/50 [11:04<07:24, 22.20s/it]

Epoch 30, train_loss=161.4255, Time: 22.20s
  [Epoch 31, Step 10/16] loss: 170.0315


Training:  62%|██████▏   | 31/50 [11:26<07:00, 22.15s/it]

Epoch 31, train_loss=158.1243, Time: 22.02s
  [Epoch 32, Step 10/16] loss: 151.6772


Training:  64%|██████▍   | 32/50 [11:48<06:38, 22.15s/it]

Epoch 32, train_loss=154.8961, Time: 22.16s
  [Epoch 33, Step 10/16] loss: 172.4924


Training:  66%|██████▌   | 33/50 [12:11<06:15, 22.11s/it]

Epoch 33, train_loss=151.5877, Time: 21.98s
  [Epoch 34, Step 10/16] loss: 140.7167


Training:  68%|██████▊   | 34/50 [12:33<05:54, 22.13s/it]

Epoch 34, train_loss=148.6802, Time: 22.17s
  [Epoch 35, Step 10/16] loss: 152.6289


Training:  70%|███████   | 35/50 [12:55<05:32, 22.15s/it]

Epoch 35, train_loss=146.1917, Time: 22.19s
  [Epoch 36, Step 10/16] loss: 123.9231


Training:  72%|███████▏  | 36/50 [13:17<05:10, 22.16s/it]

Epoch 36, train_loss=143.8542, Time: 22.21s
  [Epoch 37, Step 10/16] loss: 136.0510


Training:  74%|███████▍  | 37/50 [13:39<04:47, 22.11s/it]

Epoch 37, train_loss=141.4568, Time: 21.98s
  [Epoch 38, Step 10/16] loss: 130.2657


Training:  76%|███████▌  | 38/50 [14:01<04:24, 22.08s/it]

Epoch 38, train_loss=139.3420, Time: 22.00s
  [Epoch 39, Step 10/16] loss: 126.3072


Training:  78%|███████▊  | 39/50 [14:23<04:02, 22.05s/it]

Epoch 39, train_loss=137.5366, Time: 21.98s
  [Epoch 40, Step 10/16] loss: 131.5848


Training:  80%|████████  | 40/50 [14:45<03:40, 22.03s/it]

Epoch 40, train_loss=136.3342, Time: 21.99s
  [Epoch 41, Step 10/16] loss: 128.4210


Training:  82%|████████▏ | 41/50 [15:07<03:18, 22.02s/it]

Epoch 41, train_loss=134.7839, Time: 22.00s
  [Epoch 42, Step 10/16] loss: 112.8576


Training:  84%|████████▍ | 42/50 [15:29<02:56, 22.07s/it]

Epoch 42, train_loss=133.0514, Time: 22.19s
  [Epoch 43, Step 10/16] loss: 114.8811


Training:  86%|████████▌ | 43/50 [15:51<02:34, 22.12s/it]

Epoch 43, train_loss=131.6478, Time: 22.25s
  [Epoch 44, Step 10/16] loss: 145.8648


Training:  88%|████████▊ | 44/50 [16:14<02:12, 22.13s/it]

Epoch 44, train_loss=130.0903, Time: 22.14s
  [Epoch 45, Step 10/16] loss: 130.2158


Training:  90%|█████████ | 45/50 [16:36<01:51, 22.21s/it]

Epoch 45, train_loss=129.9570, Time: 22.39s
  [Epoch 46, Step 10/16] loss: 148.7019


Training:  92%|█████████▏| 46/50 [16:58<01:28, 22.15s/it]

Epoch 46, train_loss=128.1131, Time: 22.00s
  [Epoch 47, Step 10/16] loss: 119.6483


Training:  94%|█████████▍| 47/50 [17:20<01:06, 22.16s/it]

Epoch 47, train_loss=126.9109, Time: 22.20s
  [Epoch 48, Step 10/16] loss: 126.8355


Training:  96%|█████████▌| 48/50 [17:42<00:44, 22.11s/it]

Epoch 48, train_loss=125.6503, Time: 21.98s
  [Epoch 49, Step 10/16] loss: 117.0597


Training:  98%|█████████▊| 49/50 [18:04<00:22, 22.07s/it]

Epoch 49, train_loss=124.8087, Time: 21.97s
  [Epoch 50, Step 10/16] loss: 124.9366


Training: 100%|██████████| 50/50 [18:26<00:00, 22.14s/it]

Epoch 50, train_loss=124.1797, Time: 22.19s
Training complete.

=== Test sample index: 1892 ===
Key: vocal_synthetic_003-060-100
Original audio:





Reconstructed audio:



=== Test sample index: 50 ===
Key: guitar_acoustic_015-104-100
Original audio:


Reconstructed audio:



=== Test sample index: 772 ===
Key: flute_synthetic_000-049-025
Original audio:


Reconstructed audio:



=== Test sample index: 2964 ===
Key: reed_acoustic_011-060-075
Original audio:


Reconstructed audio:



=== Test sample index: 3050 ===
Key: mallet_acoustic_062-037-127
Original audio:


Reconstructed audio:



=== Test sample index: 10 ===
Key: bass_synthetic_033-041-075
Original audio:


Reconstructed audio:



=== Test sample index: 2613 ===
Key: guitar_acoustic_015-034-075
Original audio:


Reconstructed audio:



=== Test sample index: 1206 ===
Key: bass_electronic_027-032-127
Original audio:


Reconstructed audio:



=== Test sample index: 1677 ===
Key: keyboard_acoustic_004-090-075
Original audio:


Reconstructed audio:



=== Test sample index: 3834 ===
Key: organ_electronic_057-082-075
Original audio:


Reconstructed audio:
