## Dataset

In [1]:
import tensorflow as tf
import numpy as np
import os
%cd ..


from data.dataset import ini_dataset
from loss.loss_functions import generator_loss, discriminator_loss
from train.steps import train
from model.GAN import get_generator, get_discriminator, GANMonitor, GAN

c:\Users\victo\Documents\Projets IA\drums_sounds_generator


In [2]:
data_path = 'D:\data_drum_sounds'
SR = 16000
BATCH_SIZE = 32
EPOCHS = 20
LATENT_DIM = 256
LEARNING_RATE = 3e-4

data = np.load("D:\data_drum_sounds\data_kick_hihat_snare_tom_ride_clap_others_16kHz.npz")['data']
SIGNAL_SHAPE = data[0,:,:].shape

print('Data shape :', data.shape)
#dataset, categories = ini_dataset("D:\data_drum_sounds\data_kick_16kHz.npz", BATCH_SIZE)

Data shape : (9885, 16384, 1)


In [3]:
generator = get_generator(latent_dim=LATENT_DIM, n_layers=8)
noise = np.random.rand(4,256)
waveform = generator(noise)

print('generator output shape : ', waveform.shape)
print(generator.summary())

generator output shape :  (4, 16384, 1)
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 256)]             0         
_________________________________________________________________
dense (Dense)                (None, 2048)              524288    
_________________________________________________________________
batch_normalization (BatchNo (None, 2048)              8192      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 2048)              0         
_________________________________________________________________
reshape (Reshape)            (None, 8, 256)            0         
_________________________________________________________________
conv1d_transpose (Conv1DTran (None, 16, 256)           196608    
_________________________________________________________________
batch_normalizati

In [4]:
discriminator = get_discriminator(SIGNAL_SHAPE, n_layers=10)
x = np.random.rand(4,16384,1)
pred = discriminator(x)

print('disciminator output shape : ', pred.shape)
print(discriminator.summary())

disciminator output shape :  (4, 1)
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 16384, 1)]        0         
_________________________________________________________________
conv1d (Conv1D)              (None, 4096, 1)           9         
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 4096, 1)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 1024, 2)           18        
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 1024, 2)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 256, 4)            68        
_________________________________________________________________
leaky_re_lu_11 (L

In [5]:
generator_optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = tf.keras.optimizers.Adam(
    learning_rate=LEARNING_RATE, beta_1=0.5, beta_2=0.9
)

def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss

def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)

callbacks = GANMonitor(num_sounds=2, latent_dim=LATENT_DIM, sr=SR)

gan = GAN(
    discriminator=discriminator,
    generator=generator,
    latent_dim=LATENT_DIM,
    discriminator_extra_steps=2,
)

gan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

In [6]:
gan.fit(data, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=[callbacks])

Epoch 1/20
Epoch 2/20
Epoch 3/20