# Variational Autoencoders

## 1

The first term, $\mathbb{E}_{q\phi(z|x)}[\log p_\theta(x|z)]$, encourages the model to reconstruct the input $x$ accurately from the latent variable $z$. The second term, $D_{\mathrm{KL}}(q_\phi(z|x) || p(z))$, regularizes the encoder by pushing $q_\phi(z|x)$ towards the prior $p(z)$, typically making the latent space resemble a standard normal distribution.


## 2

### (a)

Decoding only the means of the embeddings is insufficient to recover the data distribution because the variances $\sigma_z^2$ are crucial for capturing the full variability of the data. Both the means $\mu_z$ and variances $\sigma_z^2$ are necessary to sample latent variables $z \sim q_\phi(z|x)$, which the decoder uses to reconstruct $x$ and accurately represent the data distribution.

### (b)

The encoder produces $ q_\phi(z|x) = \mathcal{N}(z; f^\mu_\phi(x), f^\sigma_\phi(x)) $. The marginal distribution $ q_\phi(z) $ is obtained by integrating over the data distribution: $ q_\phi(z) = \int q_\phi(z|x) \, p_{\text{data}}(x) \, dx $. Even if $ q_\phi(z) $ is a unit Gaussian, this doesn't imply that the means $ f^\mu_\phi(x) $ are Gaussian distributed.For example,as Central Limit Theorem (CLT) describes, sums of independent random variables tend to become Gaussian regardless of their original distributions.

## 3

No. Since the autoencoder only minimizes reconstruction error without constraining the latent space, the encoded $ z = f(x) $ does not necessarily follow a standard normal distribution; therefore, decoding $ z \sim \mathcal{N}(0, I) $ with $ g $ won't reproduce the data distribution $ p_{\text{data}} $.


## 4

**VAE**

```python
class VAE(nn.Module):
    def __init__(self, z_dims=4, input_size=784, num_hidden=128):
        super().__init__()
        self.z_dims = z_dims
        self.input_size = input_size

        # FIXME: Create two encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(input_size, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
        )

        # FIXME: Create the mean and logvar readout layers
        self.mu = nn.Linear(num_hidden, z_dims)
        self.logvar = nn.Linear(num_hidden, z_dims)

        # FIXME: Create the decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(z_dims, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, input_size),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # FIXME: Implement the VAE forward function
        x = x.flatten(start_dim=1)
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        # reparameterization trick
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        output = self.decoder(z)
        return output, mu, logvar

## 5

**Training Script**
```python
model = VAE().cuda()
optimizer = optim.Adam(model.parameters(), lr=LR)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0
    for X, _ in tqdm(train_loader):
        X = X.cuda()
        optimizer.zero_grad()
        x_prime, mu, logvar = model(X)

        # FIXME: Calculate loss
        reconstruction_loss = F.mse_loss(
            x_prime, X.flatten(start_dim=1), reduction="sum"
        )
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # negative ELBO
        loss = reconstruction_loss + kl_divergence

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(
        "Epoch: {} Train Loss: {:.4f}".format(
            epoch, train_loss / len(train_loader.dataset)
        )

## 6

**plot latent space**
```python
def plot_latents(model, i=0, j=1):
    # FIXME: Plot the image grid
    model.eval()
    n = 10  # 280*280
    grid_x = np.linspace(-2, 2, n)
    grid_y = np.linspace(-2, 2, n)

    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))

    for xi, x in enumerate(grid_x):
        for yi, y in enumerate(grid_y):
            z = torch.zeros(model.z_dims)
            z[i] = x
            z[j] = y
            z = z.unsqueeze(0).to(next(model.parameters()).device)
            with torch.no_grad():
                x_decoded = model.decoder(z)
            x_decoded = x_decoded.cpu().numpy().reshape(digit_size, digit_size)
            figure[
                (n - yi - 1) * digit_size : (n - yi) * digit_size,
                xi * digit_size : (xi + 1) * digit_size,
            ] = x_decoded

    # Plot the big image
    plt.figure(figsize=(8, 8))
    plt.imshow(figure, cmap="gray")
    plt.xlabel(f"Latent dimension {i}")
    plt.ylabel(f"Latent dimension {j}")
    plt.title(f"Latent space traversal for dimensions {i} and {j}")
    plt.show()

