In [1]:

import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cuda = True if torch.cuda.is_available() else False
print(device)

cuda


In [2]:
os.makedirs("images", exist_ok=True)

n_epochs = 1000
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 8
latent_dim = 100
img_size = 128
channels = 3
sample_interval = 200

In [3]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [10]:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(512*16, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [11]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout2d(p=0.25, inplace=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout2d(p=0.25, inplace=False)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Dropout2d(p=0.25, inplace=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): Dropout2d(p=0.25, inplace=False)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_laye

In [12]:
# Configure data loader
os.makedirs("/data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "/data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

In [13]:
dataset = datasets.ImageFolder(root="cats", 
                          transform=transforms.Compose([
                              transforms.Resize(img_size),
                              transforms.CenterCrop(img_size), 
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size,
                                        shuffle=True, num_workers = 2)

In [None]:



# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    print()
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()),
            end = '\r'
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            torch.save(generator.state_dict(), 'generator.pt')
            torch.save(discriminator.state_dict(), 'discriminator.pt')
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)



[Epoch 0/1000] [Batch 246/247] [D loss: 0.748594] [G loss: 0.653998]
[Epoch 1/1000] [Batch 246/247] [D loss: 0.693725] [G loss: 0.712708]
[Epoch 2/1000] [Batch 246/247] [D loss: 0.688455] [G loss: 0.705525]
[Epoch 3/1000] [Batch 246/247] [D loss: 0.665830] [G loss: 0.711854]
[Epoch 4/1000] [Batch 246/247] [D loss: 0.705281] [G loss: 0.732448]
[Epoch 5/1000] [Batch 246/247] [D loss: 0.678582] [G loss: 0.696422]
[Epoch 6/1000] [Batch 246/247] [D loss: 0.684510] [G loss: 0.744285]
[Epoch 7/1000] [Batch 246/247] [D loss: 0.655462] [G loss: 0.741263]
[Epoch 8/1000] [Batch 246/247] [D loss: 0.653474] [G loss: 0.702507]
[Epoch 9/1000] [Batch 246/247] [D loss: 0.680369] [G loss: 0.662454]
[Epoch 10/1000] [Batch 246/247] [D loss: 0.724844] [G loss: 0.722923]
[Epoch 11/1000] [Batch 246/247] [D loss: 0.703375] [G loss: 0.697721]
[Epoch 12/1000] [Batch 246/247] [D loss: 0.704029] [G loss: 0.709544]
[Epoch 13/1000] [Batch 246/247] [D loss: 0.700805] [G loss: 0.713903]
[Epoch 14/1000] [Batch 246/24

[Epoch 116/1000] [Batch 246/247] [D loss: 0.683880] [G loss: 0.711506]
[Epoch 117/1000] [Batch 246/247] [D loss: 0.707301] [G loss: 0.711395]
[Epoch 118/1000] [Batch 246/247] [D loss: 0.692877] [G loss: 0.693057]
[Epoch 119/1000] [Batch 246/247] [D loss: 0.696218] [G loss: 0.704469]
[Epoch 120/1000] [Batch 246/247] [D loss: 0.694530] [G loss: 0.696150]
[Epoch 121/1000] [Batch 246/247] [D loss: 0.680198] [G loss: 0.703968]
[Epoch 122/1000] [Batch 246/247] [D loss: 0.694046] [G loss: 0.705082]
[Epoch 123/1000] [Batch 246/247] [D loss: 0.687916] [G loss: 0.716988]
[Epoch 124/1000] [Batch 246/247] [D loss: 0.687908] [G loss: 0.695407]
[Epoch 125/1000] [Batch 246/247] [D loss: 0.680871] [G loss: 0.746733]
[Epoch 126/1000] [Batch 246/247] [D loss: 0.684452] [G loss: 0.706196]
[Epoch 127/1000] [Batch 246/247] [D loss: 0.678498] [G loss: 0.711816]
[Epoch 128/1000] [Batch 246/247] [D loss: 0.675448] [G loss: 0.715674]
[Epoch 129/1000] [Batch 246/247] [D loss: 0.682618] [G loss: 0.662822]
[Epoch

[Epoch 231/1000] [Batch 246/247] [D loss: 0.748431] [G loss: 0.676746]
[Epoch 232/1000] [Batch 246/247] [D loss: 0.718265] [G loss: 0.708099]
[Epoch 233/1000] [Batch 246/247] [D loss: 0.699739] [G loss: 0.702914]
[Epoch 234/1000] [Batch 246/247] [D loss: 0.664141] [G loss: 0.723364]
[Epoch 235/1000] [Batch 246/247] [D loss: 0.660798] [G loss: 0.685996]
[Epoch 236/1000] [Batch 246/247] [D loss: 0.697227] [G loss: 0.681434]
[Epoch 237/1000] [Batch 246/247] [D loss: 0.671368] [G loss: 0.708908]
[Epoch 238/1000] [Batch 246/247] [D loss: 0.718141] [G loss: 0.685548]
[Epoch 239/1000] [Batch 246/247] [D loss: 0.702648] [G loss: 0.693329]
[Epoch 240/1000] [Batch 246/247] [D loss: 0.684115] [G loss: 0.703644]
[Epoch 241/1000] [Batch 246/247] [D loss: 0.707892] [G loss: 0.678019]
[Epoch 242/1000] [Batch 246/247] [D loss: 0.687991] [G loss: 0.708586]
[Epoch 243/1000] [Batch 246/247] [D loss: 0.699021] [G loss: 0.715744]
[Epoch 244/1000] [Batch 246/247] [D loss: 0.680951] [G loss: 0.707384]
[Epoch

[Epoch 346/1000] [Batch 246/247] [D loss: 0.714936] [G loss: 0.544403]
[Epoch 347/1000] [Batch 246/247] [D loss: 0.503740] [G loss: 0.833417]
[Epoch 348/1000] [Batch 246/247] [D loss: 0.776627] [G loss: 0.587471]
[Epoch 349/1000] [Batch 246/247] [D loss: 0.892745] [G loss: 0.849224]
[Epoch 350/1000] [Batch 246/247] [D loss: 0.668705] [G loss: 0.873441]
[Epoch 351/1000] [Batch 246/247] [D loss: 0.582360] [G loss: 0.725337]
[Epoch 352/1000] [Batch 246/247] [D loss: 0.575313] [G loss: 0.800100]
[Epoch 353/1000] [Batch 246/247] [D loss: 0.930144] [G loss: 0.906160]
[Epoch 354/1000] [Batch 246/247] [D loss: 0.810024] [G loss: 0.753771]
[Epoch 355/1000] [Batch 246/247] [D loss: 1.236760] [G loss: 1.330255]
[Epoch 356/1000] [Batch 246/247] [D loss: 0.635470] [G loss: 1.160454]
[Epoch 357/1000] [Batch 246/247] [D loss: 0.598530] [G loss: 0.593256]
[Epoch 358/1000] [Batch 246/247] [D loss: 0.753454] [G loss: 0.580669]
[Epoch 359/1000] [Batch 246/247] [D loss: 0.994078] [G loss: 0.728437]
[Epoch