In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import soundfile as sf
import librosa

# Define paths to the wavs and flacs folders
wav_dir = "wavs"
flac_dir = "flacs"

# Function to load and preprocess audio files in batches
def audio_file_generator(audio_dir, batch_size, sample_rate=22050, duration=5):
    files = [file for file in os.listdir(audio_dir) if file.endswith(".wav") or file.endswith(".flac")]
    while True:
        batch_files = np.random.choice(files, batch_size)
        audio_batch = []
        for file in batch_files:
            file_path = os.path.join(audio_dir, file)
            audio, file_sample_rate = sf.read(file_path)
            # Resample if needed (optional)
            if file_sample_rate != sample_rate:
                audio = librosa.resample(audio, orig_sr=file_sample_rate, target_sr=sample_rate)
            # Pad or truncate audio to fixed length
            audio = pad_or_truncate(audio, int(sample_rate * duration))
            # Ensure audio is 2D: (length, 1) if mono, (length, channels) if stereo
            if len(audio.shape) == 1:
                audio = audio[:, np.newaxis]
            audio_batch.append(audio.astype(np.float32))  # Use float32 for memory efficiency
        yield np.array(audio_batch, dtype=np.float32)  # Use float32 for memory efficiency

def pad_or_truncate(audio, length):
    if len(audio) > length:
        return audio[:length]
    elif len(audio) < length:
        return np.pad(audio, (0, length - len(audio)), 'constant')
    else:
        return audio

# Create data generators
batch_size = 32
train_wav_generator = audio_file_generator(wav_dir, batch_size)
train_flac_generator = audio_file_generator(flac_dir, batch_size)

# Generator model
def build_generator(input_shape):
    model = models.Sequential([
        layers.InputLayer(input_shape=input_shape),
        layers.Conv1D(64, kernel_size=9, padding='same', activation='relu'),
        layers.Conv1D(128, kernel_size=9, padding='same', activation='relu'),
        layers.Conv1D(1, kernel_size=9, padding='same', activation='tanh')
    ])
    return model

# Discriminator model
def build_discriminator(input_shape):
    model = models.Sequential([
        layers.InputLayer(input_shape=input_shape),
        layers.Conv1D(64, kernel_size=9, strides=2, padding='same', activation='relu'),
        layers.Conv1D(128, kernel_size=9, strides=2, padding='same', activation='relu'),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Build and compile the GAN
def build_gan(generator, discriminator, input_shape):
    discriminator.compile(optimizer='adam', loss='binary_crossentropy')
    discriminator.trainable = False
    
    gan_input = layers.Input(shape=input_shape)
    generated_audio = generator(gan_input)
    gan_output = discriminator(generated_audio)
    
    gan = models.Model(gan_input, gan_output)
    gan.compile(optimizer='adam', loss='binary_crossentropy')
    return gan

# Define the input shape based on sample rate and duration
input_shape = (22050 * 5, 1)  # 22050 samples/second * 5 seconds, 1 channel

# Initialize models
generator = build_generator(input_shape)
discriminator = build_discriminator(input_shape)
gan = build_gan(generator, discriminator, input_shape)

# Training parameters
epochs = 5000

# Training loop
for epoch in range(epochs):
    # Load a batch of wav files
    real_wav = next(train_wav_generator)
    
    # Generate compressed audio
    generated_flac = generator.predict(real_wav)
    
    # Load corresponding flac files
    real_flac = next(train_flac_generator)

    # Debug shapes
    print(f"Epoch {epoch}")
    print("Generated audio shape:", generated_flac.shape)
    print("Real flac shape:", real_flac.shape)

    # Ensure shapes match
    if generated_flac.shape != real_flac.shape:
        print(f"Shape mismatch: Generated shape {generated_flac.shape}, Real shape {real_flac.shape}")
        # Reshape real_flac to ensure it matches generated_flac shape
        if len(real_flac.shape) == 2:
            real_flac = real_flac[:, :, np.newaxis]
        # Ensure both have the same length
        min_length = min(generated_flac.shape[1], real_flac.shape[1])
        generated_flac = generated_flac[:, :min_length, :]
        real_flac = real_flac[:, :min_length, :]
        # Check again
        if generated_flac.shape != real_flac.shape:
            raise ValueError("Shape mismatch between generated and real audio after adjustment")

    # Train discriminator
    d_loss_real = discriminator.train_on_batch(real_flac, np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(generated_flac, np.zeros((batch_size, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
    # Train generator
    g_loss = gan.train_on_batch(real_wav, np.ones((batch_size, 1)))
    
    # Print progress
    if epoch % 100 == 0:
        print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {d_loss}, Generator Loss: {g_loss}")

# Save the models
generator.save('generator.h5')
discriminator.save('discriminator.h5')




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 0
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)




Epoch 0/50, Discriminator Loss: 0.6898515224456787, Generator Loss: [array(0.691087, dtype=float32), array(0.691087, dtype=float32)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 1
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 2
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 3
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 4
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Epoch 5
Generated audio shape: (32, 110250, 1)
Real flac shape: (32, 110250, 1)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
Epoch 6
Generated audio shape: (32, 110250, 1)


