In [None]:
class AAE_generator(tf.keras.Model):
    def __init__(
        self,
        encoding_dims,
        input_size,
        input_channels,
        step_channels=16,
        non_linearity=tf.keras.layers.LeakyReLU(0.2)):
        super(AAE_generator, self).__init__(encoding_dims)
        # Encoder Part
        encoder = [Sequential(tf.keras.layers.Conv1D(input_channels, 5, 2), non_linearity)]
        size = input_size // 2
        channels = step_channels
        while size > 1:
            encoder.append(Sequential(layers=[tf.keras.layers.Conv1D(channels, 5, 2),
                                    tf.keras.layers.BatchNormalization(),
                                    non_linearity] ))
            channels *= 4
            size = size // 4
        self.encoder = Sequential(layers=encoder)
        self.encoder_fc = tfl.layers.Linear(channels, encoding_dims)  # Can add a Tanh non-linearity if training is unstable as noise prior is Gaussian

        # Decoder Part
        self.decoder_fc = tfl.layers.Linear(encoding_dims, step_channels)
        decoder = []
        size = 1
        channels = step_channels
        while size < input_size // 2:
            decoder.append(Sequential(layers=[tf.keras.layers.Conv1DTranspose(channels, 5, 2),
                                    tf.keras.layers.BatchNormalization(),
                                    non_linearity] ))
            channels *= 4
            size *= 4
        decoder.append(tf.keras.layers.Conv1DTranspose(channels, 5, 2))
        self.decoder = Sequential(layers=decoder)

    def sample(self, noise):
        noise = self.decoder_fc(noise)
        noise = noise.view(-1, noise.size(1), 1, 1)
        return self.decoder(noise)


    def call(self, inputs):
        x = self.encoder(inputs)
        encoded = self.encoder_fc(x)

        x = self.decoder_fc(encoded)
        x = self.decoder(x)
        return x

In [None]:
learning_rate = 0.001
model = AAE_generator(encoding_dims, input_size, input_channels)
model.compile(optimizer=tf.optimizers.Adam(learning_rate),
              loss=tf.keras.losses.CategoricalCrossentropy())