In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, hidden_dim, z_dim):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(x_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, hidden_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim * 3),
            nn.ReLU(),
            nn.Linear(hidden_dim * 3, x_dim),
            nn.Sigmoid(),
        )

        self.mu = nn.Linear(z_dim, z_dim)
        self.logvar = nn.Linear(z_dim, z_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
    MSE = nn.MSELoss(reduction="sum")
    reconstruction_loss = MSE(recon_x, x)
    KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + KL_divergence