In [None]:
class EncoderBlock(gluon.Block):
    def __init__(self, channels):
        super(EncoderBlock, self).__init__()
        self.channels = channels
    def forward(self, x):
        x = nn.Conv2D(self.channels, 5)(x)
        x = nn.BatchNorm()(x)
        x = nn.ReLU()(x)
        return x

class DecoderBlock(gluon.Block):
    def __init__(self, channels):
        super(DecoderBlock, self).__init__()
        self.channels = channels
    def forward(self, x):
        x = nn.Conv2DTranspose(self.channels, 5)(x)
        x = nn.BatchNorm()(x)
        x = nn.ReLU()(x)
        return x

class VAE(gluon.Block):
    
    def __init__(self, num_latent, ctx):
        super(VAE, self).__init__()
        self.encoder = self._get_encoder(num_latent)
        self.decoder = self._get_decoder(num_latent)
        self.ctx = ctx
    
    def forward(self, img_batch):
        batch_size = img_batch.shape[0]
        latent_layer = self.encoder(img_batch)
        latent_mean, latent_logvar = nd.split(latent_layer, axis=1, num_outputs=2)
        
        eps = nd.random_normal(loc=0, scale=1, shape(batch_size, self.n_latent),ctx=self.ctx)
        latent_z = latent_mean + nd.exp(0.5 * latent_logvar) * eps
        img_hat = self.decoder(latent_z)
        return img_hat, latent_mean, latent_logvar
    
    def vae_loss(self, img_hat, latent_mean, latent_logvar):
        KL_div_loss = -0.5 * nd.sum(1 + latent_logvar - latent_mean * latent_mean - nd.exp(latent_logvar),
                                   axis=1)
        logloss = -nd.sum(img_batch*nd.log(img_hat + 1e-10) + (1 - img_batch)*nd.log(1 - img_hat + 1e-10), axis=1)
        return KL_div_loss + logloss
        
    def _get_encoder(self, num_latent):
        with self.name_scope():
            encoder = nn.Sequential(prefix='encoder')
            for channels in [64, 128, 256]:
                encoder.add(EncoderBlock(channels))
            encoder.add(nn.Dense(2 * num_latent, flatten=True))
            encoder.add(nn.BatchNorm())
            encoder.add(nn.ReLU())
        return encoder
    
    def _get_decoder(self, num_latent):
        with self.name_scope():
            decoder = nn.Sequential(prefix='decoder')
            decoder.add(nn.Dense(num_latent))
            decoder.add(nn.BatchNorm())
            decoder.add(nn.ReLU())
            for channels in [256, 128, 32]:
                decoder.add(nn.DecoderBlock(channels))
            decoder.add(nn.Conv2D(3, 5))
        return decoder