In [2]:
%run init_notebook.py

In [5]:

import torch
import torch.nn as nn
import torchaudio

# Adjust these imports according to your project structure:
from src.models import AutoEncoder
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING
from src.utils.models import compute_output_size, compute_flattened_size

# Define parameters
input_height = 64
input_width = 128
latent_dim = 30
in_channels = 1
filters = [32, 64, 128]

# Instantiate the autoencoder model.
# Make sure your AutoEncoder class is defined to accept (input_height, input_width, latent_dim, in_channels, filters)
model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters)
print("AutoEncoder model:")
print(model)

# Create a dummy input tensor: shape [batch_size, in_channels, input_height, input_width]
batch_size = 4
dummy_input = torch.randn(batch_size, in_channels, input_height, input_width)

# Pass the dummy input through the autoencoder d
output = model(dummy_input)

# Print input and output shapes
print("Input shape:", dummy_input.shape)
print("Output shape:", output.shape)

# Check if the output shape matches the input shape
if dummy_input.shape == output.shape:
    print("Success: Output shape matches input shape.")
else:
    print("Mismatch: Adjust output_padding in your decoder layers if necessary.")


AutoEncoder model:
AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=16384, out_features=30, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=30, out_features=16384, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(128, 8, 16))
    (2): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (7): ReLU()
    (8): ConvTranspose2d(32, 1, kernel_size=

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import librosa
import soundfile as sf
from tqdm import tqdm
from torch.utils.data import DataLoader
from IPython.display import Audio

# Adjust these imports according to your project structure:
from src.dataset import NSynth      # Your NSynth dataset class
from src.models import AutoEncoder  # Your autoencoder model
from src.config import CONV_KERNEL_SIZE, CONV_STRIDE, CONV_PADDING

# --------------------------
# Define the Mel Spectrogram Transform
# --------------------------
mel_transform = T.MelSpectrogram(
    sample_rate=16000,
    n_fft=1024,
    hop_length=501,  # So that the mel spectrogram has shape [1, 64, 128]
    n_mels=64,
    normalized=True # Important
)

# --------------------------
# Create Datasets and DataLoaders
# --------------------------
train_dataset = NSynth(partition='training', transform=mel_transform)
test_dataset  = NSynth(partition='testing', transform=mel_transform)

# Optionally, use a subset for quicker debugging:
from torch.utils.data import Subset
train_dataset = Subset(train_dataset, list(range(100)))

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False)

# --------------------------
# Model, Optimizer, and Loss
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_height = 64
input_width  = 128
latent_dim   = 5
in_channels  = 1
filters      = [8]

model = AutoEncoder(input_height, input_width, latent_dim, in_channels, filters).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # Reduce learning rate
criterion = nn.MSELoss()

# --------------------------
# Training Loop
# --------------------------
num_epochs = 20
print("Starting training...")
for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
    model.train()
    epoch_loss = 0.0
    for i, (_, mel_spec, _) in enumerate(train_loader):
        print(f"Iteration {i+1}")
        mel_spec = mel_spec.to(device)  # Expected shape: [batch, 1, 64, 128]
        optimizer.zero_grad()
        output = model(mel_spec)
        loss = criterion(output, mel_spec)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * mel_spec.size(0)

    epoch_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print("Training complete.")

# --------------------------
# Testing, Inversion, and Audio Playback
# --------------------------
model.eval()
with torch.no_grad():
    for _, mel_spec, sample_rate in test_loader:
        mel_spec = mel_spec.to(device)  # shape [1, 1, 64, 128]
        reconstructed = model(mel_spec)
        break

print("Input shape:", mel_spec.shape)
print("Reconstructed shape:", reconstructed.shape)

# Convert the reconstructed mel spectrogram to a numpy array (squeeze batch and channel dims)
reconstructed_np = reconstructed.squeeze().cpu().numpy()  # Expected shape: [64, 128]

# Inverse Mel Transformation:
sr = 16000       # Sample rate
n_fft = 1024     # FFT window size
hop_length = 501 # Must match mel_transform hop_length
n_mels = 64      # Number of mel bins

# Convert mel spectrogram back to waveform
reconstructed_audio = librosa.feature.inverse.mel_to_audio(
    reconstructed_np, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)

# Save the audio file
sf.write("reconstructed.wav", reconstructed_audio, sr)
print("Reconstructed audio saved to 'reconstructed.wav'.")

# Play the audio in the notebook
Audio(reconstructed_audio, rate=sr)


Starting training...


Training Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:   5%|▌         | 1/20 [06:39<2:06:31, 399.53s/it]

Iteration 25
Epoch 1/20, Loss: 17.3807
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  10%|█         | 2/20 [13:10<1:58:24, 394.70s/it]

Iteration 25
Epoch 2/20, Loss: 17.3724
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  15%|█▌        | 3/20 [19:43<1:51:31, 393.60s/it]

Iteration 25
Epoch 3/20, Loss: 17.3681
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  20%|██        | 4/20 [26:10<1:44:16, 391.05s/it]

Iteration 25
Epoch 4/20, Loss: 17.3592
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  25%|██▌       | 5/20 [32:40<1:37:43, 390.87s/it]

Iteration 25
Epoch 5/20, Loss: 17.3526
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  30%|███       | 6/20 [39:13<1:31:19, 391.36s/it]

Iteration 25
Epoch 6/20, Loss: 17.3452
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14
Iteration 15
Iteration 16
Iteration 17
Iteration 18
Iteration 19
Iteration 20
Iteration 21
Iteration 22
Iteration 23
Iteration 24


Training Epochs:  35%|███▌      | 7/20 [45:43<1:24:43, 391.04s/it]

Iteration 25
Epoch 7/20, Loss: 17.3383
Iteration 1
Iteration 2
Iteration 3
Iteration 4
Iteration 5
Iteration 6
Iteration 7
Iteration 8
Iteration 9
Iteration 10
Iteration 11
Iteration 12
Iteration 13
Iteration 14


Training Epochs:  35%|███▌      | 7/20 [49:27<1:31:51, 423.98s/it]


KeyboardInterrupt: 