In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, transforms

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from PIL import Image
import time
from tqdm import tqdm_notebook as tqdm
import itertools

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using %s for computation" % device)

In [None]:
project_dir = ''
dataset_dir = project_dir + 'celebA/'
images_dir = project_dir + 'images/'
model_dir = project_dir + 'models/'

In [None]:
batch_size = 32           # number of inputs in each batch
epochs = 2               # times to run the model on complete data
image_size = 64
hidden_size = 1024        # hidden dimension
latent_size = 128          # latent vector dimension
lr = 1e-3                 # learning rate
train_loss = []

In [None]:
# !apt-get install p7zip-full

In [None]:
# !7z e /celebA/img_align_celeba.zip -o/celebA/images/

In [None]:
class CelebDataset(Dataset):
    def __init__(self, root_dir, split=None, selected_attr=None, transform=None):

        self.root_dir = root_dir
        self.images_dir = os.path.join(root_dir, 'images/')
        self.attr_dir = os.path.join(root_dir, 'list_attr_celeba.csv')
        self.partition_dir = os.path.join(root_dir, 'list_eval_partition.csv')
        self.transform = transform
        self.selected_attr = selected_attr
        self.split = split
        self.preprocess()

    def preprocess(self):
        if self.selected_attr is None:
            self.num_attr = 40
            self.attr = pd.read_csv(self.attr_dir)
        else:
            self.num_attr = len(self.selected_attr)
            self.image_ids = pd.read_csv(self.attr_dir)["image_id"]
            self.attr = pd.read_csv(self.attr_dir)[self.selected_attr]

        # self.attr.replace(to_replace=-1, value=0, inplace=True)
        if self.split is not None:
            partition = pd.read_csv(self.partition_dir)
            if self.split == 'train':
                self.attr = self.attr[partition.partition == 0]
            elif self.split == 'valid':
                self.attr = self.attr[partition.partition == 1]
            elif self.split == 'test':
                self.attr = self.attr[partition.partition == 2]

        self.attr = self.attr.values.astype(np.float32)

    def __len__(self):
        return len(self.attr)

    def __getitem__(self, idx):
        name = self.image_ids[idx]
        image = Image.open(os.path.join(self.root_dir, 'images', name))
        img_attr = self.attr[idx]
        if self.transform is not None:
            image = self.transform(image)

        return image, img_attr

In [None]:
# All the Attributes available in the dataset
# all_columns = '5_o_Clock_Shadow	Arched_Eyebrows	Attractive	Bags_Under_Eyes	Bald	Bangs	Big_Lips	Big_Nose	Black_Hair	Blond_Hair	Blurry	Brown_Hair	Bushy_Eyebrows	Chubby	Double_Chin	Eyeglasses	Goatee	Gray_Hair	Heavy_Makeup	High_Cheekbones	Male	Mouth_Slightly_Open	Mustache	Narrow_Eyes	No_Beard	Oval_Face	Pale_Skin	Pointy_Nose	Receding_Hairline	Rosy_Cheeks	Sideburns	Smiling	Straight_Hair	Wavy_Hair	Wearing_Earrings	Wearing_Hat	Wearing_Lipstick	Wearing_Necklace	Wearing_Necktie	Young'

columns = 'Black_Hair	Blond_Hair	Brown_Hair	Male	No_Beard	Smiling	Straight_Hair	Wavy_Hair	Young'
columns = columns.split('\t')
num_columns = len(columns)

In [None]:
train_data = CelebDataset(root_dir=dataset_dir, split='train', selected_attr=columns,
                          transform=transforms.Compose([transforms.Resise(64), transforms.ToTensor()]))
valid_data = CelebDataset(root_dir=dataset_dir + 'celeb/', split='valid', selected_attr=columns, transform=transforms.Compose([transforms.ToTensor()]))
test_data = CelebDataset(root_dir=dataset_dir + 'celeb/', split='test', selected_attr=columns, transform=transforms.Compose([transforms.ToTensor()]))

In [None]:
trainloader = DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=16)
validloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
def show_images(images):
    images = torchvision.utils.make_grid(images)
    show_image(images)

def show_image(img):
    plt.imshow(img.permute(1, 2, 0), interpolation="bicubic")
    plt.show()

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=1024):
        return input.view(input.size(0), 1024, 1, 1)

In [None]:
images, attr = next((iter(trainloader)))
show_images(images)
images = images.cpu()

In [None]:
class VAE(nn.Module):
    def __init__(self, image_channels=3, image_dim=image_size, hidden_size=hidden_size, latent_size=latent_size, num_classes=num_columns):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, 4, 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2),
            nn.LeakyReLU(0.2),
            Flatten(),
        )
        self.encoder_mean = nn.Linear(hidden_size, latent_size)
        self.encoder_logvar = nn.Linear(hidden_size, latent_size)
        self.fc3 = nn.Linear(latent_size, latent_size - num_classes)
        self.fc4 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc1 = nn.Linear(hidden_size, hidden_size - num_classes)
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(hidden_size, 128, 5, 2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 5, 2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 6, 2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, 6, 2),
            nn.Sigmoid()
        )

    def sample(self, log_var, mean):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

    def forward(self, x, a):
        x = self.encoder(x)
        x = self.fc1(x)
        x = torch.cat((x, a), 1)
        x = self.fc2(x)

        log_var = self.encoder_logvar(x)
        mean = self.encoder_mean(x)
        z = self.sample(log_var, mean)

        z = self.fc3(z)
        z = torch.cat((z, a), 1)
        x = self.fc4(z)
        x = self.decoder(x)

        return x, mean, log_var


vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=lr)
vae.load_state_dict(torch.load(
    model_dir + "Conditional-VAE-full-dataset.pt", map_location=torch.device(device)))

In [None]:
vae.train()
for epoch in range(epochs):
    for i, (images, attr) in tqdm(enumerate(trainloader), total = len(trainloader)):
        images = images.to(device)
        attr = attr.to(device)
        optimizer.zero_grad()
        reconstructed_image, mean, log_var = vae(images, attr)
        CE = F.binary_cross_entropy(
            reconstructed_image, images, reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

        a = 0.01
        if (i % 3 == 0):
            a = 0.5

        b = 100.0
        if (i % 5 == 0):
            b = 1.0

        loss = CE + a * KLD
        loss.backward()
        train_loss.append(loss.item()/len(images)
        optimizer.step()
        if(i % 2000 == 0):
            torch.save(vae.state_dict(), model_dir +
                       "Conditional-VAE-full-dataset.pt")

        with torch.no_grad():
            if(i % 500 == 0):
                print("Epoch: %d" % epoch)
                print("Test Loss:")
                print(loss.item()/len(images))
                print("Original Images")
                show_images(images.cpu())
                print("Reconstructed Images")
                show_images(reconstructed_image.cpu())

                valid_loss = 0
                for j, (valid_images, _) in enumerate(testloader):
                    valid_images = valid_images.to(device)
                    valid_reconstructed_image, valid_mean, valid_log_var = vae(valid_images)

                    valid_CE = F.binary_cross_entropy(valid_reconstructed_image, valid_images, reduction='sum')
                    valid_KLD = -0.5 * torch.sum(1 + valid_log_var - valid_mean.pow(2) - valid_log_var.exp())
                    valid_loss = valid_loss + valid_CE + a * valid_KLD

                    if(j == len(testloader) - 1:
                        print("Validation Images")
                        show_images(valid_reconstructed_image.cpu())
                        print("Validation Loss:")
                        print(valid_loss.item()/len(testloader))


        

In [None]:
plt.plot(train_loss)
plt.show()

In [None]:
torch.save(vae.state_dict(), model_dir+"Conditional-VAE.pt")

In [None]:
columns = 'Black_Hair	Blond_Hair	Brown_Hair	Male	No_Beard	Smiling	Straight_Hair	Wavy_Hair	Young'
columns = columns.split('\t')
label = torch.FloatTensor([1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0])

In [None]:
z = torch.randn(128).mul(1.0)

In [None]:
values = {
    'Black_Hair': 1.,
    'Blond_Hair': 1.,
    'Brown_Hair': -1.,
    'Male': -1.,
    'No_Beard': 1.,
    'Smiling': 1.,
    'Straight_Hair': 1.,
    'Wavy_Hair': -1.,
    'Young': 1.
}

label = torch.FloatTensor(list(values.values()))

In [None]:
vae.eval()
# input_vector = torch.cat((z, label))
with torch.no_grad():
    z = z.to(device)
    label = label.to(device)
    z = vae.fc3(z)
    z = torch.cat((z, label))
    x = vae.fc4(z)
    x = x.reshape(1, -1)
    x = vae.decoder(x)
    # input_vector = input_vector.unsqueeze(0)
    # input_vector = input_vector.to(device)
    # print(input_vector)
    # x = vae.fc(input_vector)
    # print(x)
    # x = vae.decoder(x)
    show_images(x.cpu())

In [None]:
x.shape

In [None]:
vae.eval()
vectors = []
with torch.no_grad():
    for i, (images, labels) in enumerate(trainloader):
        images = images.to(device)
        reconstructed_image, mean, log_var = best(images)
        temp = list(zip(labels.tolist(), mean.tolist()))
        for x in temp:
            vectors.append(x)
        if(i % 100 == 0):
            show_images(reconstructed_image.cpu())
            img_name = images_dir + 'evaluation/noKD1' + \
                str(i).zfill(6) + '.png'
            torchvision.utils.save_image(
                torchvision.utils.make_grid(reconstructed_image), img_name)

In [None]:
import os
os.listdir(images_dir + 'test/')

In [None]:
from PIL import Image
from torch.autograd import Variable
vae.eval()
values = {
    'Black_Hair': 1.,
    'Blond_Hair': 1.,
    'Brown_Hair': 1.,
    'Male': 1.,
    'No_Beard': 1.,
    'Smiling': -1.,
    'Straight_Hair': 1.,
    'Wavy_Hair': -1.,
    'Young': 1.
}

label = torch.FloatTensor(list(values.values()))
eval_imgs = []

loader = transforms.Compose([
    #  transforms.CenterCrop(1000),
    transforms.Resize(64),
    transforms.ToTensor()
])

with torch.no_grad():
    img_name = '14.jpg'
    image = Image.open(images_dir + 'test/' + img_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    image = image.to(device)
    show_images(image.cpu())

    label = label.reshape(1, -1)
    label = label.to(device)
    reconstructed_image, mean, log_var = vae(image, label)
    show_images(reconstructed_image.cpu())
    eval_imgs.append(reconstructed_image)
    img_name = images_dir + 'test/reconstructednosmile' + img_name
    torchvision.utils.save_image(
        torchvision.utils.make_grid(reconstructed_image), img_name)

    # image = Image.open(images_dir + 'test/4.jpg')
    # image = Variable(loader(image))
    # image = image.unsqueeze(0)
    # image = image.to(device)
    # reconstructed_image, mean, log_var = vae(image)
    # # show_images(image.cpu())
    # # show_images(reconstructed_image.cpu())
    # # eval_imgs[1].append(reconstructed_image)
    # img_name = images_dir + 'test/reconstructed2' + str(j).zfill(1) + '.png'
    # torchvision.utils.save_image(torchvision.utils.make_grid(reconstructed_image), img_name)

In [None]:
# eval_imgs = torch.stack(eval_imgs)
eval_imgs = eval_imgs.squeeze()

In [None]:
show_images(eval_imgs.cpu())

In [None]:
vae.eval()
start = np.array([-1.8611,  0.3629, -0.1625,  0.6801,  1.2033,  1.0312,  0.5436,  1.3066,
                  0.2905,  0.1377,  0.5122, -0.1663,  2.3431, -0.0896, -0.5873, -1.4804,
                  0.8141, -1.2197,  0.0484,  0.6414, -0.8172, -0.9543, -0.8818, -1.1465,
                  0.2720,  1.1792,  1.8410, -0.4715,  1.4380,  0.5139,  1.2099, -0.5012])
middle = np.array([-0.4763, -0.4644, -0.3850,  0.6598,  0.9110,  0.4451,  0.4617, -0.0526,
                   0.2808,  0.6080,  0.5532, -1.5506, -0.5199,  0.1359,  0.0373,  0.4284,
                   -0.4134, -1.7078, -0.0309, -1.0195, -0.3151, -0.5569,  0.2832, -0.9132,
                   -1.1339, -1.3196,  2.1297,  0.8122,  0.6849, -0.6710, -0.3507, -0.9001])
end = np.array([-1.6239,  0.2496, -1.0690, -0.8745,  0.4133,  2.2452, -0.2385, -0.6532,
                0.3818, -0.9425,  0.9404,  1.3901, -0.3327, -0.3719, -0.0365,  0.3240,
                0.4928, -0.4988, -1.2228, -0.1638,  0.6093, -0.5264, -1.6963, -0.3718,
                2.1971,  0.2166, -0.0821, -0.1722, -0.1896, -1.6610, -0.1497,  1.0655])
points = 50
linfit = interpolate.interp1d(
    [1, points/2, points], np.vstack([start, middle, end]), axis=0)
with torch.no_grad():
    for i in range(2, points-1):
        z = linfit(i)
        z = torch.FloatTensor(z)
        print(z.shape)
        z = z.reshape((-1, 32))
        z = z.to(device)
        z = vae.fc(z)
        generated_images = vae.decoder(z)
        generated_images = generated_images.view(-1, 64, 64)
        img = generated_images[0].cpu()
        plt.imshow(img)
        img_name = images_dir + 'interpolate/' + str(i).zfill(3)
        plt.savefig(img_name)
        plt.show()

In [None]:
# m = '-0.106484	-0.009962	-0.019561	-0.008069	0.006305	-0.015933	0.019840	0.067150	-0.027852	0.033752	-0.023144	-0.103557	-0.020161	0.013926	-0.017485	-0.051400	-0.008751	-0.042782	-0.024165	-0.133409	-0.030269	0.002624	0.012955	0.073078	0.025009	-0.022863	-0.008983	0.060253	-0.023170	0.030583	-0.039951	0.073296'
# m = '-0.129811	-0.024494	0.004528	-0.003351	0.012485	-0.032028	0.029316	0.063780	-0.031832	0.020394	-0.026757	-0.109279	-0.050319	-0.006987	-0.029990	-0.054825	-0.022480	-0.048136	-0.018980	-0.136170	-0.013889	-0.001449	0.033935	0.058430	-0.006667	-0.036061	0.019455	0.048937	-0.009147	0.017413	-0.019323	0.070512'
m = '5.665212	0.582472	0.629977	0.606327	0.676091	0.872109	0.620518	0.688579	0.609230	0.607343	0.544618	0.473328	0.643389	0.590706	0.584597	0.541984	0.649343	0.534719	0.539281	0.486277	0.567356	0.586760	0.627318	0.705939	0.652336	0.535795	0.626065	0.665714	0.550902	0.598984	0.573023	0.617176'
m = m.split('\t')
m = [float(i) for i in m]

In [None]:
print(z.shape)

In [None]:
# i = i + 1
with torch.no_grad():
    for i in range(10):
        # z = torch.randn(32, device=device)
        # z = torch.FloatTensor(z)
        # z = z.reshape((-1, 32))
        # z = z.to(device)
        # print(z)
        z = torch.randn(128, device=device)
        z = z.reshape((-1, 128))
        print(z)
        z = vae.fc(z)
        generated_image = vae.decoder(z)
        generated_image = generated_image.view(3, 64, 64)
        show_images(generated_image.cpu())
        img_name = images_dir + 'generated/' + str(i).zfill(3) + '.png'
        torchvision.utils.save_image(generated_image, img_name)

In [None]:
labels, z_vectors = list(zip(*vectors))
z_vectors = torch.tensor(z_vectors)
# z_mean = torch.mean(torch.tensor(z_vectors), 0)
# z_vectors.sub_(z_mean.expand_as(z_vectors))
# U, S, V = torch.svd(torch.t(z_vectors))
# C = torch.mm(z_vectors, U[:, :2]).tolist()
# C = [x + [labels[i]] for i, x in enumerate(C)]

In [None]:
df = pd.DataFrame(z_vectors.numpy())
df.head()

In [None]:
df.describe()

In [None]:
df.columns = [str(i) for i in range(32)]

In [None]:
sns.lmplot(x='0', y='1', data=df, fit_reg=False, hue='3')

##9. Saving the model
Save the model incase we need to load it again.

In [None]:
torch.save(vae.state_dict(), model_dir+"Conditional-VAE-new-structure.pt")