# GAN metric evaluation

In [43]:
%matplotlib  inline
import matplotlib.pyplot as plt
import os
import numpy as np
import math
import pickle

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

### Prepare initial set up

In [44]:
n_epochs = 2000 # number of epochs of training
batch_size = 1000 # size of  the batches
lr = 0.0002 # learning rate
b1 = 0.5  #  adam: decay of first order momentum of gradient
b2 = 0.999  #  adam: decay of first order momentum of gradient"
n_cpu = 8  #  number of cpu threads to use during batch generation
latent_dim = 2 #5 # dimensionality of the latent space
img_size = 2 #5 # size of each image dimension
channels = 1 # number of image channels
sample_interval = 400 # interval betwen image samples
outf = 'models3' # save models at each iteration during epochs

In [45]:
os.makedirs("images_generated3", exist_ok=True)
os.makedirs(outf, exist_ok=True)

In [46]:
img_shape = (channels, img_size, img_size)

In [47]:
img_shape

(1, 2, 2)

In [48]:
cuda = True if torch.cuda.is_available() else False

In [49]:
cuda

True

### Helper functions

In [50]:
from torch.utils import data

class Dataset(data.Dataset):
    #'Characterizes a dataset for PyTorch'
    def __init__(self, list_IDs, labels):
    #'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs

    def __len__(self):
        #'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        # 'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        X = torch.load('my_data/' + ID + '.pt')
        y = self.labels[ID]

        return X, y

### Models

In [51]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 4, normalize=False),
            *block(4, 8),
            *block(8, 16),
            *block(16, 32),
            nn.Linear(32, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [52]:
28*28

784

In [53]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4, 1),
            #nn.LeakyReLU(0.2, inplace=True),
            #nn.Linear(8, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [54]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

In [55]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [56]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

### Load data

In [57]:
# load data
[partition_list, labels] = pickle.load(open('train_test_info.pkl', 'rb'))

In [58]:
training_dataset = Dataset(partition_list['train'], labels)
test_dataset = Dataset(partition_list['test'], labels)

In [59]:
train_loader = data.DataLoader(training_dataset, batch_size=batch_size)
test_loader = data.DataLoader(test_dataset, batch_size=batch_size)

### Train data

In [60]:
generator_loss = []
discriminator_loss = []
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
        % (epoch, n_epochs, i, len(train_loader), d_loss.item(), g_loss.item())
    )
    
    generator_loss.append(d_loss.item())
    discriminator_loss.append(g_loss.item())

        #batches_done = epoch * len(train_loader) + i
        #if batches_done % sample_interval == 0:
            #save_image(gen_imgs.data[:25], "images_generated/%d.png" % batches_done, nrow=5, normalize=True)
    torch.save(generator.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(discriminator.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))

[Epoch 0/2000] [Batch 9/10] [D loss: 0.689490] [G loss: 0.607883]
[Epoch 1/2000] [Batch 9/10] [D loss: 0.689383] [G loss: 0.609073]
[Epoch 2/2000] [Batch 9/10] [D loss: 0.689295] [G loss: 0.610248]
[Epoch 3/2000] [Batch 9/10] [D loss: 0.689242] [G loss: 0.611379]
[Epoch 4/2000] [Batch 9/10] [D loss: 0.689205] [G loss: 0.612498]
[Epoch 5/2000] [Batch 9/10] [D loss: 0.689225] [G loss: 0.613534]
[Epoch 6/2000] [Batch 9/10] [D loss: 0.689261] [G loss: 0.614557]
[Epoch 7/2000] [Batch 9/10] [D loss: 0.689323] [G loss: 0.615552]
[Epoch 8/2000] [Batch 9/10] [D loss: 0.689428] [G loss: 0.616486]
[Epoch 9/2000] [Batch 9/10] [D loss: 0.689554] [G loss: 0.617400]
[Epoch 10/2000] [Batch 9/10] [D loss: 0.689714] [G loss: 0.618276]
[Epoch 11/2000] [Batch 9/10] [D loss: 0.689899] [G loss: 0.619130]
[Epoch 12/2000] [Batch 9/10] [D loss: 0.690123] [G loss: 0.619942]
[Epoch 13/2000] [Batch 9/10] [D loss: 0.690336] [G loss: 0.620793]
[Epoch 14/2000] [Batch 9/10] [D loss: 0.690578] [G loss: 0.621618]
[Epoc

[Epoch 123/2000] [Batch 9/10] [D loss: 0.694545] [G loss: 0.706689]
[Epoch 124/2000] [Batch 9/10] [D loss: 0.692862] [G loss: 0.709305]
[Epoch 125/2000] [Batch 9/10] [D loss: 0.694195] [G loss: 0.705728]
[Epoch 126/2000] [Batch 9/10] [D loss: 0.694977] [G loss: 0.703325]
[Epoch 127/2000] [Batch 9/10] [D loss: 0.694616] [G loss: 0.703012]
[Epoch 128/2000] [Batch 9/10] [D loss: 0.695693] [G loss: 0.699869]
[Epoch 129/2000] [Batch 9/10] [D loss: 0.694821] [G loss: 0.700630]
[Epoch 130/2000] [Batch 9/10] [D loss: 0.695349] [G loss: 0.698713]
[Epoch 131/2000] [Batch 9/10] [D loss: 0.694232] [G loss: 0.700049]
[Epoch 132/2000] [Batch 9/10] [D loss: 0.695057] [G loss: 0.697598]
[Epoch 133/2000] [Batch 9/10] [D loss: 0.694410] [G loss: 0.698038]
[Epoch 134/2000] [Batch 9/10] [D loss: 0.693123] [G loss: 0.699776]
[Epoch 135/2000] [Batch 9/10] [D loss: 0.692933] [G loss: 0.699379]
[Epoch 136/2000] [Batch 9/10] [D loss: 0.692370] [G loss: 0.699683]
[Epoch 137/2000] [Batch 9/10] [D loss: 0.691787]

[Epoch 244/2000] [Batch 9/10] [D loss: 0.677311] [G loss: 0.719594]
[Epoch 245/2000] [Batch 9/10] [D loss: 0.677286] [G loss: 0.718769]
[Epoch 246/2000] [Batch 9/10] [D loss: 0.677044] [G loss: 0.718221]
[Epoch 247/2000] [Batch 9/10] [D loss: 0.677089] [G loss: 0.717106]
[Epoch 248/2000] [Batch 9/10] [D loss: 0.677000] [G loss: 0.716235]
[Epoch 249/2000] [Batch 9/10] [D loss: 0.677728] [G loss: 0.713597]
[Epoch 250/2000] [Batch 9/10] [D loss: 0.676529] [G loss: 0.714936]
[Epoch 251/2000] [Batch 9/10] [D loss: 0.677209] [G loss: 0.712423]
[Epoch 252/2000] [Batch 9/10] [D loss: 0.676623] [G loss: 0.712489]
[Epoch 253/2000] [Batch 9/10] [D loss: 0.676814] [G loss: 0.710912]
[Epoch 254/2000] [Batch 9/10] [D loss: 0.675891] [G loss: 0.711457]
[Epoch 255/2000] [Batch 9/10] [D loss: 0.675690] [G loss: 0.710817]
[Epoch 256/2000] [Batch 9/10] [D loss: 0.675953] [G loss: 0.709116]
[Epoch 257/2000] [Batch 9/10] [D loss: 0.675302] [G loss: 0.709291]
[Epoch 258/2000] [Batch 9/10] [D loss: 0.674837]

[Epoch 365/2000] [Batch 9/10] [D loss: 0.710694] [G loss: 0.683301]
[Epoch 366/2000] [Batch 9/10] [D loss: 0.709973] [G loss: 0.685327]
[Epoch 367/2000] [Batch 9/10] [D loss: 0.711714] [G loss: 0.682202]
[Epoch 368/2000] [Batch 9/10] [D loss: 0.710253] [G loss: 0.685467]
[Epoch 369/2000] [Batch 9/10] [D loss: 0.710247] [G loss: 0.685572]
[Epoch 370/2000] [Batch 9/10] [D loss: 0.709702] [G loss: 0.686682]
[Epoch 371/2000] [Batch 9/10] [D loss: 0.709434] [G loss: 0.687157]
[Epoch 372/2000] [Batch 9/10] [D loss: 0.708106] [G loss: 0.689633]
[Epoch 373/2000] [Batch 9/10] [D loss: 0.707492] [G loss: 0.690705]
[Epoch 374/2000] [Batch 9/10] [D loss: 0.707182] [G loss: 0.690986]
[Epoch 375/2000] [Batch 9/10] [D loss: 0.707236] [G loss: 0.690337]
[Epoch 376/2000] [Batch 9/10] [D loss: 0.706558] [G loss: 0.691202]
[Epoch 377/2000] [Batch 9/10] [D loss: 0.705648] [G loss: 0.692454]
[Epoch 378/2000] [Batch 9/10] [D loss: 0.705805] [G loss: 0.691532]
[Epoch 379/2000] [Batch 9/10] [D loss: 0.704363]

[Epoch 486/2000] [Batch 9/10] [D loss: 0.666976] [G loss: 0.730759]
[Epoch 487/2000] [Batch 9/10] [D loss: 0.667103] [G loss: 0.732882]
[Epoch 488/2000] [Batch 9/10] [D loss: 0.668317] [G loss: 0.732580]
[Epoch 489/2000] [Batch 9/10] [D loss: 0.667416] [G loss: 0.736483]
[Epoch 490/2000] [Batch 9/10] [D loss: 0.666571] [G loss: 0.740070]
[Epoch 491/2000] [Batch 9/10] [D loss: 0.666194] [G loss: 0.742561]
[Epoch 492/2000] [Batch 9/10] [D loss: 0.665740] [G loss: 0.744999]
[Epoch 493/2000] [Batch 9/10] [D loss: 0.666484] [G loss: 0.744645]
[Epoch 494/2000] [Batch 9/10] [D loss: 0.666654] [G loss: 0.745398]
[Epoch 495/2000] [Batch 9/10] [D loss: 0.666394] [G loss: 0.746939]
[Epoch 496/2000] [Batch 9/10] [D loss: 0.667847] [G loss: 0.744708]
[Epoch 497/2000] [Batch 9/10] [D loss: 0.667173] [G loss: 0.746974]
[Epoch 498/2000] [Batch 9/10] [D loss: 0.668633] [G loss: 0.744505]
[Epoch 499/2000] [Batch 9/10] [D loss: 0.670348] [G loss: 0.741456]
[Epoch 500/2000] [Batch 9/10] [D loss: 0.670236]

[Epoch 607/2000] [Batch 9/10] [D loss: 0.680068] [G loss: 0.702656]
[Epoch 608/2000] [Batch 9/10] [D loss: 0.681178] [G loss: 0.701636]
[Epoch 609/2000] [Batch 9/10] [D loss: 0.680778] [G loss: 0.703566]
[Epoch 610/2000] [Batch 9/10] [D loss: 0.680731] [G loss: 0.704588]
[Epoch 611/2000] [Batch 9/10] [D loss: 0.679247] [G loss: 0.708688]
[Epoch 612/2000] [Batch 9/10] [D loss: 0.679834] [G loss: 0.708282]
[Epoch 613/2000] [Batch 9/10] [D loss: 0.679001] [G loss: 0.710753]
[Epoch 614/2000] [Batch 9/10] [D loss: 0.678480] [G loss: 0.712506]
[Epoch 615/2000] [Batch 9/10] [D loss: 0.678624] [G loss: 0.712826]
[Epoch 616/2000] [Batch 9/10] [D loss: 0.677165] [G loss: 0.716382]
[Epoch 617/2000] [Batch 9/10] [D loss: 0.677203] [G loss: 0.716774]
[Epoch 618/2000] [Batch 9/10] [D loss: 0.676499] [G loss: 0.718664]
[Epoch 619/2000] [Batch 9/10] [D loss: 0.675888] [G loss: 0.720214]
[Epoch 620/2000] [Batch 9/10] [D loss: 0.675558] [G loss: 0.721121]
[Epoch 621/2000] [Batch 9/10] [D loss: 0.674954]

[Epoch 728/2000] [Batch 9/10] [D loss: 0.634917] [G loss: 0.790429]
[Epoch 729/2000] [Batch 9/10] [D loss: 0.631975] [G loss: 0.794893]
[Epoch 730/2000] [Batch 9/10] [D loss: 0.629001] [G loss: 0.799428]
[Epoch 731/2000] [Batch 9/10] [D loss: 0.626874] [G loss: 0.801877]
[Epoch 732/2000] [Batch 9/10] [D loss: 0.625007] [G loss: 0.803586]
[Epoch 733/2000] [Batch 9/10] [D loss: 0.621628] [G loss: 0.809295]
[Epoch 734/2000] [Batch 9/10] [D loss: 0.619428] [G loss: 0.811954]
[Epoch 735/2000] [Batch 9/10] [D loss: 0.617175] [G loss: 0.814718]
[Epoch 736/2000] [Batch 9/10] [D loss: 0.614421] [G loss: 0.818962]
[Epoch 737/2000] [Batch 9/10] [D loss: 0.612800] [G loss: 0.820102]
[Epoch 738/2000] [Batch 9/10] [D loss: 0.610399] [G loss: 0.823492]
[Epoch 739/2000] [Batch 9/10] [D loss: 0.608784] [G loss: 0.824761]
[Epoch 740/2000] [Batch 9/10] [D loss: 0.606463] [G loss: 0.827970]
[Epoch 741/2000] [Batch 9/10] [D loss: 0.604267] [G loss: 0.830929]
[Epoch 742/2000] [Batch 9/10] [D loss: 0.601945]

[Epoch 849/2000] [Batch 9/10] [D loss: 0.719539] [G loss: 0.661425]
[Epoch 850/2000] [Batch 9/10] [D loss: 0.717745] [G loss: 0.664278]
[Epoch 851/2000] [Batch 9/10] [D loss: 0.717446] [G loss: 0.664444]
[Epoch 852/2000] [Batch 9/10] [D loss: 0.717350] [G loss: 0.664750]
[Epoch 853/2000] [Batch 9/10] [D loss: 0.716875] [G loss: 0.665968]
[Epoch 854/2000] [Batch 9/10] [D loss: 0.715918] [G loss: 0.667505]
[Epoch 855/2000] [Batch 9/10] [D loss: 0.714198] [G loss: 0.669189]
[Epoch 856/2000] [Batch 9/10] [D loss: 0.716047] [G loss: 0.666057]
[Epoch 857/2000] [Batch 9/10] [D loss: 0.715838] [G loss: 0.666439]
[Epoch 858/2000] [Batch 9/10] [D loss: 0.718133] [G loss: 0.662353]
[Epoch 859/2000] [Batch 9/10] [D loss: 0.718164] [G loss: 0.662570]
[Epoch 860/2000] [Batch 9/10] [D loss: 0.718558] [G loss: 0.661501]
[Epoch 861/2000] [Batch 9/10] [D loss: 0.720268] [G loss: 0.659176]
[Epoch 862/2000] [Batch 9/10] [D loss: 0.718815] [G loss: 0.661324]
[Epoch 863/2000] [Batch 9/10] [D loss: 0.721365]

[Epoch 970/2000] [Batch 9/10] [D loss: 0.653305] [G loss: 0.753074]
[Epoch 971/2000] [Batch 9/10] [D loss: 0.652375] [G loss: 0.755184]
[Epoch 972/2000] [Batch 9/10] [D loss: 0.652436] [G loss: 0.754933]
[Epoch 973/2000] [Batch 9/10] [D loss: 0.651851] [G loss: 0.756114]
[Epoch 974/2000] [Batch 9/10] [D loss: 0.652002] [G loss: 0.755804]
[Epoch 975/2000] [Batch 9/10] [D loss: 0.651669] [G loss: 0.756431]
[Epoch 976/2000] [Batch 9/10] [D loss: 0.652229] [G loss: 0.755105]
[Epoch 977/2000] [Batch 9/10] [D loss: 0.652164] [G loss: 0.755133]
[Epoch 978/2000] [Batch 9/10] [D loss: 0.651502] [G loss: 0.756537]
[Epoch 979/2000] [Batch 9/10] [D loss: 0.651512] [G loss: 0.756388]
[Epoch 980/2000] [Batch 9/10] [D loss: 0.651243] [G loss: 0.756933]
[Epoch 981/2000] [Batch 9/10] [D loss: 0.652099] [G loss: 0.754949]
[Epoch 982/2000] [Batch 9/10] [D loss: 0.652367] [G loss: 0.754308]
[Epoch 983/2000] [Batch 9/10] [D loss: 0.652325] [G loss: 0.754466]
[Epoch 984/2000] [Batch 9/10] [D loss: 0.653303]

[Epoch 1090/2000] [Batch 9/10] [D loss: 0.688337] [G loss: 0.740105]
[Epoch 1091/2000] [Batch 9/10] [D loss: 0.688994] [G loss: 0.736818]
[Epoch 1092/2000] [Batch 9/10] [D loss: 0.687424] [G loss: 0.737973]
[Epoch 1093/2000] [Batch 9/10] [D loss: 0.685908] [G loss: 0.739111]
[Epoch 1094/2000] [Batch 9/10] [D loss: 0.685779] [G loss: 0.737371]
[Epoch 1095/2000] [Batch 9/10] [D loss: 0.686052] [G loss: 0.734837]
[Epoch 1096/2000] [Batch 9/10] [D loss: 0.685852] [G loss: 0.733131]
[Epoch 1097/2000] [Batch 9/10] [D loss: 0.685740] [G loss: 0.731358]
[Epoch 1098/2000] [Batch 9/10] [D loss: 0.685484] [G loss: 0.729798]
[Epoch 1099/2000] [Batch 9/10] [D loss: 0.687957] [G loss: 0.722427]
[Epoch 1100/2000] [Batch 9/10] [D loss: 0.689500] [G loss: 0.717301]
[Epoch 1101/2000] [Batch 9/10] [D loss: 0.695283] [G loss: 0.704503]
[Epoch 1102/2000] [Batch 9/10] [D loss: 0.704355] [G loss: 0.687296]
[Epoch 1103/2000] [Batch 9/10] [D loss: 0.714417] [G loss: 0.669970]
[Epoch 1104/2000] [Batch 9/10] [D 

[Epoch 1209/2000] [Batch 9/10] [D loss: 0.659969] [G loss: 0.737673]
[Epoch 1210/2000] [Batch 9/10] [D loss: 0.661097] [G loss: 0.735299]
[Epoch 1211/2000] [Batch 9/10] [D loss: 0.660944] [G loss: 0.735616]
[Epoch 1212/2000] [Batch 9/10] [D loss: 0.661741] [G loss: 0.734324]
[Epoch 1213/2000] [Batch 9/10] [D loss: 0.661856] [G loss: 0.734519]
[Epoch 1214/2000] [Batch 9/10] [D loss: 0.664286] [G loss: 0.730368]
[Epoch 1215/2000] [Batch 9/10] [D loss: 0.667677] [G loss: 0.724370]
[Epoch 1216/2000] [Batch 9/10] [D loss: 0.671935] [G loss: 0.716288]
[Epoch 1217/2000] [Batch 9/10] [D loss: 0.680902] [G loss: 0.700052]
[Epoch 1218/2000] [Batch 9/10] [D loss: 0.685908] [G loss: 0.693547]
[Epoch 1219/2000] [Batch 9/10] [D loss: 0.686745] [G loss: 0.695803]
[Epoch 1220/2000] [Batch 9/10] [D loss: 0.685665] [G loss: 0.701798]
[Epoch 1221/2000] [Batch 9/10] [D loss: 0.683788] [G loss: 0.709092]
[Epoch 1222/2000] [Batch 9/10] [D loss: 0.681642] [G loss: 0.716585]
[Epoch 1223/2000] [Batch 9/10] [D 

[Epoch 1328/2000] [Batch 9/10] [D loss: 0.654603] [G loss: 0.713550]
[Epoch 1329/2000] [Batch 9/10] [D loss: 0.652934] [G loss: 0.714895]
[Epoch 1330/2000] [Batch 9/10] [D loss: 0.651211] [G loss: 0.716459]
[Epoch 1331/2000] [Batch 9/10] [D loss: 0.649381] [G loss: 0.718334]
[Epoch 1332/2000] [Batch 9/10] [D loss: 0.647449] [G loss: 0.720492]
[Epoch 1333/2000] [Batch 9/10] [D loss: 0.645248] [G loss: 0.723278]
[Epoch 1334/2000] [Batch 9/10] [D loss: 0.643087] [G loss: 0.726168]
[Epoch 1335/2000] [Batch 9/10] [D loss: 0.641863] [G loss: 0.727221]
[Epoch 1336/2000] [Batch 9/10] [D loss: 0.640321] [G loss: 0.729037]
[Epoch 1337/2000] [Batch 9/10] [D loss: 0.639322] [G loss: 0.729863]
[Epoch 1338/2000] [Batch 9/10] [D loss: 0.636985] [G loss: 0.733517]
[Epoch 1339/2000] [Batch 9/10] [D loss: 0.636390] [G loss: 0.733735]
[Epoch 1340/2000] [Batch 9/10] [D loss: 0.634387] [G loss: 0.737038]
[Epoch 1341/2000] [Batch 9/10] [D loss: 0.633724] [G loss: 0.737706]
[Epoch 1342/2000] [Batch 9/10] [D 

[Epoch 1447/2000] [Batch 9/10] [D loss: 0.693279] [G loss: 0.662410]
[Epoch 1448/2000] [Batch 9/10] [D loss: 0.691429] [G loss: 0.662595]
[Epoch 1449/2000] [Batch 9/10] [D loss: 0.692029] [G loss: 0.658723]
[Epoch 1450/2000] [Batch 9/10] [D loss: 0.691480] [G loss: 0.656562]
[Epoch 1451/2000] [Batch 9/10] [D loss: 0.689595] [G loss: 0.657554]
[Epoch 1452/2000] [Batch 9/10] [D loss: 0.689012] [G loss: 0.656195]
[Epoch 1453/2000] [Batch 9/10] [D loss: 0.687740] [G loss: 0.656561]
[Epoch 1454/2000] [Batch 9/10] [D loss: 0.686229] [G loss: 0.657477]
[Epoch 1455/2000] [Batch 9/10] [D loss: 0.686205] [G loss: 0.655521]
[Epoch 1456/2000] [Batch 9/10] [D loss: 0.685607] [G loss: 0.654764]
[Epoch 1457/2000] [Batch 9/10] [D loss: 0.685241] [G loss: 0.653883]
[Epoch 1458/2000] [Batch 9/10] [D loss: 0.683738] [G loss: 0.655303]
[Epoch 1459/2000] [Batch 9/10] [D loss: 0.684189] [G loss: 0.653355]
[Epoch 1460/2000] [Batch 9/10] [D loss: 0.684019] [G loss: 0.652468]
[Epoch 1461/2000] [Batch 9/10] [D 

[Epoch 1566/2000] [Batch 9/10] [D loss: 0.680336] [G loss: 0.691442]
[Epoch 1567/2000] [Batch 9/10] [D loss: 0.680903] [G loss: 0.689152]
[Epoch 1568/2000] [Batch 9/10] [D loss: 0.680580] [G loss: 0.688467]
[Epoch 1569/2000] [Batch 9/10] [D loss: 0.680222] [G loss: 0.687891]
[Epoch 1570/2000] [Batch 9/10] [D loss: 0.680745] [G loss: 0.686061]
[Epoch 1571/2000] [Batch 9/10] [D loss: 0.681246] [G loss: 0.684370]
[Epoch 1572/2000] [Batch 9/10] [D loss: 0.680812] [G loss: 0.684293]
[Epoch 1573/2000] [Batch 9/10] [D loss: 0.680612] [G loss: 0.683647]
[Epoch 1574/2000] [Batch 9/10] [D loss: 0.681688] [G loss: 0.680633]
[Epoch 1575/2000] [Batch 9/10] [D loss: 0.680390] [G loss: 0.681869]
[Epoch 1576/2000] [Batch 9/10] [D loss: 0.682436] [G loss: 0.677354]
[Epoch 1577/2000] [Batch 9/10] [D loss: 0.681455] [G loss: 0.677601]
[Epoch 1578/2000] [Batch 9/10] [D loss: 0.683648] [G loss: 0.672837]
[Epoch 1579/2000] [Batch 9/10] [D loss: 0.685083] [G loss: 0.668297]
[Epoch 1580/2000] [Batch 9/10] [D 

[Epoch 1685/2000] [Batch 9/10] [D loss: 0.659550] [G loss: 0.717617]
[Epoch 1686/2000] [Batch 9/10] [D loss: 0.659195] [G loss: 0.715719]
[Epoch 1687/2000] [Batch 9/10] [D loss: 0.657912] [G loss: 0.715853]
[Epoch 1688/2000] [Batch 9/10] [D loss: 0.656067] [G loss: 0.717570]
[Epoch 1689/2000] [Batch 9/10] [D loss: 0.657111] [G loss: 0.713627]
[Epoch 1690/2000] [Batch 9/10] [D loss: 0.658360] [G loss: 0.709635]
[Epoch 1691/2000] [Batch 9/10] [D loss: 0.656033] [G loss: 0.712324]
[Epoch 1692/2000] [Batch 9/10] [D loss: 0.657841] [G loss: 0.707596]
[Epoch 1693/2000] [Batch 9/10] [D loss: 0.657375] [G loss: 0.707097]
[Epoch 1694/2000] [Batch 9/10] [D loss: 0.657286] [G loss: 0.706303]
[Epoch 1695/2000] [Batch 9/10] [D loss: 0.658348] [G loss: 0.703111]
[Epoch 1696/2000] [Batch 9/10] [D loss: 0.656158] [G loss: 0.706681]
[Epoch 1697/2000] [Batch 9/10] [D loss: 0.657234] [G loss: 0.704242]
[Epoch 1698/2000] [Batch 9/10] [D loss: 0.656669] [G loss: 0.705053]
[Epoch 1699/2000] [Batch 9/10] [D 

[Epoch 1804/2000] [Batch 9/10] [D loss: 0.749603] [G loss: 0.697678]
[Epoch 1805/2000] [Batch 9/10] [D loss: 0.749824] [G loss: 0.697963]
[Epoch 1806/2000] [Batch 9/10] [D loss: 0.754165] [G loss: 0.691287]
[Epoch 1807/2000] [Batch 9/10] [D loss: 0.754368] [G loss: 0.691797]
[Epoch 1808/2000] [Batch 9/10] [D loss: 0.752773] [G loss: 0.695017]
[Epoch 1809/2000] [Batch 9/10] [D loss: 0.754173] [G loss: 0.692559]
[Epoch 1810/2000] [Batch 9/10] [D loss: 0.757038] [G loss: 0.687758]
[Epoch 1811/2000] [Batch 9/10] [D loss: 0.755912] [G loss: 0.688617]
[Epoch 1812/2000] [Batch 9/10] [D loss: 0.756969] [G loss: 0.685597]
[Epoch 1813/2000] [Batch 9/10] [D loss: 0.757532] [G loss: 0.682391]
[Epoch 1814/2000] [Batch 9/10] [D loss: 0.758764] [G loss: 0.679875]
[Epoch 1815/2000] [Batch 9/10] [D loss: 0.758040] [G loss: 0.679740]
[Epoch 1816/2000] [Batch 9/10] [D loss: 0.756955] [G loss: 0.679246]
[Epoch 1817/2000] [Batch 9/10] [D loss: 0.756310] [G loss: 0.678791]
[Epoch 1818/2000] [Batch 9/10] [D 

[Epoch 1923/2000] [Batch 9/10] [D loss: 0.705449] [G loss: 0.685266]
[Epoch 1924/2000] [Batch 9/10] [D loss: 0.704968] [G loss: 0.685054]
[Epoch 1925/2000] [Batch 9/10] [D loss: 0.706049] [G loss: 0.681972]
[Epoch 1926/2000] [Batch 9/10] [D loss: 0.705083] [G loss: 0.682891]
[Epoch 1927/2000] [Batch 9/10] [D loss: 0.706766] [G loss: 0.678321]
[Epoch 1928/2000] [Batch 9/10] [D loss: 0.706460] [G loss: 0.678041]
[Epoch 1929/2000] [Batch 9/10] [D loss: 0.706839] [G loss: 0.676094]
[Epoch 1930/2000] [Batch 9/10] [D loss: 0.706704] [G loss: 0.675859]
[Epoch 1931/2000] [Batch 9/10] [D loss: 0.707145] [G loss: 0.674420]
[Epoch 1932/2000] [Batch 9/10] [D loss: 0.706036] [G loss: 0.675930]
[Epoch 1933/2000] [Batch 9/10] [D loss: 0.706666] [G loss: 0.674142]
[Epoch 1934/2000] [Batch 9/10] [D loss: 0.707394] [G loss: 0.672342]
[Epoch 1935/2000] [Batch 9/10] [D loss: 0.707406] [G loss: 0.671967]
[Epoch 1936/2000] [Batch 9/10] [D loss: 0.706976] [G loss: 0.672539]
[Epoch 1937/2000] [Batch 9/10] [D 

In [61]:
pickle.dump([generator_loss, discriminator_loss], open('gen_dis_loss3.pkl','wb'))

### Evaluate results for Generator

In [20]:
[mean_MNpdf, cov_MNpdf] = torch.load( 'MultiVariateNormalParameters.pt')

In [21]:
if cuda:
    mean_MNpdf = mean_MNpdf.cuda()
    cov_MNpdf = cov_MNpdf.cuda()

In [22]:
N_generate_images = 1

In [23]:
generator.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=5, out_features=8, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Linear(in_features=8, out_features=16, bias=True)
    (3): BatchNorm1d(16, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Linear(in_features=16, out_features=32, bias=True)
    (6): BatchNorm1d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Linear(in_features=32, out_features=64, bias=True)
    (9): BatchNorm1d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Linear(in_features=64, out_features=25, bias=True)
    (12): Tanh()
  )
)

In [24]:
inverse_cov = torch.inverse(cov_MNpdf)
determinant_cov = torch.cholesky(cov_MNpdf).diag().prod()
const_part = -(img_size*img_size)/2 * torch.log(torch.tensor(2*3.141592653)) - 1/2*torch.log(determinant_cov)

def pdf(X):
    X_flattened =  X.flatten()
    diff = (X_flattened - mean_MNpdf)
    pdf_value = torch.matmul( torch.matmul(diff.reshape(1,-1), inverse_cov), diff)[0]
    return -1/2*pdf_value + const_part

In [25]:
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (batch_size, latent_dim))))

# Generate a batch of images
gen_imgs = generator(z)

In [26]:
gen_imgs.shape

torch.Size([1000, 1, 5, 5])

In [27]:
tot_log_likelihood = 0
for each_img in gen_imgs:
    tot_log_likelihood +=   pdf( each_img.squeeze(0) ) 

In [28]:
tot_log_likelihood

tensor(-34720.4844, device='cuda:0', grad_fn=<AddBackward0>)

In [29]:
from scipy.stats import multivariate_normal

In [30]:
var = multivariate_normal(mean=np.zeros(28*28), cov=np.eye(28*28,28*28))

In [31]:
my_x = np.random.multivariate_normal(np.zeros(28*28), np.eye(28*28,28*28))

In [32]:
my_x = np.random.uniform(high=1, low=0, size=784)

In [33]:
np.random.normal(loc=0, scale=1, s-ze=1)

SyntaxError: keyword can't be an expression (<ipython-input-33-de24086a803a>, line 1)

In [None]:
var.pdf(my_x)

In [None]:
prob_list = []
for each_img in gen_imgs:
    my_x = each_img.squeeze(0).cpu().detach().numpy()
    my_x = my_x.flatten()
    probs = multivariate_normal.pdf(my_x, mean=np.zeros(28*28), cov=np.eye(28*28,28*28))
    prob_list.append(probs)

In [None]:
my_x

In [None]:
prob_list