# Wavelet Generative Adversarial Networks - Encoder
## Initialisation and dataset preparation

First, let us import the required libraries.

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pywt
from misceallaneous import getWavFileAsNpArray, displaySpectrogram
from IPython.display import Audio
import librosa
import librosa.display

Then, let us include the dataset.

The dataset is made of two files: `clean/p1.wav`and `white/p1.wav` which are converted into arrays of `int32` and then split into segments of `samples_length`.

In [None]:
wavelet_family = "morl"

In [None]:
def display_audio(low, detailed, g_low=None, g_detailed=None, p=0, i=0):
    audio = hear_audio(low, detailed, g_low, g_detailed, p, i)
    D = librosa.stft(audio)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    librosa.display.specshow(S_db, x_axis='time', y_axis='mel', sr=samplerate)
    #plt.show()
    plt.savefig(str(p)+".png", format='png')
    
def hear_audio(low, detailed, g_low=None, g_detailed=None, p=0, i=0):
    if g_low == None:
        data_low = low[i]*max_clean_low
    else:
        d_low = np.reshape(low[i], (-1, data_shape[0]))
        data_low = (g_low.predict(d_low)*max_clean_low)[0]
    if g_detailed == None:
        data_detailed = detailed[i]*max_clean_detailed
    else:
        d_detailed = np.reshape(detailed[i], (-1, data_shape[0]))
        data_detailed = (g_detailed.predict(d_detailed)*max_clean_detailed)[0]
    audio = pywt.idwt(data_low, data_detailed, wavelet_family, "per")
    return audio

def get_distance_audio(white, clean, g, n):
    res = 0
    for i in range(n):
        d = np.reshape(white[i], (-1, data_shape[0]))
        data = (g.predict(d))[0]
        res += np.sum((data-clean[i])**2)
    return res/n

In [None]:
samplerate = 12000
nperseg = 1024

clean = getWavFileAsNpArray("../dataset_2/clean/p1.wav")
white = getWavFileAsNpArray("../dataset_2/white/p1.wav")
clean = np.array(clean, dtype="float32")
white = np.array(white, dtype="float32")

clean_dataset = []
white_dataset = []

samples_length = nperseg*6

for i in range(0, clean.shape[0]-samples_length, samples_length):
    clean_dataset.append(clean[i:i+samples_length])
    white_dataset.append(white[i:i+samples_length])
clean_dataset = np.array(clean_dataset)
white_dataset = np.array(white_dataset)

In [None]:
ex = clean_dataset[0]
ca, cd = pywt.dwt(ex, wavelet_family, "per")
data_shape = ca.shape

In [None]:
wavelet_clean_dataset_low = []
wavelet_white_dataset_low = []
wavelet_clean_dataset_detailed = []
wavelet_white_dataset_detailed = []

for sample in clean_dataset:
    ca, cd = pywt.dwt(sample, wavelet_family, "per")
    wavelet_clean_dataset_low.append(ca)
    wavelet_clean_dataset_detailed.append(cd)
for sample in white_dataset:
    ca, cd = pywt.dwt(sample, wavelet_family, "per")
    wavelet_white_dataset_low.append(ca)
    wavelet_white_dataset_detailed.append(cd)

wavelet_clean_dataset_low = np.array(wavelet_clean_dataset_low)
wavelet_white_dataset_low = np.array(wavelet_white_dataset_low)
wavelet_clean_dataset_detailed = np.array(wavelet_clean_dataset_detailed)
wavelet_white_dataset_detailed = np.array(wavelet_white_dataset_detailed)

max_clean_low = np.max(np.abs(wavelet_clean_dataset_low))
wavelet_clean_dataset_low = (wavelet_clean_dataset_low)/(max_clean_low)

max_white_low = np.max(np.abs(wavelet_white_dataset_low))
wavelet_white_dataset_low = (wavelet_white_dataset_low)/(max_white_low)

max_clean_detailed = np.max(np.abs(wavelet_clean_dataset_detailed))
wavelet_clean_dataset_detailed = (wavelet_clean_dataset_detailed)/(max_clean_detailed)

max_white_detailed = np.max(np.abs(wavelet_white_dataset_detailed))
wavelet_white_dataset_detailed = (wavelet_white_dataset_detailed)/(max_white_detailed)

print(np.min(wavelet_white_dataset_detailed))

In [None]:
def train_on_batch(d, i, o, validation_split=0, batch_size=16, verbose=True):
    history = d.fit(i, o, batch_size=batch_size, validation_split=validation_split, verbose=verbose)
    return np.mean(history.history['accuracy'])

In [None]:
def generator_low(sizes, lr):
    inputs = tf.keras.Input(shape=(sizes[1]))
    x = tf.keras.layers.Dropout(0.2)(inputs)
    x1 = tf.keras.layers.Dense(512, activation="tanh")(x)
    x2 = tf.keras.layers.Dense(128, activation="relu")(x1)
    x3 = tf.keras.layers.Dense(512, activation="tanh")(x2)
    a = tf.keras.layers.Add()([x1, x3])
    x4 = tf.keras.layers.Dense(sizes[1], activation="tanh")(a)
    outputs = tf.keras.layers.Add()([inputs, x4])
    #outputs = tf.keras.layers.Dense(sizes[1], activation="tanh")(x1)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
    model.summary()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss='mse', metrics=['accuracy'])
    return model

In [None]:
p = 0
mse_low = 10000
mse_prev_low = 10000
mses_low = []

In [None]:
g_low = generator_low(wavelet_white_dataset_low.shape, lr=0.0001)
while mse_low >= 0.2:
    train_on_batch(g_low, wavelet_white_dataset_low, wavelet_clean_dataset_low, batch_size=4, verbose=True)
    mse_prev_low = mse_low
    mse_low = get_distance_audio(wavelet_white_dataset_low, wavelet_clean_dataset_low, g_low, 100)
    mses_low.append(mse_low)
    display_audio(wavelet_white_dataset_low, wavelet_clean_dataset_detailed, g_low=g_low, p=p)
    print("MSE_low:", mse_low)
    p += 1
plt.plot(mses_low)
#plt.plot(mses_detailed)
plt.show()

In [None]:
def generator_detailed(sizes, lr):
    inputs = tf.keras.Input(shape=(sizes[1]))
    x = tf.keras.layers.Dropout(0.3)(inputs)
    x1 = tf.keras.layers.Dense(sizes[1], activation="tanh")(x)
    x2 = tf.keras.layers.Dense(sizes[1], activation="tanh")(x1)
    outputs = tf.keras.layers.Add()([x, x2])
    #outputs = tf.keras.layers.Dense(sizes[1], activation="tanh")(x1)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
    model.summary()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss='mse', metrics=['accuracy'])
    return model

In [None]:
mse_detailed = 10000
mse_prev_detailed = 10000
mses_detailed = []

In [None]:
g_detailed = generator_detailed(wavelet_white_dataset_detailed.shape, lr=0.00005)
while mse_detailed >= 0.8:
    train_on_batch(g_detailed, wavelet_white_dataset_detailed, wavelet_clean_dataset_detailed, batch_size=4, verbose=True)
    mse_prev_detailed = mse_detailed
    mse_detailed = get_distance_audio(wavelet_white_dataset_detailed, wavelet_clean_dataset_detailed, g_detailed, 100)
    mses_detailed.append(mse_detailed)
    display_audio(wavelet_clean_dataset_low, wavelet_white_dataset_detailed, g_low=None, g_detailed=g_detailed, p=p)
    print("MSE_detailed:", mse_detailed)
    p += 1
#plt.plot(mses_low)
plt.plot(mses_detailed)
plt.show()

## Audio testing

In [89]:
res = []
for i in range(10):
    res.append(hear_audio(wavelet_white_dataset_low, wavelet_white_dataset_detailed, i=i))
audio = np.concatenate(res)
Audio(audio, rate=samplerate)

In [90]:
res = []
for i in range(10):
    res.append(hear_audio(wavelet_white_dataset_low, wavelet_white_dataset_detailed, g_low=g_low, g_detailed=g_detailed, i=i))
audio = np.concatenate(res)
Audio(audio, rate=samplerate)

In [17]:
res = []
for i in range(10):
    res.append(hear_audio(wavelet_clean_dataset_low, wavelet_clean_dataset_detailed, i=i))
audio = np.concatenate(res)
Audio(audio, rate=samplerate)