In [1]:
'''import zipfile
with zipfile.ZipFile('cheap_sneakers.zip', 'r') as zip_ref:
    zip_ref.extractall('cheap_sneakers')'''

"import zipfile\nwith zipfile.ZipFile('cheap_sneakers.zip', 'r') as zip_ref:\n    zip_ref.extractall('cheap_sneakers')"

In [2]:
'''!rm -rf "out3"'''

'!rm -rf "out3"'

In [3]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
manualSeed = 24
random.seed(manualSeed)
torch.manual_seed(manualSeed)
cudnn.benchmark = True

In [5]:
nc = 3
ngpu = 1
nz = 100
ngf = 64
ndf = 64
batchSize = 64
dataset = dset.ImageFolder(root='cheap_sneakers',
              transform=transforms.Compose([
                  transforms.Resize((64, 64)),
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
              ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                      shuffle=True, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [7]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

In [8]:
netG = Generator(ngpu).to(device)
netG.apply(weights_init)
print(netG)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


In [9]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)

In [10]:
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
print(netD)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)


In [11]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
d_lr = 0.00002
g_lr = 0.0002
optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(0.55, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=g_lr, betas=(0.55, 0.999))

In [12]:
def train(niter):
    outf = 'out1'
    d_loss, g_loss = [], []
    for epoch in tqdm(range(niter)):
        for i, data in enumerate(dataloader):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label,
                               dtype=real_cpu.dtype, device=device)

            output = netD(real_cpu)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

        d_loss.append(errD.item())
        g_loss.append(errG.item())
        
        print('[%d/%d] | Loss_D: %.4f | Loss_G: %.4f | D(x): %.4f | D(G(z)): %.4f / %.4f'
                % (epoch+1, niter, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        if epoch < 10:
            vutils.save_image(real_cpu, '%s/real_samples.png' % outf, normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(), '%s/fake_samples_epoch_%03d.png' % (outf, epoch), normalize=True)
        
        if epoch % 10 == 0:
            vutils.save_image(real_cpu, '%s/real_samples.png' % outf, normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(), '%s/fake_samples_epoch_%03d.png' % (outf, epoch), normalize=True)
        
        if epoch % 100 == 0:
            # do checkpointing
            torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
            torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))
            
    return d_loss, g_loss

In [None]:
d_loss, g_loss = train(400)

HBox(children=(FloatProgress(value=0.0, max=400.0), HTML(value='')))

[1/400] | Loss_D: 0.1725 | Loss_G: 5.6853 | D(x): 0.9205 | D(G(z)): 0.0749 / 0.0048
[2/400] | Loss_D: 0.6948 | Loss_G: 6.6517 | D(x): 0.9272 | D(G(z)): 0.4459 / 0.0027
[3/400] | Loss_D: 0.5919 | Loss_G: 2.4416 | D(x): 0.6877 | D(G(z)): 0.0917 / 0.0977
[4/400] | Loss_D: 2.2743 | Loss_G: 0.1934 | D(x): 0.1655 | D(G(z)): 0.0924 / 0.8266
[5/400] | Loss_D: 0.7856 | Loss_G: 0.7416 | D(x): 0.5396 | D(G(z)): 0.1080 / 0.5008
[6/400] | Loss_D: 1.3121 | Loss_G: 4.3243 | D(x): 0.9439 | D(G(z)): 0.6904 / 0.0142
[7/400] | Loss_D: 0.6766 | Loss_G: 3.4773 | D(x): 0.8698 | D(G(z)): 0.3875 / 0.0355
[8/400] | Loss_D: 1.0446 | Loss_G: 0.5815 | D(x): 0.4305 | D(G(z)): 0.1178 / 0.5775
[9/400] | Loss_D: 1.1806 | Loss_G: 0.7106 | D(x): 0.4629 | D(G(z)): 0.2390 / 0.5099
[10/400] | Loss_D: 1.8448 | Loss_G: 3.8247 | D(x): 0.7988 | D(G(z)): 0.7802 / 0.0257
[11/400] | Loss_D: 1.1217 | Loss_G: 2.0990 | D(x): 0.6159 | D(G(z)): 0.4275 / 0.1338
[12/400] | Loss_D: 1.6178 | Loss_G: 2.5343 | D(x): 0.8161 | D(G(z)): 0.716

[98/400] | Loss_D: 0.1486 | Loss_G: 4.2952 | D(x): 0.9440 | D(G(z)): 0.0823 / 0.0184
[99/400] | Loss_D: 0.1810 | Loss_G: 3.5503 | D(x): 0.8573 | D(G(z)): 0.0239 / 0.0446
[106/400] | Loss_D: 4.9816 | Loss_G: 0.0499 | D(x): 0.0526 | D(G(z)): 0.0278 / 0.9574
[107/400] | Loss_D: 0.1465 | Loss_G: 4.3067 | D(x): 0.8802 | D(G(z)): 0.0144 / 0.0173
[108/400] | Loss_D: 0.1126 | Loss_G: 4.4750 | D(x): 0.9702 | D(G(z)): 0.0767 / 0.0184
[109/400] | Loss_D: 0.2162 | Loss_G: 6.4302 | D(x): 0.9552 | D(G(z)): 0.1432 / 0.0025
[110/400] | Loss_D: 0.6133 | Loss_G: 0.7060 | D(x): 0.6186 | D(G(z)): 0.1016 / 0.5465
[111/400] | Loss_D: 0.2553 | Loss_G: 3.0386 | D(x): 0.8454 | D(G(z)): 0.0643 / 0.0517
[112/400] | Loss_D: 0.1273 | Loss_G: 5.2784 | D(x): 0.9894 | D(G(z)): 0.1052 / 0.0090
[113/400] | Loss_D: 0.1825 | Loss_G: 2.9609 | D(x): 0.8704 | D(G(z)): 0.0377 / 0.0602
[114/400] | Loss_D: 0.2316 | Loss_G: 3.4749 | D(x): 0.8629 | D(G(z)): 0.0738 / 0.0391
[115/400] | Loss_D: 0.1491 | Loss_G: 3.0803 | D(x): 0.89

[200/400] | Loss_D: 0.0657 | Loss_G: 5.4359 | D(x): 0.9410 | D(G(z)): 0.0026 / 0.0077
[201/400] | Loss_D: 0.0265 | Loss_G: 6.0503 | D(x): 0.9800 | D(G(z)): 0.0061 / 0.0051
[202/400] | Loss_D: 0.0240 | Loss_G: 5.7913 | D(x): 0.9858 | D(G(z)): 0.0094 / 0.0061
[203/400] | Loss_D: 0.0180 | Loss_G: 5.9171 | D(x): 0.9865 | D(G(z)): 0.0042 / 0.0051
[204/400] | Loss_D: 0.0287 | Loss_G: 6.7562 | D(x): 0.9907 | D(G(z)): 0.0184 / 0.0019
[205/400] | Loss_D: 0.0074 | Loss_G: 6.2512 | D(x): 0.9960 | D(G(z)): 0.0033 / 0.0025
[206/400] | Loss_D: 0.1729 | Loss_G: 1.3609 | D(x): 0.8517 | D(G(z)): 0.0060 / 0.4449
[207/400] | Loss_D: 0.0091 | Loss_G: 6.1273 | D(x): 0.9961 | D(G(z)): 0.0051 / 0.0042
[208/400] | Loss_D: 0.0465 | Loss_G: 5.6537 | D(x): 0.9616 | D(G(z)): 0.0066 / 0.0093
[209/400] | Loss_D: 0.1020 | Loss_G: 4.0875 | D(x): 0.9119 | D(G(z)): 0.0046 / 0.0313
[210/400] | Loss_D: 0.0080 | Loss_G: 7.0544 | D(x): 0.9965 | D(G(z)): 0.0044 / 0.0032
[211/400] | Loss_D: 0.2642 | Loss_G: 1.3067 | D(x): 0.

[296/400] | Loss_D: 0.9626 | Loss_G: 0.1017 | D(x): 0.4676 | D(G(z)): 0.0056 / 0.9125
[297/400] | Loss_D: 0.0218 | Loss_G: 5.7742 | D(x): 0.9893 | D(G(z)): 0.0107 / 0.0050
[298/400] | Loss_D: 0.0373 | Loss_G: 6.1221 | D(x): 0.9918 | D(G(z)): 0.0281 / 0.0025
[299/400] | Loss_D: 0.0174 | Loss_G: 6.1585 | D(x): 0.9991 | D(G(z)): 0.0162 / 0.0028
[300/400] | Loss_D: 0.0705 | Loss_G: 4.6529 | D(x): 0.9435 | D(G(z)): 0.0094 / 0.0151
[301/400] | Loss_D: 0.0072 | Loss_G: 7.7151 | D(x): 0.9949 | D(G(z)): 0.0021 / 0.0013
[302/400] | Loss_D: 0.0081 | Loss_G: 7.8249 | D(x): 0.9934 | D(G(z)): 0.0015 / 0.0012
[303/400] | Loss_D: 0.0089 | Loss_G: 6.8933 | D(x): 0.9986 | D(G(z)): 0.0074 / 0.0018
[304/400] | Loss_D: 0.0636 | Loss_G: 8.5974 | D(x): 0.9824 | D(G(z)): 0.0434 / 0.0005
[305/400] | Loss_D: 0.0156 | Loss_G: 5.8675 | D(x): 0.9921 | D(G(z)): 0.0076 / 0.0034
[306/400] | Loss_D: 0.0532 | Loss_G: 6.4475 | D(x): 0.9562 | D(G(z)): 0.0068 / 0.0082
[307/400] | Loss_D: 0.0433 | Loss_G: 5.5477 | D(x): 0.

In [None]:
plt.figure(figsize=(15, 10))
plt.plot(d_loss, label="Discriminator Loss")
plt.plot(g_loss, label="Generator Loss")
plt.title('Cheap Sneaker GAN Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(frameon=False);