In [None]:
%run init_notebook.py

import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import librosa
import time
import torchaudio
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.display import Audio, display

from src.dataset import NSynth   
from src.models import VAE
from src.utils.models import adjust_shape
from src.utils.logger import save_training_results
from src.utils.models import compute_magnitude_and_phase

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = r"C:\Users\Articuno\Desktop\TFG-info\data\models\vae.pth"

# STFT transform
sample_rate = 16000
n_fft = 800
hop_length = 100
win_length = n_fft  # Same as n_fft

# Apply the correct transform for magnitude and phase (onesided=False to handle complex spectrogram)
stft_transform = T.Spectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    power=None,  # Keep as complex spectrogram (magnitude and phase)
    onesided=False,  # Make sure we keep the full spectrum (complex-valued)
    center=False
).to(device)

# Define the inverse STFT function (onesided=False)
istft_transform = torchaudio.transforms.InverseSpectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    onesided=False  # Make sure we reconstruct using the full complex spectrogram
).to(device)

# Datasets and DataLoaders, training parameters.
train_dataset = NSynth(partition='training')
valid_dataset = NSynth(partition='validation')
test_dataset  = NSynth(partition='testing')

batch_size = 64
training_subset_size = len(train_dataset)
train_dataset = Subset(train_dataset, list(range(training_subset_size)))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

input_height = 800 
input_width  = 633
latent_dim   = 256
learning_rate = 1.5e-3

model = VAE((input_height, input_width), latent_dim).to(device)
if os.path.exists(model_path):
    try:
        model.load_state_dict(torch.load(model_path))
        print("Model loaded successfully.")
    except PermissionError as e:
        print(f"PermissionError: {e}. Unable to load the model.")
else:
    print("No saved model found, starting training from scratch.")

    # Model, Optimizer, and Loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training Loop
    num_epochs = 30
    log_interval = 10
    avg_epoch_time = 0.0

    training_losses = []
    validation_losses = []

    # Learning rate scheduler to reduce learning rate based on validation loss
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    print(f"Starting training on {device}...")
    for epoch in tqdm(range(num_epochs), desc="Training"):

        model.train()
        start_epoch_time = time.time()
        train_loss = 0.0
        
        # Training Loop
        for i, (waveform, _, _, _) in enumerate(train_loader):
            # waveform has shape [batch_size, 1, time_steps]

            waveform = waveform.to(device)

            # Apply STFT transformation
            stft_spec = stft_transform(waveform)  # [batch_size, 1, freq_bins, time_frames]

            # Extract magnitude and phase
            magnitude, phase = compute_magnitude_and_phase(stft_spec)

            input = torch.cat([magnitude, phase], dim=1).to(device)  # [batch_size, 2, freq_bins, time_frames]

            optimizer.zero_grad()
            output, mu, log_var = model(input)  # shape [batch_size, 2, freq_bins, time_frames]

            loss = model.loss_function(input, output, mu, log_var, input.shape[0])
            
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * waveform.size(0)

        # Compute average epoch loss for training
        train_loss /= len(train_loader.dataset)

        # Validation Loop
        model.eval()
        valid_loss = 0.0
        with torch.no_grad():
            for i, (waveform, _, _, _) in enumerate(valid_loader):
                waveform = waveform.to(device)
                stft_spec = stft_transform(waveform)
                
                # Extract magnitude and phase
                magnitude, phase = compute_magnitude_and_phase(stft_spec)
                input = torch.cat([magnitude, phase], dim=1).to(device)

                output, mu, log_var = model(input)
                loss = model.loss_function(input, output, mu, log_var, input.shape[0])
                valid_loss += loss.item() * waveform.size(0)

        # Compute average validation loss
        valid_loss /= len(valid_loader.dataset)        
        
        # Step the scheduler with validation loss
        scheduler.step(valid_loss)

        epoch_time = time.time() - start_epoch_time
        avg_epoch_time += epoch_time
        training_losses.append(train_loss)
        validation_losses.append(valid_loss)

        print(f"Epoch {epoch+1}, train_loss={train_loss}, valid_loss={valid_loss}, Time: {epoch_time:.2f}s")

    print("Training complete.")
    save_training_results({
        "train_losses": training_losses,
        "valid_loss": validation_losses,
        "num_epochs": num_epochs,
        "avg_epoch_time": avg_epoch_time / num_epochs,
        "learning_rate": learning_rate,
        "training_subset_size": training_subset_size,
        "batch_size": batch_size,
        "sample_rate": sample_rate,
        "n_fft": n_fft,
        "hop_length": hop_length,
        "input_height": input_height,
        "input_width": input_width,
        "latent_dim": latent_dim,
    }, "vae.json")

    torch.save(model.state_dict(), model_path)

    import matplotlib.pyplot as plt
    epochs = range(1, num_epochs + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, training_losses, label='Training Loss', marker='o')
    plt.plot(epochs, validation_losses, label='Validation Loss', marker='x')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()



Encoder:  Encoder(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(2, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=512000, out_features=256, bias=True)
  (fc2): Linear(in_features=512000, out_features=256, bias=True)
)
Decoder:  Decoder(
  (fc2): Linear(in_features=256, out_features=512000, bias=True)
  (unflatten): Unflatten(dim=1, unflattened_size=(64, 100, 80))
  (decoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)

Training:   0%|          | 0/30 [00:00<?, ?it/s]

### Reconstruction Accuracy

The **reconstruction accuracy** for each sample is computed as:

$$
\text{Accuracy}_i = \left(1 - \frac{\lVert x_i - \hat{x}_i \rVert_2^2}{\lVert x_i \rVert_2^2 + \varepsilon} \right) \times 100\%
$$

where:
- $x_i$ is the input for sample $i$ (of shape `[batch_size, 2, num_freq, time_frames]`),
- $\hat{x}_i$ is the reconstructed output from the model,
- $\lVert \cdot \rVert_2^2$ denotes the squared L2 norm (i.e., the sum of squared elements),
- $\varepsilon$ is a small constant to avoid division by zero.

This gives a value between 0% and 100% for each sample.

---

### Total Accuracy

To compute the **overall accuracy** across all samples in the dataset (or batch-wise using a DataLoader), we take the average of all per-sample accuracies:

$$
\text{Total Accuracy} = \frac{1}{N} \sum_{i=1}^{N} \text{Accuracy}_i
$$

where $N$ is the total number of samples evaluated.

This metric tells us how well the model reconstructs the inputs **on average** over the dataset.

In [3]:
def reconstruction_accuracy(x, x_hat):
    # Compute per-sample accuracy tensor of shape [B]
    numerator = torch.sum((x - x_hat) ** 2, dim=(1, 2, 3))
    denominator = torch.sum(x ** 2, dim=(1, 2, 3)) + 1e-8
    accuracy = 1 - numerator / denominator
    return torch.clamp(accuracy, min=0.0, max=1.0)


model.eval()
total_acc_sum = 0.0
total_samples = 0

with torch.no_grad():
    for (waveform, _, _, _) in test_loader:
        raw_waveform = waveform.to(device)  # shape: [B, 1, T]

        stft_spec = stft_transform(raw_waveform)
        magnitude, phase = compute_magnitude_and_phase(stft_spec)
        input = torch.cat([magnitude, phase], dim=1)

        output, _, _ = model(input)

        batch_acc = reconstruction_accuracy(input, output) 
        total_acc_sum += batch_acc.sum().item()
        total_samples += waveform.size(0)

test_accuracy = round((total_acc_sum / total_samples) * 100, 2)
print(f"Testing accuracy: {test_accuracy}%")

Testing accuracy: 49.16%


## Sampling from the VAE

This section demonstrates two types of sampling using the trained Variational Autoencoder (VAE):

---

### Posterior Sampling

We sample from the **posterior distribution** by encoding real data samples through the encoder, and then decoding the resulting latent representations.

Steps:
1. Take a batch of real waveforms from the training set.
2. Compute their STFT representations (magnitude and phase).
3. Concatenate both as input to the VAE.
4. The VAE reconstructs magnitude and phase via the decoder.
5. Convert reconstructed spectrograms back to waveforms using inverse STFT.

This gives us reconstructions of real samples, useful for qualitative inspection of the VAE's reconstruction ability.

---

### Prior Sampling

We sample directly from the **prior distribution** over the latent space, which is assumed to be standard Gaussian $ \mathcal{N}(0, I) $.

Steps:
1. Sample random latent vectors $ z \sim \mathcal{N}(0, I) $.
2. Pass these vectors through the decoder.
3. The output is a pair of magnitude and phase spectrograms.
4. Recombine into complex spectrograms and apply inverse STFT to get raw audio waveforms.

This tests how well the decoder has learned to generate plausible samples from the latent space.

---

Both types of samples are played back using IPython’s `Audio` widget for auditory inspection.


In [1]:
# ------------
# Sampling
# ------------
model.eval()
num_samples = 10  # Number of samples to generate

# POSTERIOR SAMPLING
print("Posterior sampling:")
print("-------------------\n")

# Get a batch from the train_loader (DataLoader is not subscriptable)
batch = next(iter(train_loader))
waveform, _, _, _ = batch

# Select the first num_samples from the batch
waveform = waveform[:num_samples]
waveform = waveform.to(device)
stft_spec = stft_transform(waveform)

magnitude, phase = compute_magnitude_and_phase(stft_spec)
input_tensor = torch.cat([magnitude, phase], dim=1).to(device)

with torch.no_grad():
    output, _, _ = model(input_tensor)

    recon_magnitude = output[:, 0, :, :]    # [batch, freq_bins, time_frames]
    recon_phase = output[:, 1, :, :]        # [batch, freq_bins, time_frames]

    recon_complex = recon_magnitude * torch.exp(1j * recon_phase)
    reconstructed_waveform = istft_transform(recon_complex)  # [batch, time_steps]

for i in range(reconstructed_waveform.shape[0]):
    print(f"\n=== Generated Sample {i+1} ===")
    display(Audio(reconstructed_waveform[i].squeeze().cpu().numpy(), rate=sample_rate))


# PRIOR SAMPLING
print("Prior sampling:")
print("---------------\n")

with torch.no_grad():
    # Sample from N(0, 1)
    z = torch.distributions.Normal(0, 1).sample(sample_shape=(num_samples, latent_dim)).to(device)
    
    # Pass the latent vectors through the decoder to generate audio
    with torch.no_grad():
        generated_output = model.decoder(z)
        generated_output = adjust_shape(generated_output, (input_height, input_width))

    # The generated output is in two channels (magnitude and phase), so split them
    gen_magnitude = generated_output[:, 0, :, :]
    gen_phase = generated_output[:, 1, :, :]
    
    # Recombine into a complex spectrogram
    gen_complex = gen_magnitude * torch.exp(1j * gen_phase)
    generated_waveforms = istft_transform(gen_complex) # [num_samples, time_steps]

# Play the generated samples
for i in range(generated_waveforms.shape[0]):
    print(f"\n=== Generated Sample {i+1} ===")
    display(Audio(generated_waveforms[i].squeeze().cpu().numpy(), rate=sample_rate))         

NameError: name 'model' is not defined