# Wavelet Generative Adversarial Networks - Decoder
## Initialisation and dataset preparation

First, let us import the required libraries.

In [152]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pywt
from misceallaneous import getWavFileAsNpArray, displaySpectrogram
from IPython.display import Audio
import librosa
import librosa.display

Then, let us include the dataset.

The dataset is made of two files: `clean/p1.wav`and `white/p1.wav` which are converted into arrays of `int32` and then split into segments of `samples_length`.

In [153]:
wavelet_family = "db38"

In [154]:
samplerate = 12000
nperseg = 1024

clean = getWavFileAsNpArray("../dataset_2/clean/p1.wav")
white = getWavFileAsNpArray("../dataset_2/white/p1.wav")
clean = np.array(clean, dtype="float32")
white = np.array(white, dtype="float32")

clean_dataset = []
white_dataset = []

samples_length = nperseg*2

for i in range(0, clean.shape[0]-samples_length, samples_length):
    clean_dataset.append(clean[i:i+samples_length])
    white_dataset.append(white[i:i+samples_length])
clean_dataset = np.array(clean_dataset)
white_dataset = np.array(white_dataset)

In [155]:
ex = clean_dataset[0]
ca, cd = pywt.dwt(ex, wavelet_family, "per")
data_shape = ca.shape

In [177]:
wavelet_clean_dataset = []
wavelet_white_dataset = []

for sample in clean_dataset:
    ca, cd = pywt.dwt(sample, wavelet_family, "per")
    wavelet_clean_dataset.append(np.concatenate((ca, cd)))
for sample in white_dataset:
    ca, cd = pywt.dwt(sample, wavelet_family, "per")
    wavelet_white_dataset.append(np.concatenate((ca, cd)))
    
max_clean = np.max(np.abs(wavelet_clean_dataset))
wavelet_clean_dataset = np.array(wavelet_clean_dataset)/(max_clean)

max_white = np.max(np.abs(wavelet_white_dataset))
wavelet_white_dataset = np.array(wavelet_white_dataset)/(max_white)

print(np.max(wavelet_white_dataset))
print(wavelet_white_dataset.shape)

0.790526
(5329, 2048)


In [178]:
def train_on_batch(d, i, o, validation_split=0, batch_size=16, verbose=True):
    history = d.fit(i, o, batch_size=batch_size, validation_split=validation_split, verbose=verbose)
    return np.mean(history.history['accuracy'])

In [245]:
def discriminator(sizes, lr):
    inputs = tf.keras.Input(shape=(sizes[1]))
    x = tf.keras.layers.Dropout(0.3)(inputs)
    x1 = tf.keras.layers.Dense(512, activation="tanh")(x)
    x2 = tf.keras.layers.Dense(64, activation="tanh")(x1)
    outputs = tf.keras.layers.Dense(1, activation="tanh")(x2)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="discriminator")
    model.summary()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss='mse', metrics=['accuracy'])
    return model

In [247]:
d = discriminator(wavelet_white_dataset.shape, lr=0.0001)
output = np.concatenate((np.zeros(wavelet_white_dataset.shape[0]), np.ones(wavelet_clean_dataset.shape[0])))
for e in range(250):
    train_on_batch(d, np.concatenate((wavelet_white_dataset, wavelet_clean_dataset)), output)

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_110 (InputLayer)       [(None, 2048)]            0         
_________________________________________________________________
dropout_109 (Dropout)        (None, 2048)              0         
_________________________________________________________________
dense_295 (Dense)            (None, 512)               1049088   
_________________________________________________________________
dense_296 (Dense)            (None, 64)                32832     
_________________________________________________________________
dense_297 (Dense)            (None, 1)                 65        
Total params: 1,081,985
Trainable params: 1,081,985
Non-trainable params: 0
_________________________________________________________________






In [249]:
d.predict(np.reshape(wavelet_white_dataset[0], (-1, 2048)))

array([[0.31221697]], dtype=float32)