In [35]:
import torch
import torch.nn as nn
import os
import numpy as np
from PIL import Image
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchvision.utils import save_image, make_grid

In [67]:
data_path = 'dataset\\cats\\Data'
DEVICE = torch.device("cpu")

batch_size = 128
img_size = (64, 64) # (width, height)

input_dim = 3 * 64 * 64
hidden_dim = 512
latent_dim = 16
n_embeddings= 512
output_dim = 3 * 64 * 64
commitment_beta = 0.25

lr = 2e-4
epochs = 50

In [68]:
class CatDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path = self.file_list[idx]
        img = Image.open(path)
        return transform(img)

In [69]:
transform = transforms.Compose([
        transforms.Resize((64,64)), transforms.ToTensor()
    ])

files = [os.path.join(data_path, line.strip()) for line in os.listdir(data_path)]
dataset = CatDataset(files, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [70]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        
        self.flatten = nn.Flatten()
        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_input(x))
        h_       = self.LeakyReLU(self.FC_input2(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                       #             (i.e., parameters of simple tractable normal distribution "q"
        
        return mean, log_var

In [71]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_hidden(x))
        h     = self.LeakyReLU(self.FC_hidden2(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        x_hat = x_hat.view(-1, self.output_dim)
        return x_hat

In [72]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [73]:
encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = input_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

In [74]:
from torch.optim import Adam

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

In [82]:
print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, x in enumerate(loader):
        if x.shape[0] != batch_size:
            continue
        
        x = x.view(batch_size, input_dim)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / len(files))
    
print("Finish!!")

Start training VAE...
	Epoch 1 complete! 	Average Loss:  7686.889121820863
	Epoch 2 complete! 	Average Loss:  7326.677380373957
	Epoch 3 complete! 	Average Loss:  7226.854448698187
	Epoch 4 complete! 	Average Loss:  7189.889123915156
	Epoch 5 complete! 	Average Loss:  7157.862643668532
	Epoch 6 complete! 	Average Loss:  7138.695481352411
	Epoch 7 complete! 	Average Loss:  7130.442674999163
	Epoch 8 complete! 	Average Loss:  7111.999155999732
	Epoch 9 complete! 	Average Loss:  7099.561285309788
	Epoch 10 complete! 	Average Loss:  7095.293289046008
	Epoch 11 complete! 	Average Loss:  7091.4005776061385
	Epoch 12 complete! 	Average Loss:  7088.033328586268
	Epoch 13 complete! 	Average Loss:  7085.373984267668
	Epoch 14 complete! 	Average Loss:  7082.452863736889
	Epoch 15 complete! 	Average Loss:  7080.335840900714
	Epoch 16 complete! 	Average Loss:  7077.8019384780355
	Epoch 17 complete! 	Average Loss:  7076.178965754113
	Epoch 18 complete! 	Average Loss:  7074.160954830279
	Epoch 19 com

In [83]:
torch.save(model.state_dict(), "model.pth")

In [93]:
import matplotlib.pyplot as plt
def show_image(x, idx):
    x = x.view(10, 64, 64)

    fig = plt.figure()
    plt.imshow(x[idx].cpu().numpy())

In [94]:
with torch.no_grad():
    noise = torch.randn(10, latent_dim).to(DEVICE)
    generated_images = decoder(noise)

save_image(generated_images.view(10, 3, 64, 64), 'generated_sample.png', nrow=5)

In [129]:
x = next(el for el in loader)

In [130]:
def interpolate(gen, z1, z2, steps=10):
    alphas = torch.linspace(0,1,steps, device=DEVICE).view(-1,1,1,1)
    zs = (1-alphas)*z1 + alphas*z2
    with torch.no_grad(): imgs = gen(zs)
    return imgs

In [131]:
x1 = x[0]
x2 = x[1]

x1 = x1.view(1, input_dim)
x1 = x1.to(DEVICE)

x2 = x2.view(1, input_dim)
x2 = x2.to(DEVICE)

z1, z2 = encoder(x1), encoder(x2)

interp_grid = interpolate(decoder, z1[0], z2[0])
save_image(interp_grid.view(10, 3, 64, 64), 'interpolation.png', nrow=5)