### Import modules

In [None]:
import tensorflow as tf

# from https://medium.com/ibm-data-ai/memory-hygiene-with-tensorflow-during-model-training-and-deployment-for-inference-45cf49a15688
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        print(str(gpu))
        tf.config.experimental.set_virtual_device_configuration(gpu,[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])


import tensorflow.keras as keras
from keras import layers
import numpy as np
from dataset import *

#### GPU config

### Create dataset from test audio file


In [None]:
xf, X, spec_params = create_dataset('audio/training_audio.wav', 2, 0.5)
print(X.shape)

## Define model

In [None]:
inputs = keras.Input((1024, 256, 1))

# Encoder
batchnorm = layers.BatchNormalization()(inputs)
conv = layers.Conv2D(2, (3, 3), (1, 1), padding='same', activation='relu')(batchnorm)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)
conv = layers.Conv2D(4, (3, 3), (1, 1), padding='same', activation='relu')(pool)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)
conv = layers.Conv2D(8, (3, 3), (1, 1), padding='same', activation='relu')(pool)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)
conv = layers.Conv2D(16, (3, 3), (1, 1), padding='same', activation='relu')(pool)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)
conv = layers.Conv2D(32, (3, 3), (1, 1), padding='same', activation='relu')(pool)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)
conv = layers.Conv2D(64, (3, 3), (1, 1), padding='same', activation='relu')(pool)
pool = layers.MaxPool2D((2, 2), strides=2)(conv)
batchnorm = layers.BatchNormalization()(pool)

flatten = layers.Flatten()(pool)
dense = layers.Dense(2048)(flatten)

# Decoder
dense2 = layers.Dense(flatten.shape[1])(dense)
reshaped = layers.Reshape(pool.shape[1:])(dense2)

filters = 64
deconv = reshaped
while (filters > 2):
    print(filters)
    batchnorm = layers.BatchNormalization()(deconv)
    depool = layers.Conv2DTranspose(filters, (2, 2), strides=2)(batchnorm)
    deconv = layers.Conv2DTranspose(1, (3, 3), strides=1, padding='same', activation='relu')(depool)
    filters /= 2

depool = layers.Conv2DTranspose(filters, (2, 2), strides=2)(deconv)
deconv = layers.Conv2DTranspose(1, (3, 3), strides=1, padding='same', activation='sigmoid')(depool)

print(deconv.shape)

outputs = deconv

model = keras.Model(inputs=inputs, outputs=outputs, name="autoencoder")

In [None]:
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()

In [None]:
X_fit = tf.expand_dims(X, -1)
X_fit = (X_fit - np.min(X_fit))/(np.max(X_fit) - np.min(X_fit))

print("==================")
print(X_fit.shape)
print("==================")

model.load_weights("cnn_autoencoder")


In [None]:
model.fit(X_fit, X_fit,
        epochs=150,
        shuffle=True
)

### COMPARE SPECTROGRAMS IN IMAGES


In [None]:
import matplotlib.pyplot as plt

In [None]:
x_example = X[0]

plt.imsave("INPUT_EXAMPLE.png", x_example)
x_example = np.reshape(x_example, (1, x_example.shape[0], x_example.shape[1], 1))
print(x_example.shape)
prediction = model.predict(x_example)
plt.imsave("OUTPUT_EXAMPLE.png", prediction[0, :, :, 0])

In [None]:
model.save("cnn_autoencoder")

### Save to audio file

In [None]:
import importlib
import postprocessing

In [None]:
importlib.reload(postprocessing)
out_samp, out_win, out_stride = spec_params
postprocessing.reverse_spectrogram(prediction[0, :, :, 0], out_samp, out_win, out_stride)

In [None]:
del spectrogram
del spectrogram_timestep