In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Model

- `input_dim`: input's dimension
- `h_dim`: hidden's dimension
- `z_dim`: latent's dimension

- Ensure latent space is gaussian. Why?
- BCE takes 0 and 1 as target, but loss in training loop give image => how bce fit in this case?
- what does `kl_div` do in this case?
- why calculate the `sigma` and `mu` use a liner layer?

In [28]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        
        # encoder
        self.img2hid = nn.Linear(input_dim, h_dim)
        self.hid2mu = nn.Linear(h_dim, z_dim)
        self.hid2sigma = nn.Linear(h_dim, z_dim)
    
        # decoder
        self.z2hid = nn.Linear(z_dim, h_dim)
        self.hid2img = nn.Linear(h_dim, input_dim)
    
    def encode(self, x):
        h = F.relu(self.img2hid(x))
        mu, sigma = self.hid2mu(h), self.hid2sigma(h)
        return mu, sigma
    
    def decode(self, z):
        h = F.relu(self.z2hid(z))
        return F.sigmoid(self.hid2img(h))
    
    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma * epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma

In [18]:
x = torch.randn(4, 28*28)

In [19]:
model = VariationalAutoEncoder(input_dim=28*28)

In [23]:
x_reconstructed, mu, sigma = model(x)

In [24]:
x_reconstructed.shape

torch.Size([4, 784])

In [25]:
mu.shape

torch.Size([4, 20])

### Training Loop

In [27]:
import torchvision.datasets as datasets
from tqdm import tqdm
from torchvision import transforms as tfms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

In [42]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [33]:
INPUT_DIM = 28*28
H_DIM = 200
Z_DIM = 20
NUM_EPOCH = 3
BATCH_SIZE = 32
LR = 3e-4

In [31]:
3e-4

0.0003

In [38]:
dataset = datasets.MNIST(
    root="dataset/", train=True,
    transform=tfms.ToTensor(),
    download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [39]:
train_loader = DataLoader(
    dataset=dataset, batch_size=BATCH_SIZE,
    shuffle=True
)

In [43]:
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)

In [44]:
optimizier = torch.optim.Adam(model.parameters(), lr=LR)

In [45]:
loss_func = nn.BCELoss(reduction=sum)

In [None]:
for epoch in range(NUM_EPOCH):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) for loop:
        x = x.to(DEVICE).view(-1, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        loss = loss_func(x_reconstructed, x)
        kl_div 