In [None]:
%run init_notebook.py

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
import time
import torchaudio
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from IPython.display import Audio, display

from src.dataset import NSynth   
from src.models import CVAE
from src.utils.models import adjust_shape
from src.utils.logger import save_training_results
from src.utils.models import compute_magnitude_and_phase

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = r"C:\Users\Articuno\Desktop\TFG-info\data\models\cvae.pth"

# STFT transform
sample_rate = 16000
n_fft = 800
hop_length = 100
win_length = n_fft  # Same as n_fft

# Apply the correct transform for magnitude and phase (onesided=False to handle complex spectrogram)
stft_transform = T.Spectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    power=None,  # Keep as complex spectrogram (magnitude and phase)
    onesided=False,  # Make sure we keep the full spectrum (complex-valued)
    center=False
).to(device)

# Define the inverse STFT function (onesided=False)
istft_transform = torchaudio.transforms.InverseSpectrogram(
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    onesided=False  # Make sure we reconstruct using the full complex spectrogram
).to(device)

# Datasets and DataLoaders, training parameters.
train_dataset = NSynth(partition='training')
valid_dataset = NSynth(partition='validation')
test_dataset  = NSynth(partition='testing')

batch_size = 64
training_subset_size = len(train_dataset)
train_dataset = Subset(train_dataset, list(range(training_subset_size)))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

input_height = 800 
input_width  = 633
latent_dim   = 128
learning_rate = 1e-4

NUM_INSTRUMENT_CLASSES = 11

model = CVAE((input_height, input_width), latent_dim, NUM_INSTRUMENT_CLASSES).to(device)
if os.path.exists(model_path):
    try:
        model.load_state_dict(torch.load(model_path))
        print("Model loaded successfully.")
    except PermissionError as e:
        print(f"PermissionError: {e}. Unable to load the model.")
else:
    print("No saved model found, starting training from scratch.")

    # Model, Optimizer, and Loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training Loop
    num_epochs = 30
    log_interval = 10
    avg_epoch_time = 0.0

    training_losses = []
    validation_losses = []

    # Learning rate scheduler to reduce learning rate based on validation loss
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

    print(f"Starting training on {device}...")
    for epoch in tqdm(range(num_epochs), desc="Training"):

        model.train()
        start_epoch_time = time.time()
        train_loss = 0.0
        
        # Training Loop
        for i, (waveform, _, _, metadata) in enumerate(train_loader):
            # waveform has shape [batch_size, 1, time_steps]

            print("Metadata is: " , metadata)
            exit(0)

            waveform = waveform.to(device)

            # Apply STFT transformation
            stft_spec = stft_transform(waveform)  # [batch_size, 1, freq_bins, time_frames]

            # Extract magnitude and phase
            magnitude, phase = compute_magnitude_and_phase(stft_spec)

            input = torch.cat([magnitude, phase], dim=1).to(device)  # [batch_size, 2, freq_bins, time_frames]
            labels = torch.stack([meta for meta in metadata]).to(device) # Each meta is a one-hot encoded tensor

            optimizer.zero_grad()
            output, mu, log_var = model(input, labels)  # shape [batch_size, 2, freq_bins, time_frames]

            loss = model.loss_function(input, output, mu, log_var, input.shape[0])
            
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * waveform.size(0)

        # Compute average epoch loss for training
        train_loss /= len(train_loader.dataset)

        # Validation Loop
        model.eval()
        valid_loss = 0.0
        with torch.no_grad():
            for i, (waveform, _, _, metadata) in enumerate(valid_loader):
                waveform = waveform.to(device)
                stft_spec = stft_transform(waveform)
                
                # Extract magnitude and phase
                magnitude, phase = compute_magnitude_and_phase(stft_spec)
                input = torch.cat([magnitude, phase], dim=1).to(device)
                labels = torch.stack([one_hot for one_hot in metadata['one_hot_instrument']]).to(device)

                output, mu, log_var = model(input, labels)
                loss = model.loss_function(input, output, mu, log_var, input.shape[0])
                valid_loss += loss.item() * waveform.size(0)

        # Compute average validation loss
        valid_loss /= len(valid_loader.dataset)        
        
        # Step the scheduler with validation loss
        scheduler.step(valid_loss)

        epoch_time = time.time() - start_epoch_time
        avg_epoch_time += epoch_time
        training_losses.append(train_loss)
        validation_losses.append(valid_loss)

        print(f"Epoch {epoch+1}, train_loss={train_loss}, valid_loss={valid_loss}, Time: {epoch_time:.2f}s")

    print("Training complete.")
    save_training_results({
        "train_losses": training_losses,
        "valid_loss": validation_losses,
        "num_epochs": num_epochs,
        "avg_epoch_time": avg_epoch_time / num_epochs,
        "learning_rate": learning_rate,
        "training_subset_size": training_subset_size,
        "batch_size": batch_size,
        "sample_rate": sample_rate,
        "n_fft": n_fft,
        "hop_length": hop_length,
        "input_height": input_height,
        "input_width": input_width,
        "latent_dim": latent_dim,
    }, "cvae.json")

    torch.save(model.state_dict(), model_path)

    import matplotlib.pyplot as plt
    epochs = range(1, num_epochs + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, training_losses, label='Training Loss', marker='o')
    plt.plot(epochs, validation_losses, label='Validation Loss', marker='x')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()



Encoder: Encoder(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(2, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=512000, out_features=128, bias=True)
  (fc2): Linear(in_features=512000, out_features=128, bias=True)
)
Decoder: Decoder(
  (fc): Linear(in_features=128, out_features=512000, bias=True)
  (unflatten): Unflatten(dim=1, unflattened_size=(64, 100, 80))
  (decoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 

Training:   0%|          | 0/30 [00:00<?, ?it/s]

Metadata is:  {'one_hot_instrument': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 

Training:   0%|          | 0/30 [42:51<?, ?it/s]


TypeError: expected Tensor as element 0 in argument 0, but got str

: 

TESTING

In [None]:
from src.utils.dataset import INSTRUMENT_ID_2_STR, NUM_INSTRUMENTS

# Posterior Sampling (conditioned on instrument)
model.eval()
num_samples = 3  # Number of samples per instrument

print("Posterior sampling (conditioned on instrument):")
print("---------------------------------------------\n")

# Loop over all instruments
for instrument_id in range(NUM_INSTRUMENTS):  # 11 instruments in total
    print(f"\nGenerating samples for instrument: {INSTRUMENT_ID_2_STR[instrument_id]}")

    # Get a batch from the train_loader (DataLoader is not subscriptable)
    batch = next(iter(train_loader))
    waveform, _, _, metadata = batch

    # Select the first num_samples from the batch
    waveform = waveform[:num_samples]
    waveform = waveform.to(device)
    stft_spec = stft_transform(waveform)

    magnitude, phase = compute_magnitude_and_phase(stft_spec)
    input_tensor = torch.cat([magnitude, phase], dim=1).to(device)

    # Get the one-hot encoded label for the desired instrument (e.g., guitar)
    labels = torch.stack([torch.tensor(meta["one_hot_instrument"], dtype=torch.float) for meta in metadata]).to(device)

    # Select the one-hot label for the specific instrument_id (e.g., guitar)
    selected_label = labels[:num_samples, instrument_id].unsqueeze(1).to(device)

    with torch.no_grad():
        output, _, _ = model(input_tensor, selected_label)

        recon_magnitude = output[:, 0, :, :]    # [batch, freq_bins, time_frames]
        recon_phase = output[:, 1, :, :]        # [batch, freq_bins, time_frames]

        recon_complex = recon_magnitude * torch.exp(1j * recon_phase)
        reconstructed_waveform = istft_transform(recon_complex)  # [batch, time_steps]

    # Play the generated samples
    for i in range(reconstructed_waveform.shape[0]):
        print(f"\n=== Generated Sample {i+1} for {INSTRUMENT_ID_2_STR[instrument_id]} ===")
        display(Audio(reconstructed_waveform[i].squeeze().cpu().numpy(), rate=sample_rate))



# Prior Sampling (conditioned on instrument)
model.eval()
num_samples = 3  # Number of samples per instrument

print("Prior sampling (conditioned on instrument):")
print("----------------------------------------\n")

# Loop over all instruments
for instrument_id in range(NUM_INSTRUMENTS):  # 11 instruments in total
    print(f"\nGenerating samples for instrument: {INSTRUMENT_ID_2_STR[instrument_id]}")

    # Get the one-hot encoded label for the desired instrument (e.g., guitar)
    instrument_label = torch.zeros(num_samples, NUM_INSTRUMENTS).to(device)
    instrument_label[:, instrument_id] = 1  # Set the label for the specific instrument (e.g., guitar)

    with torch.no_grad():
        # Sample from N(0, 1) for the latent space
        z = torch.distributions.Normal(0, 1).sample(sample_shape=(num_samples, latent_dim)).to(device)

        # Pass the latent vectors and instrument label through the decoder to generate audio
        generated_output = model.decoder(z, instrument_label)
        generated_output = adjust_shape(generated_output, (input_height, input_width))

        # The generated output is in two channels (magnitude and phase), so split them
        gen_magnitude = generated_output[:, 0, :, :]
        gen_phase = generated_output[:, 1, :, :]

        # Recombine into a complex spectrogram
        gen_complex = gen_magnitude * torch.exp(1j * gen_phase)
        generated_waveforms = istft_transform(gen_complex)  # [num_samples, time_steps]

    # Play the generated samples
    for i in range(generated_waveforms.shape[0]):
        print(f"\n=== Generated Sample {i+1} for {INSTRUMENT_ID_2_STR[instrument_id]} ===")
        display(Audio(generated_waveforms[i].squeeze().cpu().numpy(), rate=sample_rate))

