In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import seed_everything, to_var, count_parameters, show_image_grid
# from models import VAE
import torch.nn.functional as F
import time
from matplotlib import pyplot as plt
from typing import List
import numpy as np

In [None]:
seed_everything(42)
gc.collect()
# CONST TABLE
MPS_FLAG = torch.backends.mps.is_available()
if MPS_FLAG:
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Backend device: {}".format(device))

BATCH_SIZE = 128
ENC_HID_DIM = [4, 8, 16, 32]
DEC_HID_DIM = [8, 32, 16, 8]
LATENT_DIM = 16
PRINT_FREQ = 100
EPOCHES = 15
KL_WEIGHT = 0.002
LEARNING_RATE = 2e-4
data_path = "../assets/mnist/"

transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

dataset1 = datasets.MNIST(data_path, train=True, download=True, transform=transform)
dataset2 = datasets.MNIST(data_path, train=False, transform=transform)
train_loader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset2, batch_size=BATCH_SIZE)

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        in_channels=1,
        hidden_dims=[4, 8, 16, 32],
        latent_dim=16,
        device=torch.device("cuda" if torch.cuda.is_available() else "mps"),
    ) -> None:
        super().__init__()
        # input shape [b, 1, 28, 28]
        self.latent_dim = latent_dim
        self.device = device
        self.hidden_dims = hidden_dims
        # for image processing we consider conv layers
        modules = []
        for h_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(self.hidden_dims[-1] * 4, latent_dim)
        self.fc_var = nn.Linear(self.hidden_dims[-1] * 4, latent_dim)

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        o = self.encoder(x)
        o = torch.flatten(o, start_dim=1)
        mu = self.fc_mu(o)
        log_var = self.fc_var(o)
        return [mu, log_var]


class ConvBlock(nn.Module):
    def __init__(
        self, hidden_dims: List[int] = [8, 32, 16, 8], latent_dim: int = 16
    ) -> None:
        super(ConvBlock, self).__init__()
        conv_block = []
        for i in range(len(hidden_dims) - 1):
            conv_block.append(
                nn.Sequential(
                    nn.Conv2d(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    ),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class Decoder(nn.Module):
    def __init__(
        self,
        hidden_dims=[8, 32, 16, 8],
        latent_dim=16,
        device=torch.device("cuda" if torch.cuda.is_available() else "mps"),
    ) -> None:
        super().__init__()
        self.latent_dim = latent_dim
        self.device = device
        self.hidden_dims = hidden_dims

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[0] * 49)
        self.upsampling = nn.Upsample(scale_factor=4, mode="bilinear")
        self.conv_block1 = ConvBlock(hidden_dims, latent_dim)
        self.conv_block2 = ConvBlock(hidden_dims, latent_dim)
        self.output_layer = nn.Sequential(
            nn.Conv2d(hidden_dims[-1], out_channels=1, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z: torch.Tensor) -> List[torch.Tensor]:
        o = self.decoder_input(z)
        o = o.view(-1, self.hidden_dims[0], 7, 7)
        o = self.upsampling(o)
        o = self.conv_block1(o) + o
        o = self.conv_block2(o) + o
        o = self.output_layer(o)
        return o


class VAE(nn.Module):
    def __init__(
        self,
        encoder_hidden_dim=[4, 8, 16, 32],
        decoder_hidden_dim=[8, 32, 16, 8],
        latent_dim=16,
        device=torch.device("cuda" if torch.cuda.is_available() else "mps"),
    ) -> None:
        super().__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.enc = Encoder(
            hidden_dims=encoder_hidden_dim, latent_dim=self.latent_dim
        ).to(device)
        self.dec = Decoder(
            hidden_dims=decoder_hidden_dim, latent_dim=self.latent_dim
        ).to(device)

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std, requires_grad=False)
        return eps * std + mu

    def loss_function(self, xhat, x, mu, log_var, kl_weight=0.0025) -> dict:
        recon_loss = F.mse_loss(xhat, x, reduction="mean")
        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )
        loss = recon_loss + kl_weight * kld_loss
        return {
            "loss": loss,
            "recon_loss": recon_loss.detach(),
            "KLD": kld_loss.detach(),
        }

    def forward(self, x) -> List[torch.Tensor]:
        mu, log_var = self.enc(x)
        z = self.reparameterize(mu, log_var)
        xhat = self.dec(z)
        return [xhat, mu, log_var]

In [None]:
model = VAE(
    encoder_hidden_dim=ENC_HID_DIM,
    decoder_hidden_dim=DEC_HID_DIM,
    latent_dim=LATENT_DIM,
    device=device,
).to(device)
print("Encoder Parameters: {}".format(count_parameters(model.enc)))
print("Decoder Parameters: {}".format(count_parameters(model.dec)))
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
best_loss = 200
for epoch in range(EPOCHES):
    start_time = time.time()
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        xhat, mu, log_var = model(images)
        # loss = F.mse_loss(images, xhat[0])
        loss_dict = model.loss_function(
            xhat=xhat, x=images, mu=mu, log_var=log_var, kl_weight=KL_WEIGHT
        )
        loss = loss_dict["loss"]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % PRINT_FREQ == 0:
            print(
                "[{:0>2d}][{}/{}]: loss: {:.4f}\t rec_loss: {:.4f}\t kld_loss: {:.4f}\t Epoch time {:.3f}".format(
                    epoch,
                    i,
                    len(train_loader),
                    loss.item(),
                    loss_dict["recon_loss"].item(),
                    loss_dict["KLD"].item(),
                    time.time() - start_time,
                )
            )
    model.eval()
    loss_ = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            xhat, mu, log_var = model(images)
            loss_ += model.loss_function(
                xhat=xhat, x=images, mu=mu, log_var=log_var, kl_weight=KL_WEIGHT
            )[
                "loss"
            ].item()  # F.mse_loss(images, xhat).item()
        print(
            "[{:0>2d}]\tValidation Loss: {:.4f}".format(epoch, loss_ / len(test_loader))
        )
        if loss_ / len(test_loader) <= best_loss:
            best_loss = loss_ / len(test_loader)
            torch.save(model.state_dict(), "./vae_nano.pt")
            print("Model Saved.")
# 0.0357, KL_WEIGHT = 0.0025

In [None]:
model_restore = VAE(latent_dim=LATENT_DIM).to(device)
model_restore.load_state_dict(torch.load("./vae_nano.pt".format(LATENT_DIM)))
BATCH_SIZE = 64
train_loader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset2, batch_size=BATCH_SIZE)
for i, (images, labels) in enumerate(train_loader):
    show_image_grid(images, BATCH_SIZE)
    with torch.no_grad():
        images = images.to(device)
        xhat = model_restore(images)
        # print(xhat[1], xhat[2].exp())
        show_image_grid(xhat[0].cpu(), BATCH_SIZE)
    break
for i, (images, labels) in enumerate(test_loader):
    show_image_grid(images, BATCH_SIZE)
    with torch.no_grad():
        images = images.to(device)
        xhat = model_restore(images)
        # print(xhat[1], xhat[2].exp())
        show_image_grid(xhat[0].cpu(), BATCH_SIZE)
    break

In [None]:
print(xhat[1].shape,xhat[1].mean(),xhat[1].std())
plt.imshow(xhat[1].cpu().detach().numpy(), cmap='gray')
plt.colorbar()

In [None]:
print(xhat[2].shape,xhat[2].exp().mean(),xhat[2].exp().std())
plt.imshow(xhat[2].exp().cpu().detach().numpy())
plt.colorbar()