In [None]:
from joblib import dump, load
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
from scipy.stats import entropy
from __future__ import print_function
import argparse
import time
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torchvision
import scipy as sp
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  999


<torch._C.Generator at 0x7fedd14fef78>

In [None]:
# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 50

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

#defines probability distribution of latent space
#noise_code = "normal"
#noise_code = "normal_naive_pca"
#noise_code = "normal_multivar_pca"
noise_code = "gm_pca"

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Root directory for dataset
root_dir = "/content/drive/MyDrive/ENSTA-Bretagne/Semestre 6/Projet/Code"
datasets_dir = os.path.join(root_dir, "Datasets")
models_dir = os.path.join(root_dir, "Models")
temp_dir = os.path.join(root_dir, "Temp")

In [None]:
class DatasetFromTensor(Dataset):
  def __init__(self, tensor, transform = False):
    self.tensor = tensor
    self.transform = transform
  
  def __len__(self):
    return self.tensor.shape[0]

  def __getitem__(self, idx):

    if self.transform:
      return (self.transform(self.tensor[idx]),0)
    else:
      return (self.tensor[idx],0)

In [None]:
dataset_name = "celeba"
transform =transforms.Compose([transforms.Resize(image_size), 
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

loadDataPath = "/content/drive/MyDrive/ENSTA-Bretagne/Semestre 6/Projet/Code/Datasets/celeba.pt"
data = torch.load(loadDataPath)

dataset = DatasetFromTensor(data, transform = transform)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

In [None]:
dataset_name = "cifar10"

transform =transforms.Compose([transforms.ToTensor(),
                              transforms.Resize(image_size), 
                               transforms.RandomHorizontalFlip(),
                               transforms.ColorJitter(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = torchvision.datasets.CIFAR10(datasets_dir, download = True, transform=transform, train = False)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

Files already downloaded and verified


In [None]:
dataset_name = "fashion-mnist"

transform =transforms.Compose([transforms.Resize(image_size), 
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

loadDataPath = "/content/drive/MyDrive/ENSTA-Bretagne/Semestre 6/Projet/Code/Datasets/fashion-mnist.pt"
data = torch.load(loadDataPath)

dataset = DatasetFromTensor(data, transform = transform)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

In [None]:
#loads PCA projections
path_load = os.path.join(temp_dir, f"{dataset_name}_PCAproj.pt")
noise = torch.load(path_load)

#loads gaussian mixture model
path_load = os.path.join(temp_dir, f"{dataset_name}_PCAproj_GM.joblib")
gm = load(path_load)

#noise statistics for density estimation
noise_mean = noise.mean(dim=0)
noise_cov = np.cov(noise.T)
noise_std = noise.std(dim=0)

def get_noise(batch_size, noise_code):
  if noise_code == "normal":

    return torch.randn( (batch_size, nz, 1, 1), dtype = torch.float )

  elif noise_code == "normal_naive_pca":

    out = np.random.normal(noise_mean, noise_std, size = (batch_size,nz))
    return torch.tensor(out, dtype=torch.float).view(batch_size,nz,1,1)

  elif noise_code == "normal_multivar_pca":

    out = np.random.multivariate_normal(noise_mean, noise_cov, size = (batch_size))
    return torch.tensor(out, dtype=torch.float).view(batch_size,nz,1,1)

  elif noise_code =="gm_pca":
    return torch.tensor(gm.sample(batch_size)[0], dtype=torch.float).view(batch_size,nz,1,1)
  else:
    return None

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
#Defining model
class Identity(torch.nn.Module):
  def __init__(self):
    super(Identity, self).__init__()
        
  def forward(self, x):
    return x

def get_feature_extractor(device):
  model = torchvision.models.inception_v3(pretrained=True)
  model.dropout = Identity()
  model.fc = Identity()
  model.to(device)
  model.eval()

  return model

def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

def calculate_fid(realDataset, fakeDataset, batch_size, device):

  model = get_feature_extractor(device)

  realAct = np.zeros( (len(realDataset), 2048) )
  fakeAct = np.zeros( (len(realDataset), 2048) )

  realLoader = torch.utils.data.DataLoader(realDataset, batch_size = batch_size)
  fakeLoader = torch.utils.data.DataLoader(fakeDataset, batch_size = batch_size)

  #Produce model activations
  up = nn.Upsample(size=(299, 299), mode='bilinear')
  with torch.no_grad():

    for i, (imgs, _) in enumerate(realLoader):
      bsize = imgs.shape[0]
      out = model(up(imgs).to(device))

      realAct[i*batch_size:i*batch_size + bsize] = out.cpu().numpy()

    for i, (imgs, _) in enumerate(fakeLoader):
      bsize = imgs.shape[0]
      out = model(up(imgs).to(device))

      fakeAct[i*batch_size:i*batch_size + bsize] = out.cpu().numpy()

  # calculate mean and covariance statistics
  mu1, sigma1 = realAct.mean(axis=0), np.cov(realAct, rowvar=False)
  mu2, sigma2 = fakeAct.mean(axis=0), np.cov(fakeAct, rowvar=False)
	# calculate sum squared difference between means
  ssdiff = np.sum((mu1 - mu2)**2.0)
  # calculate sqrt of product between cov
  covmean = sp.linalg.sqrtm(sigma1.dot(sigma2))
  # check and correct imaginary numbers from sqrt
  if np.iscomplexobj(covmean):
    covmean = covmean.real
  # calculate score
  fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

  return fid

#definition of real and fake datasets
def get_fid_full(realDataset, generator):
  with torch.no_grad():
    generator.eval()
    z = get_noise(len(realDataset),noise_code).to(device)
    out = generator(z)

  fakeDataset = DatasetFromTensor(out)

  fid = calculate_fid(realDataset, fakeDataset, 512, device = device)

  return fid

In [None]:
t_start = time.time()

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = get_noise(64, noise_code).to(device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress

FIDs = []
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    epoch_G_losses = []
    epoch_D_losses = []
    # For each batch in the dataloader
    for i, data in enumerate(dataloader):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors

        noise = get_noise(b_size, noise_code).to(device)

        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        epoch_G_losses.append(errG.item()*b_size)
        epoch_D_losses.append(errD.item()*b_size)

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
        
    G_losses.append(np.mean(epoch_G_losses))
    D_losses.append(np.mean(epoch_D_losses))
    #FIDs.append(get_fid_full(dataset, netG))

t_end = time.time()
print(t_end-t_start)

In [None]:
checkpoint = {"netD":netD.state_dict(),
              "netG":netG.state_dict(),
              "optimizerD":optimizerD.state_dict(),
              "optimizerG":optimizerG.state_dict(),
              "epoch":num_epochs,
              "G_losses": G_losses,
              "D_losses": D_losses,
              "FIDs":FIDs}

if noise_code == "normal":
  base_name = ""
elif noise_code == "normal_naive_pca":
  base_name = "_naivePCA"
elif noise_code == "normal_multivar_pca":
  base_name = "_multiVarPCA"
elif noise_code == "gm_pca":
  base_name = "_GMPCA"
else:
  assert 1 == 0

path_checkpoint = os.path.join(models_dir, f"DCGAN{base_name}_{dataset_name}_{num_epochs}epochs.pt")
torch.save(checkpoint, path_checkpoint)

In [None]:
if noise_code == "normal":
  base_name = ""
elif noise_code == "normal_naive_pca":
  base_name = "_naivePCA"
elif noise_code == "normal_multivar_pca":
  base_name = "_multiVarPCA"
elif noise_code == "gm_pca":
  base_name = "_GMPCA"
else:
  assert 1 == 0

path_checkpoint = os.path.join(models_dir, f"DCGAN{base_name}_{dataset_name}_{num_epochs}epochs.pt")

loaded_checkpoint = torch.load(path_checkpoint)

netD = Discriminator(ngpu).to(device)
netD.load_state_dict(loaded_checkpoint["netD"])

netG = Generator(ngpu).to(device)
netG.load_state_dict(loaded_checkpoint["netG"])

<All keys matched successfully>

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")

print(real_batch[0].to(device)[:64].shape)

plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
def normalizeImg(img):
  return (img - img.min())/(img.max() - img.min())

def showClosetImage(dataset, genImg):
  genImg = genImg.cpu()[0]

  #Closest image using RMES
  metric = nn.MSELoss()

  min_dist = 10**10
  min_i_0 = None
  
  for i, (image, _) in enumerate(dataset):
    
    #print(f"Image #{i}", end = " ")
    dist = metric(genImg, image)
    #print(f"Dist: {dist}")

    if dist < min_dist:
      min_dist = dist
      min_i_0 = i
  
  #Closest image using inception distance
  realAct = torch.zeros( (len(dataset), 2048) , device = device)

  loader = torch.utils.data.DataLoader(dataset, batch_size = 128)

  #Defining model
  model = get_feature_extractor(device)

  #Produce model activations
  up = nn.Upsample(size=(299, 299), mode='bilinear')
  with torch.no_grad():

    for i, (imgs, _) in enumerate(loader):
      bsize = imgs.shape[0]
      out = model(up(imgs).to(device))

      realAct[i*batch_size:i*batch_size + bsize] = out

    genImgAct = model(up(genImg.view(1,3,64,64)).to(device)).view(2048)

  min_dist = 10**10
  min_i_1 = None

  metric = nn.MSELoss()
  for i in range(realAct.shape[0]):
    dist = metric(realAct[i], genImgAct)

    if dist < min_dist:
      min_dist = dist
      min_i_1 = i

  fig, ax = plt.subplots(1,3, figsize = (10,10))
  ax[0].imshow(normalizeImg(genImg.permute(1,2,0)))
  ax[1].imshow(normalizeImg(dataset[min_i_0][0].permute(1,2,0)))
  ax[2].imshow(normalizeImg(dataset[min_i_1][0].permute(1,2,0)))

  ax[0].set_title("Generated image")
  ax[1].set_title("Closest image in dataset (RMSE)")
  ax[2].set_title("Closest image in dataset (ID)")

  ax[0].axis("off")
  ax[1].axis("off")
  ax[2].axis("off")

with torch.no_grad():
  noise = get_noise(1,noise_code).to(device)
  output = netG(noise)

start = time.time()
showClosetImage(dataset, output)
end = time.time()
print(end - start)

In [None]:
batch = get_noise(8, noise_code).to(device)
with torch.no_grad():
  output = netG(batch)

plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(output, nrow = 8 ,padding=1, normalize=True, pad_value=1).cpu(),(1,2,0)))

samples from different gaussian mixture components

In [None]:
gm_means = gm.means_
gm_covs = gm.covariances_

fig, ax = plt.subplots(10,2, figsize=(2*31,2*20))

with torch.no_grad():
  for i in range(10):
    noise1 = torch.tensor(np.random.multivariate_normal(gm_means[i], gm_covs[i], size = 10),dtype =torch.float).view(10,100,1,1).to(device)
    out1 = netG(noise1)

    noise2 = torch.tensor(gm_means[i], dtype = torch.float).view(1,100,1,1).to(device)
    out2 = netG(noise2)

    ax[i][0].imshow(np.transpose(vutils.make_grid(out2, nrow = 1 ,padding=1, normalize=True, pad_value=.5).cpu(),(1,2,0)))
    ax[i][1].imshow(np.transpose(vutils.make_grid(out1, nrow = 10 ,padding=1, normalize=True, pad_value=1).cpu(),(1,2,0)))
    ax[i][1].axis("off")
    ax[i][0].axis("off")

  plt.tight_layout()
  plt.show()