# Homework2

## 1. 3D Reconstuction 
Describe the NeRF (Neural Radiance Fields) algorithm by answering the following three points clearly and concisely (3 points):
- What are the inputs and outputs of NeRF?
- What is the training loss? How to compute the gradient?
- What training data does NeRF require?


## 2. Instantaneous Change of Variables
Derive the Instantaneous Change of Variables formula:
$$ \frac{d}{dt} \log p(\mathbf{z}(t)) = -\operatorname{tr}\left( \frac{d f}{d \mathbf{z}(t)} \right),$$
where the dynamics are given by
$$\frac{d\mathbf{z}(t)}{dt} = f(\mathbf{z}(t), t),$$

and the function  f  is assumed to be uniformly Lipschitz continuous in  $\mathbf{z}$  and continuous in  t . Provide at least **two ways** to prove it. (4 points)

Hints:
- Hint 1: Consider the time limit of the discrete change-of-variables formula.
- Hint 2: You may utilize the Fokker-Planck equation in the deterministic case.
- Hint 3: You may introduce a smooth test function to rigorously justify the derivation

## 3. Large Language Models 

LLMs model a joint distribution over a sequence $ x_1, \dots, x_T $ by factorizing it into conditional distributions
$$ p(x_1,\dots,x_T)=\prod_{t=1}^T p(x_t \mid x_{<t}),$$
and learning these conditionals through next-token prediction.

Answer the following:
- How is next-token prediction implemented in practice using a Transformer decoder? (1 pts)

- Identify the training setup:including the input,output and traning objective. (2 pts)

- What is the difference between sampling during training and inference? Why? (2 pts)


## 4. Coding: VAE on MNIST
A Python notebook with the code to be completed is provided. Please complete it using the following intructions. This problem is adapted from the pset1 of [Course 6.S978 Deep Generative Models](https://mit-6s978.github.io/schedule.html) given by Professor Kaiming He at MIT.

VAEs are trained by maximizing the Evidence Lower Bound (ELBO) on the marginal log-likelihood:
$$\log p(x) \geq \mathbb{E}_{q(z|x)}[\log\frac{p(x, z)}{q(z|x)}] = \mathrm{ELBO},$$

where $x$ is the data (binary images for MNIST) and $z$ is the latent code.

(a) Give a detailed mathematical proof of the ELBO starting from the marginal
log-likelihood log p(x).(2 Points)

(b) Complete the implementation of the  ``self.encoder`` and ``self.decoder`` in the
``VAE()`` model. (2 Points)

(c) Implement the reparameterization trick in the ``reparameterize()`` function. In this assignment, we only sample one latent code $z_{i}$ for each $x_i$. (1 Points)

(d) In practice, the above expectation in ELBO is estimated using Monte Carlo sampling, yielding the generic Stoachastic Gradient Variational Bayes (SGVB) estimator,
$$\mathrm{ELBO} \approx \sum_{i} [\log p(x_i|z_{i}) + \log p(z_{i}) - \log q(z_{i}|x_i)], $$
where $z_{i}$ is sampled from $ q(z|x_i) = \mathcal{N}(z;\mu_i, \sigma^2_i \mathbf{I})$. 

Finalize the SGVB estimator by completing the ``log_normal_pdf()``  function, which computes the log probability for a normal distribution given its mean and variance.(1 Points)

(e) In many cases, Monte Carlo sampling is not necessary to estimate all the terms of ELBO, as some terms can be integrated analytically. In particular, when both $q(z|x)=\mathcal{N}(z;\mu(x),\mathrm{diag}(\sigma^2(x)))$ and $p(z)=\mathcal{N}(z;0,I)$ are
Gaussian distributions, the ELBO can be decomposed into an analytical KL divergence plus
the expected reconstruction error:
$$\mathrm{ELBO} â‰ˆ -D_{KL}(q(z|x) || p(z)) + \sum_{i} \log p(x_i|z_{i}) = \\\frac{1}{2}\sum_{d}(1+\log((\sigma_d)^2) - (\mu_d)^2 - (\sigma_d)^2) + \sum_{i} \log p(x_i|z_{i})$$
where d is the dimension of the latent space, and i is the indices of the data.

Run the verfirication code to check if the analytical KL divergence matches the Monte Carlo estimate. (2 Points)

(f) Using the above two losses, train two VAE models on the MNIST dataset (manual
tuning of parameters such as epochs, hidden dims, lr, coeff may be necessary).
Use the provided evaluation code to visualize the reconstruction results and the
generated images (in 2D grid) for both models. (3 Points)

(g) Latent Interpolation:
Encode two MNIST test images with different digit labels to obtain latent codes z_1 and z_2. Linearly interpolate between them using z(\alpha) = (1-\alpha)z_1 + \alpha z_2 for \alpha\in\{0,0.1,\dots,1\}, decode each interpolated code with your trained VAE decoder, and display the generated images in order. Briefly describe how the generated digits gradually transform from one class to the other along the interpolation. (2 Points)

# Code to Be Completed

#### Import Libraries and Functions

In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### MNIST Data Loader

In [None]:
tensor_transform = transforms.ToTensor()

batch_size = 256
MNIST_dataset = datasets.MNIST(root = "./data",
									train = True,
									download = True,
									transform = tensor_transform)

MNIST_loader = torch.utils.data.DataLoader(dataset = MNIST_dataset,
							   batch_size = batch_size,
								 shuffle = True)


### Training Function

In [None]:
from math import e
mse = torch.nn.MSELoss()

def loss_func(model, x, reg_func=None, coeff=1e-3):
    output = model(x)
    err = mse(output['imgs'], x)
    logpx_z = -1.0 * torch.sum(err)

    if reg_func is not None:
      reg = reg_func(output)
    else:
      reg = 0.0

    return -1.0 * torch.mean(logpx_z + coeff * reg)

def train(dataloader, model, loss_func, optimizer, epochs):
    losses = []

    for epoch in tqdm(range(epochs), desc='Epochs'):
        running_loss = 0.0
        batch_progress = tqdm(dataloader, desc='Batches', leave=False)

        for iter, (images, labels) in enumerate(batch_progress):
            batch_size = images.shape[0]
            images = images.reshape(batch_size, -1).to(device)
            loss = loss_func(model, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / len(MNIST_dataset) * batch_size
            losses.append(loss.item())

        tqdm.write(f'----\nEpoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}\n')

    return losses


### Evaluation Function

In [None]:
def plot_latent_images(model, n, digit_size=28):
    grid_x = np.linspace(-2, 2, n)
    grid_y = np.linspace(-2, 2, n)

    image_width = digit_size * n
    image_height = digit_size * n
    image = np.zeros((image_height, image_width))

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z = torch.tensor([[xi, yi]], dtype=torch.float32).to(device)
            with torch.no_grad():
                x_decoded = model.decode(z)
            digit = x_decoded.view(digit_size, digit_size).cpu().numpy()
            image[i * digit_size: (i + 1) * digit_size,
                  j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(image, cmap='Greys_r')
    plt.axis('Off')
    plt.show()


def eval(model):
    original_imgs = torch.cat([MNIST_dataset[i][0] for i in range(5)])
    with torch.no_grad():
      res = model(original_imgs.reshape(5, -1).to(device))
      reconstructed_imgs = res['imgs']
      reconstructed_imgs = reconstructed_imgs.cpu().reshape(*original_imgs.shape)

    fig, axes = plt.subplots(5, 2, figsize=(10, 25))

    for i in range(5):
        original_image = original_imgs[i].reshape(28, 28)
        axes[i, 0].imshow(original_image, cmap='gray')
        axes[i, 0].set_title(f'Original Image {i+1}')
        axes[i, 0].axis('off')

        reconstructed_image = reconstructed_imgs[i].reshape(28, 28)
        axes[i, 1].imshow(reconstructed_image, cmap='gray')
        axes[i, 1].set_title(f'Reconstructed Image {i+1}')
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()



### Model

In [None]:
class VAE(torch.nn.Module):
  def __init__(self, input_dim, hidden_dims, decode_dim=-1, use_sigmoid=True):
      '''
      input_dim: The dimensionality of the input data.
      hidden_dims: A list of hidden dimensions for the layers of the encoder and decoder.
      decode_dim: (Optional) Specifies the dimensions to decode, if different from input_dim.
      '''
      super().__init__()

      self.z_size = hidden_dims[-1] // 2

      self.encoder = torch.nn.Sequential()
      self.decoder = torch.nn.Sequential()
      ##################
      ### Problem 2 (a): finish the implementation for encoder and decoder
      ##################

  def encode(self, x):
      mean, logvar = torch.split(self.encoder(x), split_size_or_sections=[self.z_size, self.z_size], dim=-1)
      return mean, logvar

  def reparameterize(self, mean, logvar, n_samples_per_z=1):
      ##################
      ### Problem 2(b): finish the implementation for reparameterization
      ##################
      pass

  def decode(self, z):
      probs = self.decoder(z)
      return probs

  def forward(self, x, n_samples_per_z=1):
      mean, logvar = self.encode(x)

      batch_size, latent_dim = mean.shape
      if n_samples_per_z > 1:
        mean = mean.unsqueeze(1).expand(batch_size, n_samples_per_z, latent_dim)
        logvar = logvar.unsqueeze(1).expand(batch_size, n_samples_per_z, latent_dim)

        mean = mean.contiguous().view(batch_size * n_samples_per_z, latent_dim)
        logvar = logvar.contiguous().view(batch_size * n_samples_per_z, latent_dim)

      z = self.reparameterize(mean, logvar, n_samples_per_z)
      x_probs = self.decode(z)

      x_probs = x_probs.reshape(batch_size, n_samples_per_z, -1)
      x_probs = torch.mean(x_probs, dim=[1])

      return {
          "imgs": x_probs,
          "z": z,
          "mean": mean,
          "logvar": logvar
      }

### Test
hidden_dims = [128, 64, 36, 18, 18]
input_dim = 256
test_tensor = torch.randn([1, input_dim]).to(device)

vae_test = VAE(input_dim, hidden_dims).to(device)

with torch.no_grad():
  test_out = vae_test(test_tensor)


### Loss Functions
#### Loss 1: Stoachastic Gradient Variational Bayes (SGVB) Estimator

In [None]:
##### Loss 1: SGVB #####
log2pi = torch.log(2.0 * torch.tensor(np.pi)).to(device)
torch_zero = torch.tensor(0.0).to(device)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    ##################
    ### Problem 2(c): finish the implementation for the log-probability for normal distribution with mean and var
    ##################
    pass

def loss_SGVB(output):
    logpz = log_normal_pdf(output['z'], torch_zero, torch_zero)
    logqz_x = log_normal_pdf(output['z'], output['mean'], output['logvar'])
    return logpz -logqz_x


### Loss 2: KL Divergence w/o Estimation

In [None]:
##### Loss 2: KL w/o Estimation #####
def loss_KL_wo_E(output):
    var = torch.exp(output['logvar'])
    logvar = output['logvar']
    mean = output['mean']

    return -0.5 * torch.sum(torch.pow(mean, 2)
                            + var - 1.0 - logvar,
                            dim=[1])


### Verifying loss 1 == loss 2


In [None]:
##################
### Problem 2(d): Check if the analytical KL divergence matches the Monte Carlo estimate.
hidden_dims = [128, 32, 16, 4]
image_shape = MNIST_dataset[0][0].shape
input_dim = torch.prod(torch.tensor(image_shape)).item()
vae_test = VAE(input_dim, hidden_dims).to(device)

all_l_sgvb, all_KL_wo_E = [], []
all_n_samples_per_z = list(range(1, 4000, 100))

with torch.no_grad():
    for n_samples_per_z in all_n_samples_per_z:
        for _, (imgs, _) in enumerate(MNIST_loader):
            batch_size = imgs.shape[0]
            imgs = imgs.reshape(batch_size, -1).to(device)

            output = vae_test(imgs, n_samples_per_z=n_samples_per_z)

            l_sgvb = torch.mean(loss_SGVB(output))
            l_KL_wo_E = torch.mean(loss_KL_wo_E(output))

            all_l_sgvb.append(l_sgvb.item())
            all_KL_wo_E.append(l_KL_wo_E.item())
            break

# Plot the two curves
plt.figure(figsize=(12, 6))

plt.plot(all_n_samples_per_z, all_l_sgvb, label='SGVB Loss')
plt.plot(all_n_samples_per_z, all_KL_wo_E, label='KL Divergence (w/o E)')

plt.xlabel('Number of Samples per z')
plt.ylabel('Loss')
plt.legend()

plt.grid(True)
plt.show()
##################


### Training with SGVB loss



In [None]:
##################
### Problem 2(e): Train VAE with SGVB loss
epochs = 20

hidden_dims = [128, 32, 16, 4]
assert hidden_dims[-1] == 4, "always use 4 as the latent dimension for generating a 2D image grid during evaluation"

image_shape = MNIST_dataset[0][0].shape
input_dim = torch.prod(torch.tensor(image_shape)).item()
print("input_dim: ", input_dim)

vae_sgvb = VAE(input_dim, hidden_dims).to(device)
print(vae_sgvb)

coeff = 1e-3

optimizer_vae_sgvb = torch.optim.Adam(vae_sgvb.parameters(),
                                lr = 1e-4,
                                weight_decay = 1e-8)

log_vae_sgvb = train(MNIST_loader, vae_sgvb, lambda model, x: loss_func(model, x, reg_func=loss_SGVB, coeff=1e-3), optimizer_vae_sgvb, epochs)

### Evalutaion
eval(vae_ELBO)
plot_latent_images(vae_ELBO, n=8)
##################

### Trainimg with analytical KL

In [None]:
##################
### Problem 2(e): Train VAE with analytical KL
epochs = 20

hidden_dims = [128, 32, 16, 4]
assert hidden_dims[-1] == 4, "always use 4 as the latent dimension for generating a 2D image grid during evaluation"

image_shape = MNIST_dataset[0][0].shape
input_dim = torch.prod(torch.tensor(image_shape)).item()
print("input_dim: ", input_dim)

vae_kl_wo_e = VAE(input_dim, hidden_dims).to(device)
print(vae_kl_wo_e)

optimizer_vae_kl_wo_e = torch.optim.Adam(vae_kl_wo_e.parameters(),
                                lr = 1e-4,
                                weight_decay = 1e-8)

log_vae_kl_wo_e = train(MNIST_loader, vae_kl_wo_e, lambda model, x: loss_func(model, x, reg_func=loss_KL_wo_E, coeff=1e-3), optimizer_vae_kl_wo_e, epochs)


### Evaluation
eval(vae_kl_wo_e)
plot_latent_images(vae_kl_wo_e, n=8)
##################
