# Convolutionnal Generative Adversarial Networks
## Initialisation and dataset preparation

First, let us import the required libraries.

In [1]:
import tensorflow as tf
import numpy as np
from scipy.io import wavfile
from scipy.signal import spectrogram, stft, istft
import matplotlib.pyplot as plt
import librosa
import librosa.display
from misceallaneous import getWavFileAsNpArray, displaySpectrogram
from IPython.display import Audio

Then, let us include the dataset.

The dataset is made of two files: `clean/p1.wav`and `white/p1.wav` which are converted into arrays of `int32` and then split into segments of `samples_length`.

The goal of the CGAN here is to predict the clean sample, when fed with the white one.

In [2]:
samplerate = 12000
nperseg = 1024

clean = getWavFileAsNpArray("../dataset_2/clean/p1.wav")
white = getWavFileAsNpArray("../dataset_2/white/p1.wav")
clean = np.array(clean, dtype="int32")
white = np.array(white, dtype="int32")

clean_dataset = []
white_dataset = []

samples_length = nperseg

for i in range(0, clean.shape[0]-samples_length, samples_length):
    clean_dataset.append(clean[i:i+samples_length])
    white_dataset.append(white[i:i+samples_length])
clean_dataset = np.array(clean_dataset)
white_dataset = np.array(white_dataset)

In [3]:
stft_clean_dataset_real = []
stft_clean_dataset_imag = []
stft_white_dataset_real = []
stft_white_dataset_imag = []

for i in clean_dataset:
    c, t, inp = stft(i, fs=samplerate, nperseg=nperseg)
    stft_clean_dataset_real.append(np.real(inp).T)
    stft_clean_dataset_imag.append(np.imag(inp).T)
    
for i in white_dataset:
    c, t, inp = stft(i, fs=samplerate, nperseg=nperseg)
    stft_white_dataset_real.append(np.real(inp).T)
    stft_white_dataset_imag.append(np.imag(inp).T)

stft_clean_dataset_real = np.array(stft_clean_dataset_real)
stft_clean_dataset_imag = np.array(stft_clean_dataset_imag)
stft_white_dataset_real = np.array(stft_white_dataset_real)
stft_white_dataset_imag = np.array(stft_white_dataset_imag)
print(stft_clean_dataset_real.shape, stft_clean_dataset_imag.shape, stft_white_dataset_real.shape, stft_white_dataset_imag.shape)

(10659, 3, 513) (10659, 3, 513) (10659, 3, 513) (10659, 3, 513)


# CGAN Model
The main idea of a GAN model is to create two networks who play an adversarial game:
- A Generator, whose goal is to produce the most realistic samples possible to fool the Discriminator
- A Discriminator, whose goal is to correctly guess if its input is a real sample from the clean dataset or an output created by the Generator

### Discriminator

The discriminator here uses a layer to process the Short-Time Fourier Transform (https://en.wikipedia.org/wiki/Short-time_Fourier_transform) before reducing the problem dimension to one single boolean prediction layer.

In [4]:
def discriminator(input_shape):
    inputs = tf.keras.Input(shape=(input_shape[1], input_shape[2]))
    x2 = tf.keras.layers.Dense(512, activation="tanh")(inputs)
    x3 = tf.keras.layers.Dense(256, activation="tanh")(x2)
    x4 = tf.keras.layers.Dense(128, activation="tanh")(x3)
    x5 = tf.keras.layers.Dense(1, activation="tanh")(x4)
    x6 = tf.keras.layers.Flatten()(x5)
    outputs = tf.keras.layers.Dense(1, activation="tanh")(x6)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="discriminator")
    model.summary()
    model.compile(optimizer= 'adam', loss='mse', metrics=['accuracy'])
    return model

## Generator
The generator itself is a Convolutionnal Autoencoder.

Its input size and output size are both the size of the stft array.

In [20]:
def generator(sizes):
    inputs = tf.keras.Input(shape=(sizes[1], sizes[2]))
    x1 = tf.keras.layers.Dense(128, activation='tanh')(inputs)
    x2 = tf.keras.layers.Dense(32, activation='tanh')(x1)
    x3 = tf.keras.layers.Dense(128, activation='tanh')(x2)
    x4 = tf.keras.layers.Dense(sizes[2], activation='tanh')(x3)
    x5 = tf.keras.layers.Add()([inputs, x4])
    outputs = tf.keras.layers.Dense(sizes[2], activation='linear')(x5)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")
    model.summary()
    model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
    return model

Take care, the distance between the raw audio might be 'too continuous' to use a classical distance function. Maybe, use the distance function on the STFT's, or another loss function on the raw audio.

In [6]:
def evaluate_generator(g, inputs, outputs, size):
    res = 0
    s = min(size, inputs.shape[0])
    for i in range(s):
        expected = outputs[i]
        c, t, inp = stft(inputs[i], fs=samplerate, nperseg=nperseg)
        inp_real = np.real(inp)
        inp_imag = np.imag(inp)
        y = np.reshape(inp_real.T, (-1, inp_real.shape[1], inp_real.shape[0]))
        prediction = np.reshape(g.predict(y), (inp_real.shape[0], inp_real.shape[1]))
        t, y1 = istft(prediction + np.imag(inp_imag))
        res += (sum((expected-y1)**2)/expected.shape[0])
    return res/(s*10000000)

## Building the GAN

In [7]:
def get_generator_outputs(white, train_size, g, nperseg, clean):
    rng = np.random.default_rng()
    g_outputs = []
    batch = rng.choice(white, train_size)
    for i in range(train_size):
        t = np.reshape(white[i, :, :], (-1, white.shape[1], white.shape[2]))
        m = g.predict(t)
        g_outputs.append(m)
    g_outputs = np.reshape(np.array(g_outputs), (train_size,  white.shape[1], white.shape[2]))
    input_data = np.concatenate((g_outputs, clean[:train_size,]))
    output_data = np.concatenate((-1*np.ones((train_size,)), np.ones((train_size,))))
    return input_data, output_data

In [44]:
class GAN:
    def __init__(self, size, g, d):
        self.g = g
        self.d = d
        self.z = tf.keras.layers.Input(shape=(size[1],size[2]))
        self.image = self.g(self.z)
        self.valid = self.d(self.image)
        self.combined_network = tf.keras.Model(self.z, self.valid)
        self.compile()
    def block_discriminator(self):
        self.d.trainable = False
        self.g.trainable = True
        self.compile()
    def block_generator(self):
        self.g.trainable = False
        self.d.trainable = True
        self.compile()
    def compile(self):
        self.combined_network.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
        #combined_network.summary()

g = generator(stft_white_dataset_real.shape)
d = discriminator(stft_white_dataset_real.shape)

Model: "autoencoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_17 (InputLayer)           [(None, 3, 513)]     0                                            
__________________________________________________________________________________________________
dense_38 (Dense)                (None, 3, 128)       65792       input_17[0][0]                   
__________________________________________________________________________________________________
dense_39 (Dense)                (None, 3, 32)        4128        dense_38[0][0]                   
__________________________________________________________________________________________________
dense_40 (Dense)                (None, 3, 128)       4224        dense_39[0][0]                   
________________________________________________________________________________________

In [45]:
def train_on_batch(d, i, o, verbose=True):  
    history = d.fit(i, o, batch_size=16, verbose=verbose)
    return np.mean(history.history['accuracy'])

In [46]:
train_size = 10000

In [None]:
gan = GAN(stft_white_dataset_real.shape, g, d)
disc_acc = []
gen_acc = []

for step in range(10):
    g_accuracy = 0
    d_accuracy = 0
    print("Step", step)
    if d_accuracy < 1:
        i, o = get_generator_outputs(stft_white_dataset_real, train_size, gan.g, nperseg, stft_clean_dataset_real)
    gan.block_generator()
    print("Training the discriminator")
    while d_accuracy < 0.5:
        d_accuracy = train_on_batch(gan.d, i, o, verbose=False)
        print(d_accuracy)
        disc_acc.append(d_accuracy)
        gen_acc.append(0)
    gan.block_discriminator()
    print("Training the generator")
    while  (g_accuracy <= 0.95 or g_accuracy <= d_accuracy) and g_accuracy < 1:
        g_accuracy = train_on_batch(gan.combined_network, stft_white_dataset_real[:train_size,], np.ones(train_size), verbose=False)
        print(g_accuracy)
        gen_acc.append(g_accuracy)
        disc_acc.append(0)
plt.plot(disc_acc)
plt.plot(gen_acc)
plt.show()

Step 0
Training the discriminator
0.003949999809265137
0.4874500036239624
0.5
Training the generator
0.5580999851226807
0.984499990940094
Step 1
Training the discriminator
0.4918999969959259
0.5
Training the generator
0.3059999942779541
0.9323999881744385
0.9800999760627747
Step 2
Training the discriminator
0.4941500127315521
0.4999000132083893
0.5
Training the generator
0.1671999990940094
0.8517000079154968
0.9584000110626221
Step 3
Training the discriminator
0.4939500093460083
0.5
Training the generator
0.0017000000225380063
0.26649999618530273
0.7792999744415283
0.9083999991416931
0.9434999823570251
0.9609000086784363
Step 4
Training the discriminator
0.49334999918937683
0.49970000982284546
0.5
Training the generator
0.0
0.021800000220537186
0.4203999936580658
0.7276999950408936
0.8367000222206116
0.8944000005722046
0.9192000031471252
0.9381999969482422
0.9502999782562256
Step 5
Training the discriminator
0.4918000102043152
0.49869999289512634
0.49935001134872437
0.49924999475479126

0.8051000237464905
0.8112000226974487
0.8069000244140625
0.8059999942779541
0.8155999779701233
0.8206999897956848
0.817799985408783
0.8241000175476074
0.8288999795913696
0.8378999829292297
0.8446999788284302
0.8497999906539917
0.8531000018119812
0.8551999926567078
0.8551999926567078
0.8629000186920166
0.8615000247955322
0.8608999848365784
0.8610000014305115
0.8611999750137329
0.8708000183105469
0.8712999820709229
0.8694000244140625
0.8704000115394592
0.8747000098228455
0.8766999840736389
0.876800000667572
0.8827999830245972
0.8841999769210815
0.8877999782562256
0.8891000151634216
0.8931999802589417
0.8901000022888184
0.8931000232696533
0.8942000269889832
0.8942999839782715
0.8917999863624573
0.8878999948501587
0.8844000101089478
0.8896999955177307
0.8974000215530396
0.8924000263214111
0.9013000130653381
0.9046000242233276
0.9043999910354614
0.9045000076293945
0.9077000021934509
0.9049000144004822
0.9093000292778015
0.9118000268936157
0.909500002861023
0.9133999943733215
0.9090999960899

0.0017999999690800905
0.002400000113993883
0.002300000051036477
0.002099999925121665
0.002199999988079071
0.0017000000225380063
0.002300000051036477
0.0017999999690800905
0.0026000000070780516
0.002899999963119626
0.0026000000070780516
0.0024999999441206455
0.003100000089034438
0.004000000189989805
0.003700000001117587
0.003000000026077032
0.003100000089034438
0.0034000000450760126
0.002899999963119626
0.003100000089034438
0.0035000001080334187
0.003100000089034438
0.0031999999191612005
0.003800000064074993
0.004000000189989805
0.004900000058114529
0.004399999976158142
0.004800000227987766
0.0038999998942017555
0.005100000184029341
0.00570000009611249
0.00559999980032444
0.004399999976158142
0.004800000227987766
0.006000000052154064
0.006500000134110451
0.006300000008195639
0.007199999876320362
0.007199999876320362
0.006300000008195639
0.006300000008195639
0.006899999920278788
0.005799999926239252
0.006599999964237213
0.006500000134110451
0.005200000014156103
0.005900000222027302
0.005

0.7261999845504761
0.7305999994277954
0.7211999893188477
0.7263000011444092
0.7289999723434448
0.7278000116348267
0.7286999821662903
0.7285000085830688
0.7294999957084656
0.732200026512146
0.7287999987602234
0.7261999845504761
0.7282999753952026
0.7289000153541565
0.7303000092506409
0.7325000166893005
0.7286999821662903
0.7311000227928162
0.7336999773979187
0.7311000227928162
0.7389000058174133
0.7361000180244446
0.7426000237464905
0.7425000071525574
0.7408000230789185
0.7422000169754028
0.7426999807357788
0.7493000030517578
0.756600022315979
0.7570000290870667
0.7620000243186951
0.7597000002861023
0.7513999938964844
0.7530999779701233
0.7573000192642212
0.758400022983551
0.7603999972343445
0.7642999887466431
0.7666000127792358
0.7664999961853027
0.7692999839782715
0.7677000164985657
0.7723000049591064
0.7670000195503235
0.7742999792098999
0.770799994468689
0.774399995803833
0.7698000073432922
0.7684999704360962
0.7698000073432922
0.7709000110626221
0.7695000171661377
0.770699977874755

In [None]:
inputs = []
outputs = []
for i in range(10):
    x = np.reshape(white_dataset[i, :].T, (1, white_dataset.shape[1]))
    y = np.reshape(stft_white_dataset_real[i, :, :], (-1, stft_white_dataset_real.shape[1], stft_white_dataset_real.shape[2]))
    t, y1 = istft(np.reshape(gan.g.predict(y).T, (513, 3))+np.imag(stft_white_dataset_imag[i]).T)
    x2 = np.reshape(x.T, (clean_dataset.shape[1],))
    y2 = np.reshape(y1.T, (clean_dataset.shape[1],))
    inputs.append(x2)
    outputs.append(y2)

a = np.concatenate(inputs)
b = np.concatenate(outputs)

c, t, axx = stft(a, fs=samplerate, nperseg=nperseg)
c, t, bxx = stft(b, fs=samplerate, nperseg=nperseg)
displaySpectrogram(axx)
plt.show()
displaySpectrogram(bxx)
plt.show()

In [None]:
Audio(a, rate=samplerate)

In [None]:
Audio(b, rate=samplerate)