In [None]:
import tensorflow as tf
from tensorflow import keras as k
import numpy as np

In [None]:
def _compute_number_of_neurons(layers, seed):
    neurons = [seed]

    tmp = seed
    for _ in range(layers-1):
        tmp = int(tmp / 2)
        neurons.append(tmp)
    return neurons

In [None]:
_compute_number_of_neurons(3,96)

In [None]:
class AAEClassModel(k.models.Model):
    def __init__(self,
            n_features,
            n_classes,
            enc_layers, enc_seed,
            disc_layers, disc_seed,
            af='gelu',
            bn_momentum=0.8, leak_alpha=0.2,
            latent_dim=2,
            dist_threshold=2.
        ):

        super().__init__()
# https://doi.org/10.48550/arXiv.1511.05644, Sect. 7 / Fig. 10
        self.n_classes = n_classes
        self.latent_dim = latent_dim
        self.dist_threshold2 = dist_threshold*dist_threshold
        
        enc_neurons = _compute_number_of_neurons(enc_layers, enc_seed) 
        disc_neurons = _compute_number_of_neurons(disc_layers, disc_seed)

        inp = k.Input(shape=(n_features,),name='inp')
        out = inp

        for n in range(enc_layers):
            out = k.layers.Dense(enc_neurons[n],activation=af,name=f'enc_{n}')(out)
            out = k.layers.BatchNormalization(momentum=bn_momentum, name=f'enc_bn_{n}')(out)

        z = k.layers.Dense(latent_dim,name='latent_z')(out)
        y = k.layers.Dense(n_classes,activation='softmax',name='latent_y')(out)

        hz = k.layers.Dense(latent_dim,name='cluster_heads')(y)
        latent = k.layers.Add(name='add_z_hz')([z,hz])

        out = latent
        for n in reversed(range(enc_layers)):
            out = k.layers.Dense(enc_neurons[n],activation=af,name=f'dec_{n}')(out)
            out = k.layers.BatchNormalization(momentum=bn_momentum, name=f'dec_bn_{n}')(out)

        dec_out = k.layers.Dense(n_features,name='dec_out')(out)

        out = y
        for n in range(disc_layers):
            out = k.layers.Dense(disc_neurons[n],name=f'y_disc_{n}')(out)
            out = k.layers.LeakyReLU(negative_slope=leak_alpha,name=f'y_disc_relu_{n}')(out)
        
        y_disc_out = k.layers.Dense(1,name='y_disc_out')(out)

        out = z
        for n in range(disc_layers):
            out = k.layers.Dense(disc_neurons[n],name=f'z_disc_{n}')(out)
            out = k.layers.LeakyReLU(negative_slope=leak_alpha,name=f'z_disc_relu_{n}')(out)

        z_disc_out = k.layers.Dense(1,name='z_disc_out')(out)

        self.enc = k.Model(inputs=inp,outputs=[y,z])
        self.sum = k.Model(inputs=[y,z],outputs=latent)
        self.dec = k.Model(inputs=latent,outputs=dec_out)
        self.heads = k.Model(inputs=y,outputs=hz)
        # self.ae = k.Model(inputs=inp,outputs=dec_out)
        self.y_disc = k.Model(inputs=y,outputs=y_disc_out)
        self.z_disc = k.Model(inputs=z,outputs=z_disc_out)

    def compile(self,optimizer,lr=None):
        if optimizer is None:
            self.opt = k.optimizers.Adam(learning_rate=lr)
        else:
            self.opt = optimizer

        self.ae_loss = k.losses.MeanSquaredError()

        self.y_enc.compile()
        self.z_enc.compile()
        self.enc.compile()
        self.dec.compile()
        self.y_disc.compile()
        self.z_disc.compile()

    @tf.function
    def train_step(self,in_batch):
        if isinstance(in_batch, tuple):
            batch = in_batch[0]
        else:
            batch = in_batch

        # autoencoder
        with tf.GradientTape() as aet:
            y,z = self.enc(batch)
            rec = self.dec(self.sum([y,z]))
            ael = self.ae_loss(batch,rec)
    
        aew = self.enc.trainable_weights + self.sum.trainable_weights + self.dec.trainable_weights
        aeg = aet.gradient(ael,aew)
        self.opt.apply_gradients(zip(aeg,aew))
    
        # categoric discriminator
        idx = tf.random.uniform((batch.shape[0],), minval=0, maxval=self.n_classes, dtype=tf.int32)
        randy = tf.one_hot(idx, depth=self.n_classes)
    
        # binary crossentropy from logits
        with tf.GradientTape() as yt:
            nyp = self.y_disc(y)
            nyp *= tf.random.uniform(tf.shape(nyp), 1., 1.05)
            nyl = tf.reduce_mean(nyp,axis=0)
    
            pyp = self.y_disc(randy)
            pyp *= tf.random.uniform(tf.shape(pyp), 1., 1.05)
            pyl = -tf.reduce_mean(pyp,axis=0)
            y_disc_loss = nyl + pyl
    
        yg = yt.gradient(y_disc_loss,self.y_disc.trainable_weights)
        self.opt.apply_gradients(zip(yg,self.disc.trainable_weights))
    
        # cheet it
        with tf.GradientTape() as yct:
            yc = self.y_disc(self.enc(batch))
            yc *= tf.random.uniform(tf.shape(yc), 1., 1.05)
            ycl = -tf.reduce_mean(yc, axis=0)
    
        ycg = yct.gradient(ycl,self.enc.trainable_weights)
        self.opt.apply_gradient(zip(ycg,self.enc.trainable_weights))
            
        # intra category discriminator
        randz = tf.random.normal(shape=(batch.shape[0], self.latent_dim))
    
        with tf.GradientTape() as zt:
            nzp = self.z_disc(z)
            nzp *= tf.random.uniform(tf.shape(nzp), 1., 1.05)
            nzl = tf.reduce_mean(nzp,axis=0)
    
            pzp = self.z_disc(randz)
            pzp *= tf.random.uniform(tf.shape(pzp), 1., 1.05)
            pzl = -tf.reduce_mean(pzp,axis=0)
            z_disc_loss = nzl + pzl
    
        zg = zt.gradient(z_disc_loss,self.z_disc.trainable_weights)
        self.opt.apply_gradients(zip(zg,self.disc.trainable_weights))
            
        # cheet it
        with tf.GradientTape() as zct:
            zc = self.z_disc(self.enc(batch))
            zc *= tf.random.uniform(tf.shape(zc), 1., 1.05)
            zcl = -tf.reduce_mean(zc, axis=0)
    
        zcg = yct.gradient(zcl,self.enc.trainable_weights)
        self.opt.apply_gradient(zip(zcg,self.enc.trainable_weights))
        
        # keep cluster heads apart
        randc = tf.linalg.diag(tf.random.uniform((self.n_classes,),0.95,1.05))
        with tf.GradientTape() as ht:
            ch = self.heads(randc)
            norms = tf.reduce_sum(tf.square(ch), axis=1, keepdims=True)  # shape=(N,1)
            dists_squared = norms - 2 * tf.matmul(ch, ch, transpose_b=True) + tf.transpose(norms)
            small_dists = tf.mask(dists_squared, dists < self.dist_threshold2)
            hl = tf.sum(tf.exp(-small_dists))

        hg = ht.gradient(hl,self.heads.trainable_weights)
        self.opt.apply_gradient(zip(hg,self.heads.trainable_weights))

        return { 
            'ae_loss' : ael,
            'y_disc_loss': y_disc_loss,
            'z_disc_loss': z_disc_loss,
            'y_cheat_loss': ycl,
            'z_cheat_loss': zcl,
            'cluster_head_loss': hl
        }

    @tf.function
    def call(self,inp):
        return self.sum(self.enc(inp))

            

In [None]:
m = AAEClassModel(n_features=123,n_classes=14,enc_layers=3,enc_seed=96,disc_layers=3,disc_seed=48)

In [None]:
m.enc.summary()

In [None]:
m.dec.summary()

In [None]:
m.y_disc.summary()

In [None]:
m.sum.summary()

In [None]:
randc = tf.linalg.diag(tf.random.uniform((m.n_classes,),0.95,1.05))
m.heads(randc)

In [None]:
tf.random.uniform((m.n_classes,),0.95,1.05)

In [None]:
m(tf.random.normal((4,123)))