In [1]:
import os
import tensorflow as tf
import numpy as np
import math
from random import sample, shuffle
from PIL import Image
import matplotlib.pyplot as plt
from src.DataGenerator import AudioDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D , Flatten, Reshape, Conv2DTranspose, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from src.helper_functions import plot_reconstruction
import seaborn as sns

In [2]:
data_gen = AudioDataGenerator(
    directory='data/Spotify/comp_pngs/', 
    image_size=(128,512), 
    color_mode='rgb',
    batch_size=32,
    sample_size=100000,
    shuffle=True,
    train_test_split=True, 
    test_size=0.2,
    output_channel_index=0,
    output_size=(128,128))


Found 80000 files for Training set
Found 20000 files for Test set


In [3]:
img_width = 128
img_height = 128
kernel_size = 5
strides = 2

In [4]:
class Time_Freq_Autoencoder(tf.keras.Model):
    
    def __init__(self, latent_dim, num_channels):
        super(Time_Freq_Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.time_encoder = Sequential([
            Reshape(target_shape=(128,128)),
            tf.keras.layers.Conv1D(input_shape=(img_height, img_width), filters=64, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=64, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=128, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=256, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            Flatten(),
            Dense(2048, activation='relu'),
            Dense(units=latent_dim//2)
        ])
        self.freq_encoder = Sequential([
            Reshape(target_shape=(128,128)),
            tf.keras.layers.Conv1D(input_shape=(img_height, img_width), filters=64, kernel_size=5, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=64, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=128, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            tf.keras.layers.Conv1D(filters=256, kernel_size=kernel_size, padding="same", strides=strides, activation='relu'),
            BatchNormalization(axis=-1),
            Flatten(),
            Dense(2048, activation='relu'),
            Dense(units=latent_dim//2)
        ])
        self.decoder = Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim)),
            Dense(units=16384, activation='relu'),
            BatchNormalization(axis=-1),
            Reshape(target_shape=(8,8,256)),
            Conv2DTranspose(filters=256, kernel_size=kernel_size, strides=strides, padding="same", activation="relu"),
            BatchNormalization(axis=-1),
            Conv2DTranspose(filters=128, kernel_size=kernel_size, strides=strides, padding="same", activation="relu"),
            BatchNormalization(axis=-1),
            Conv2DTranspose(filters=64, kernel_size=kernel_size,  strides=strides, padding="same", activation="relu"),
            BatchNormalization(axis=-1),
            Conv2DTranspose(filters=32, kernel_size=kernel_size,  strides=strides, padding="same", activation="relu"),
            BatchNormalization(axis=-1),
            Conv2DTranspose(filters=num_channels, kernel_size=kernel_size, padding="same", activation='sigmoid'),
        ])
        
    def encoder(self, x):
        x_1 = x
        x_2 = tf.transpose(x, perm=[0,2,1,3])
        encoded_time = self.time_encoder(x_1)
        encoded_freq = self.freq_encoder(x_2)
        encoded = tf.keras.layers.Concatenate(axis=1)([encoded_time, encoded_freq])
        return encoded
        
    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
autoencoder = Time_Freq_Autoencoder(256, 1)

opt = Adam(learning_rate=1e-3)

autoencoder.compile(optimizer=opt, loss=tf.keras.losses.mse)

In [5]:
autoencoder.build(input_shape=(None,128,128,1))

In [6]:
for layer in autoencoder.layers:
    print(layer.summary())

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 128, 128)          0         
                                                                 
 conv1d (Conv1D)             (None, 64, 64)            41024     
                                                                 
 batch_normalization (BatchN  (None, 64, 64)           256       
 ormalization)                                                   
                                                                 
 conv1d_1 (Conv1D)           (None, 32, 64)            20544     
                                                                 
 batch_normalization_1 (Batc  (None, 32, 64)           256       
 hNormalization)                                                 
                                                                 
 conv1d_2 (Conv1D)           (None, 16, 128)           4