In [19]:
import tensorflow as tf
from tensorflow.keras.layers import LeakyReLU
import numpy as np

In [None]:
#Set Kbar = 1 first

In [44]:
class Encoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.pitchconv1 = tf.keras.layers.Conv2D(16,(1,12),(1,12),activation=LeakyReLU())
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pitchconv2 = tf.keras.layers.Conv2D(16,(1,7),(1,7),activation=LeakyReLU())
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.timeconv1 = tf.keras.layers.Conv2D(16,(3,1),(3,1),activation=LeakyReLU())
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.timeconv2 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn4 = tf.keras.layers.BatchNormalization()
        self.timeconv3 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn5 = tf.keras.layers.BatchNormalization()
        self.timeconv4 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn6 = tf.keras.layers.BatchNormalization()
        self.flatten = tf.keras.layers.Flatten()
        
    def call(self,inputs):
        x = self.pitchconv1(inputs)
        x = self.bn1(x)
        x = self.pitchconv2(x)
        x = self.bn2(x)
        x = self.timeconv1(x)
        x = self.bn3(x)
        x = self.timeconv2(x)
        x = self.bn4(x)
        x = self.timeconv3(x)
        x = self.bn5(x)
        x = self.timeconv4(x)
        x = self.bn6(x)
        return self.flatten(x)


class Decoder(tf.keras.Model): #Analogous to the generator actually
    def __init__(self):
        super().__init__()
        self.timeconv1 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.timeconv2 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.timeconv3 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())
        self.bn3 = tf.keras.layers.BatchNormalization()
        self.timeconv4 = tf.keras.layers.Conv2DTranspose(16,(3,1),(3,1),activation=LeakyReLU())
        self.bn4 = tf.keras.layers.BatchNormalization()
        self.pitchconv1 = tf.keras.layers.Conv2DTranspose(16,(1,7),(1,7),activation=LeakyReLU())
        self.bn5 = tf.keras.layers.BatchNormalization()
        self.pitchconv2 = tf.keras.layers.Conv2DTranspose(1,(1,12),(1,12),activation=LeakyReLU())
        self.bn6 = tf.keras.layers.BatchNormalization()
    def call(self,inputs):
        x = self.timeconv1(inputs)
        x = self.bn1(x)
        x = self.timeconv2(x)
        x = self.bn2(x)
        x = self.timeconv3(x)
        x = self.bn3(x)
        x = self.timeconv4(x)
        x = self.bn4(x)
        x = self.pitchconv1(x)
        x = self.bn5(x)
        x = self.pitchconv2(x)
        x = self.bn6(x)
        return x

In [None]:
#we update G once every five updates of D and apply batch normalization only to G

In [46]:
class Autoencoder(tf.keras.Model): #Analogous to the generator actually
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.reshape = tf.keras.layers.Reshape((4,1,16))

    def call(self,inputs):
        x = self.encoder(inputs)
        x = self.reshape(x)
        x = self.decoder(x)
        return x

In [47]:
model = Autoencoder()
model.build((32,96,84,1))
model.summary()
model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError)

Model: "autoencoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_15 (Encoder)         multiple                  4768      
_________________________________________________________________
decoder_3 (Decoder)          multiple                  4693      
_________________________________________________________________
reshape (Reshape)            multiple                  0         
Total params: 9,461
Trainable params: 9,107
Non-trainable params: 354
_________________________________________________________________
