In [None]:
# This script is still work in progress. 

import numpy as np
import librosa
import librosa.display
import soundfile as sf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import os
import glob

# Load real audio data
def load_audio_data(file_paths, sample_rate=16000, duration=1.0):
    audio_data = []
    for file_path in file_paths:
        y, sr = librosa.load(file_path, sr=None, duration=duration)  # Load with original sample rate
        if sr != sample_rate:  # Resample if the sample rate is not 16,000 Hz
            y = librosa.resample(y, orig_sr=sr, target_sr=sample_rate)
        if len(y) < sample_rate:  # Pad if shorter than 1 second
            y = np.pad(y, (0, sample_rate - len(y)), mode='constant')
        audio_data.append(y)
    return np.array(audio_data)

# Load all .wav files from the data folder
def load_audio_data_from_folder(folder_path, sample_rate=16000, duration=1.0):
    audio_data = []
    file_paths = glob.glob(os.path.join(folder_path, "*.wav"))  # Find all .wav files in the folder
    for file_path in file_paths:
        y, sr = librosa.load(file_path, sr=None, duration=duration)  # Load with original sample rate
        if sr != sample_rate:  # Resample if the sample rate is not 16,000 Hz
            y = librosa.resample(y, orig_sr=sr, target_sr=sample_rate)
        if len(y) < sample_rate:  # Pad if shorter than 1 second
            y = np.pad(y, (0, sample_rate - len(y)), mode='constant')
        audio_data.append(y)
    return np.array(audio_data), file_paths

# Generator model
def build_generator(latent_dim, output_dim):
    model = Sequential([
        Dense(256, input_dim=latent_dim),
        LeakyReLU(alpha=0.2),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(output_dim, activation='tanh')  # Output raw waveform
    ])
    return model

# Discriminator model
def build_discriminator(input_dim):
    model = Sequential([
        Dense(512, input_dim=input_dim),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')  # Binary classification
    ])
    return model

# GAN training
# Initialize lists to store loss values
d_losses = []
g_losses = []

# Modify the train_gan function to store losses
def train_gan(generator, discriminator, gan, real_data, latent_dim, epochs=1000, batch_size=32):
    for epoch in range(epochs):
        # Train discriminator
        real_samples = real_data[np.random.randint(0, real_data.shape[0], batch_size)]
        noise = np.random.randn(batch_size, latent_dim)
        fake_samples = generator.predict(noise)
        labels_real = np.ones((batch_size, 1))
        labels_fake = np.zeros((batch_size, 1))
        d_loss_real = discriminator.train_on_batch(real_samples, labels_real)
        d_loss_fake = discriminator.train_on_batch(fake_samples, labels_fake)

        # Train generator
        noise = np.random.randn(batch_size, latent_dim)
        labels_gan = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(noise, labels_gan)

        # Store losses
        d_losses.append(d_loss_real + d_loss_fake)
        g_losses.append(g_loss)

        if epoch % 100 == 0:
            print(f"Epoch {epoch}, D Loss: {d_loss_real + d_loss_fake}, G Loss: {g_loss}")

# After training, plot the losses
plt.plot(d_losses, label="Discriminator Loss")
plt.plot(g_losses, label="Generator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Save generated audio
def save_audio(waveform, sample_rate, file_name):
    sf.write(file_name, waveform, sample_rate)
    print(f"Generated audio saved to {file_name}")

# Example usage
if __name__ == "__main__":
    # Specify the folder containing audio files
    data_folder = "../data/raw"
    sample_rate = 16000

    # Load all .wav files from the folder
    real_audio, audio_files = load_audio_data_from_folder(data_folder, sample_rate=sample_rate)

    # Print loaded file paths for verification
    print(f"Loaded {len(audio_files)} audio files:")
    for file in audio_files:
        print(file)

    # Normalize real audio to [-1, 1]
    real_audio = real_audio / np.max(np.abs(real_audio), axis=1, keepdims=True)

    # GAN setup
    latent_dim = 100
    generator = build_generator(latent_dim, output_dim=sample_rate)
    discriminator = build_discriminator(input_dim=sample_rate)
    discriminator.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')
    gan = Sequential([generator, discriminator])
    discriminator.trainable = False
    gan.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')

    # Train GAN
    train_gan(generator, discriminator, gan, real_audio, latent_dim, epochs=1000, batch_size=16)

    # Generate audio
    noise = np.random.randn(1, latent_dim)
    generated_audio = generator.predict(noise)[0]
    generated_audio = generated_audio / np.max(np.abs(generated_audio))  # Normalize to [-1, 1]

    # Save generated audio to a .wav file
    save_audio(generated_audio, sample_rate, "../output/generated_audio_GAN.wav")