In [100]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import os
from torch.utils.data import Dataset, DataLoader
import sys

In [101]:
# Define the VAE Encoder
class VAEEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAEEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # Log variance

    def forward(self, x):
        h1 = torch.relu(self.fc1(x))  # Apply ReLU activation
        z_mean = self.fc21(h1)         # Mean of the latent space
        z_log_var = self.fc22(h1)      # Log variance of the latent space
        return z_mean, z_log_var

# Define the VAE Decoder (Generator)
class VAEDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(VAEDecoder, self).__init__()
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h3 = torch.relu(self.fc3(z))   # Apply ReLU activation
        return torch.tanh(self.fc4(h3))  # Output layer with tanh activation

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        h1 = torch.relu(self.fc1(x))   # Apply ReLU activation
        return torch.sigmoid(self.fc2(h1))  # Output layer with sigmoid activation



In [102]:
# Custom Dataset for loading .wav files
class AudioDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.file_paths[idx])
        
        # Ensure the audio is exactly 1 second (48000 samples)
        if waveform.size(1) < 48000:
            waveform = torch.nn.functional.pad(waveform, (0, 48000 - waveform.size(1)))
        elif waveform.size(1) > 48000:
            waveform = waveform[:, :48000]
            
        waveform = np.float32(waveform)
        waveform = (waveform - min(waveform)) / (max(waveform) - min(waveform)+sys.float_info.epsilon) * 2 - 1 
        waveform = waveform/np.max(waveform)
        return torch.tensor(waveform,dtype=torch.float32)


In [103]:
# Reparameterization trick for VAE
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

In [104]:
# Path to your .wav files directory
wav_directory = 'C:/Users/Acer/work/git/AudioMNIST/data/01'  # Change this to your directory with .wav files
file_paths = [os.path.join(wav_directory, f) for f in os.listdir(wav_directory) if f.endswith('.wav')]

# Create dataset and dataloader
dataset = AudioDataset(file_paths)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

input_dim = 48000  # For 1 second of audio at 48 kHz
hidden_dim = 256
latent_dim = 64

vae_encoder = VAEEncoder(input_dim, hidden_dim, latent_dim)
vae_decoder = VAEDecoder(latent_dim, hidden_dim, input_dim)
discriminator = Discriminator(input_dim)

In [105]:
optimizer_vae = optim.Adam(list(vae_encoder.parameters()) + list(vae_decoder.parameters()), lr=0.001)
optimizer_disc = optim.Adam(discriminator.parameters(), lr=0.001)
kl_losses  = []
vae_gan_losses = []
disc_losses = []
batch_count = 0
for epoch in range(12):
    for real_data in data_loader:
        # real_data is already in the correct shape (batch_size, input_dim)
        # Train Discriminator
        optimizer_disc.zero_grad()
        mu, logvar = vae_encoder(real_data)
        z = reparameterize(mu, logvar)
        fake_data = vae_decoder(z)

        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.detach())
        loss_disc = -torch.mean(torch.log(disc_real) + torch.log(1 - disc_fake))
        disc_losses.append(loss_disc.item())
        loss_disc.backward()
        optimizer_disc.step()

        # Train VAE-GAN
        optimizer_vae.zero_grad()
        disc_fake = discriminator(fake_data)
        loss_vae_gan = -torch.mean(torch.log(disc_fake))  # Adversarial loss
        vae_gan_losses.append(loss_vae_gan.item())
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())  # KL divergence
        kl_losses.append(kl_loss.item())
        loss_vae_total = loss_vae_gan + kl_loss
        loss_vae_total.backward()
        optimizer_vae.step()
        if batch_count % 10 == 0:
            print(f'Epoch: {epoch}, VAE Loss: {loss_vae_gan.item()}, Discriminator Loss: {loss_disc.item()}, KL Loss {kl_loss.item()}')
        batch_count +=1

Epoch: 0, VAE Loss: 7.771003723144531, Discriminator Loss: 1.3903833627700806, KL Loss 0.03726758435368538
Epoch: 0, VAE Loss: 45.461753845214844, Discriminator Loss: 5.587935891782081e-09, KL Loss 4836209.0
Epoch: 1, VAE Loss: nan, Discriminator Loss: nan, KL Loss nan
Epoch: 1, VAE Loss: nan, Discriminator Loss: nan, KL Loss nan
Epoch: 2, VAE Loss: nan, Discriminator Loss: nan, KL Loss nan
Epoch: 3, VAE Loss: nan, Discriminator Loss: nan, KL Loss nan


KeyboardInterrupt: 

In [None]:
# Plot the losses
plt.figure(figsize=(12, 6))
plt.plot(disc_losses, label='Discriminator Loss')
plt.plot(kl_losses, label='KL Losses')
plt.plot(vae_gan_losses, label='VAE GAN Loss')
plt.title('VAE-GAN Training Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.show()        

In [None]:
import matplotlib.pyplot as plt
vae_decoder.eval()
with torch.no_grad():
    z =  torch.tanh(torch.randn(1, latent_dim))
    generated_audio = vae_decoder(torch.randn(1, latent_dim)).detach().numpy().flatten()
print(generated_audio.shape)
# Plot the generated audio signal
plt.figure(figsize=(12, 6))
plt.plot(generated_audio, label='Generated Audio Signal', alpha=0.5)
plt.title('Generated Audio Signal')
plt.xlabel('Sample Index')
plt.ylabel('Amplitude')
plt.legend()
plt.show()

In [None]:
from scipy.io.wavfile import write
import numpy as np
sample_rate = 48000
# Assuming the array values are in the range of int16 for WAV format
# Scale the array if necessary
generated_audio = np.int16(generated_audio / np.max(np.abs(generated_audio)) * 32767)
# Save the array as a WAV file
write('generated/generated.wav', sample_rate, generated_audio) 