## Libraries

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

## MNIST from Torchvision

In [None]:
import torchvision.datasets as datasets

In [None]:
mnist_trainset = mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)

In [None]:
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [None]:
mnist_trainset[9892]    # image mode=L means it's grayscale image with 1 channel

In [None]:
mnist_testset[9892]   # image mode=P RGB image with 3 channels

In [None]:
len(mnist_trainset)

In [None]:
len(mnist_testset)

In [None]:
mnist_trainset

In [None]:
mnist_trainset[2]

In [None]:
type(mnist_trainset[0])

In [None]:
type(mnist_trainset[0][1])

In [None]:
mnist_trainset[0][0]

In [None]:
mnist_trainset[0][1]

In [None]:
np_array = np.asarray(mnist_trainset[0][0])
np_array.max()

In [None]:
X_train = [img for img, label in mnist_trainset]       # list comprehension
y_train = [label for img, label in mnist_trainset]

In [None]:
X_test = [img for img, label in mnist_testset]
y_test = [label for img, label in mnist_testset]

In [None]:
type(y_train)

In [None]:
# One Hot Encoing on the labels

# num_classes = np.max(y_train) + 1     # Determine total class
# y_train = np.eye(num_classes)[y_train]

## Dataset

In [None]:
class MNIST(Dataset):
    def __init__(self, x, y=None):
        self.x = x
        self.y = y
    def __len__(self):
        return (len(self.x))
    def __getitem__(self, idx):
        img = self.x[idx]
        img = transforms.ToTensor()(img)
        img = (img*2) - 1
        img = img.float()
        img = img.reshape(1,28,28)

        if self.y is None:
            return img
        label = self.y[idx]
        label = torch.tensor(label)
        label = F.one_hot(label, num_classes=10).float()
        return img, label

In [None]:
train_dataset = MNIST(X_train, y_train)
test_dataset = MNIST(X_test, y_test)

In [None]:
BATCH_SIZE = 32
LR = 0.003
EPOCHS = 10

In [None]:
img, lab = train_dataset[240]
print(type(img))
print(type(lab))

## Dataloader

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# CVAE Model

In [None]:
ZDIM = 4

## Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=2*ZDIM):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels+10, out_channels=32, kernel_size=5, stride=2, padding=1)        # 11, 28, 38 -> 32, 13, 13
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=1)                 # 32, 13, 13 -> 64, 5, 5
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=1)                # 64, 5, 5 -> 128, 2, 2
        self.fc1 = nn.Linear(512, out_channels)

    def forward(self, x, y):
        B, C, H, W = x.shape
        y = y.reshape(B, 10, 1, 1)            # label shape (1, 10) -> (1, 10, 1, 1)
        y = y.expand(-1, -1, 28, 28)          # label shape (1, 10, 1, 1) -> (1, 10, 28, 28)
        x = torch.cat([x, y], 1)              # (1, 28, 28) + (10, 28, 28) -> (11, 28, 38)
        
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        return x

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels=ZDIM, out_channels=1):
        super().__init__()
        self.fc1 = nn.Linear(in_channels+10, 512)
        self.convt1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)       # 128,2,2 -> 64,4,4
        self.convt2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=5, stride=2, padding=1)       # 64,4,4 -> 32,9,9
        self.convt3 = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=14, stride=2, padding=1)       # 32,9,9 -> 16,28,28

    def forward(self, x, y):
        B, D = x.shape
        y = y.reshape(B, 10)
        x = torch.cat([x,y], 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = x.reshape(B, 128, 2, 2)
        x = self.convt1(x)
        x = F.relu(x)
        x = self.convt2(x)
        x = F.relu(x)
        x = self.convt3(x)
        
        return x

In [None]:
encoder = Encoder()
decoder = Decoder()

In [None]:
ex_idx = 19132
ex_image, label = train_dataset[ex_idx]
print(ex_image.shape)
print(label.shape)
ex_img = torch.unsqueeze(ex_image, dim=0)
label = torch.unsqueeze(label, dim=0)
print(ex_img.shape)
print(label.shape)
ex_pred = encoder(ex_img, label)
print(ex_pred.shape)

In [None]:
print(label)
print(label.shape)

In [None]:
# ex_img2 = torch.randn(ZDIM)
# ex_img2 = torch.unsqueeze(ex_img2, dim=0)
# ex_pred2 = decoder(ex_img2, label)
# ex_pred2.shape

## Training

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
encoder.to(device)
decoder.to(device)

In [None]:
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=LR)

In [None]:
from tqdm.auto import tqdm

for epochs in tqdm(range(EPOCHS)):
    for batch in train_loader:
        image, label = batch
        image = image.to(device)
        label = label.to(device)
        
        mu_logvar = encoder(image, label)

        mu = mu_logvar[:, :ZDIM]
        log_var = mu_logvar[:,ZDIM:]

        var = torch.exp(log_var)
        sigma = torch.sqrt(var)

        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon

        pred = decoder(z, label)
        
        # Loss Calculation
        recon_loss = F.mse_loss(pred, image)*28*28
        kl_div = - 0.5 * torch.mean(torch.sum(1 + log_var - mu**2 - var, dim=1))
        loss = recon_loss + kl_div

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
latent = torch.randn(ZDIM)
latent = torch.unsqueeze(latent, dim=0).to(device)
ex_label = 5
ex_label = torch.tensor(ex_label)
ex_label = F.one_hot(ex_label, num_classes=10).float()
ex_label = torch.unsqueeze(ex_label, dim=0).to(device)

with torch.no_grad():
    pred = decoder(latent, ex_label)
pred = torch.squeeze(pred, dim=0).to("cpu")
plt.imshow(pred[0], cmap='gray')
plt.show()

# FID Score

In [None]:
from torchvision import models

inception_v3 = models.inception_v3(pretrained=True)

In [None]:
inception_v3.fc = nn.Identity()

In [None]:
inception_v3.fc

In [None]:
num_samples = 1024

labels = torch.randint(low=0, high=10, size=(num_samples,))
labels = F.one_hot(labels, num_classes=10).float()

latents = torch.randn(num_samples, ZDIM)

with torch.no_grad():
    preds = decoder(latents.to(device), labels.to(device))
    
print("Prediction completed")
print(preds.shape)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1024)
images, _ = next(iter(test_loader))
i_images = images.to("cpu")
i_preds = preds.to("cpu")
B, C, H, W = i_images.shape

In [None]:
i_images.shape

In [None]:
transform = transforms.Compose([
    transforms.Resize((299,299)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
trans_images = []
trans_preds = []
for i_image, i_prd in zip(i_images, i_preds):
    pil_image = transforms.ToPILImage()(i_image)
    pil_pred = transforms.ToPILImage()(i_prd)
    
    trans_images.append(transform(pil_image))
    trans_preds.append(transform(pil_pred))

tensor_images = torch.stack(trans_images, dim=0)
tensor_preds = torch.stack(trans_preds, dim=0)

print(tensor_images.shape)
print(tensor_preds.shape)

In [None]:
@torch.no_grad
def calculate_fid(images, gen_images):
    act1 = inception_v3(images).logits
    act2 = inception_v3(gen_images).logits

    mu1 = torch.mean(act1, dim=0)
    cov_matrix1 = torch.cov(act1.T)

    mu2 = torch.mean(act2, dim=0)
    cov_matrix2 = torch.cov(act2.T)

    # Calculate FID
    fid = torch.norm(mu1 - mu2)**2 + torch.trace(cov_matrix1 + cov_matrix2 - 2 * torch.sqrt(cov_matrix1 * cov_matrix2))
    return fid

In [None]:
fid = calculate_fid(tensor_images, tensor_preds)

In [None]:
print(fid)