## DCGAN ##

In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 32
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class Generator_DCGAN(nn.Module):
    def __init__(self): 
        super(Generator_DCGAN, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(num_features=1024),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose2d(in_channels=256, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh())

    def forward(self, x):
        return self.net(x)

class Discriminator_DCGAN(nn.Module):
    def __init__(self):
        super(Discriminator_DCGAN, self).__init__()
        self.net = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
        nn.Sigmoid())

    def forward(self, x):
        return self.net(x)

print("Instantiating DCGAN generator and discriminator...")
dcgan_generator = Generator_DCGAN()
dcgan_discriminator = Discriminator_DCGAN()
dcgan_generator.to(device)
dcgan_discriminator.to(device)
print()

Instantiating DCGAN generator and discriminator...



In [4]:
learning_rate = 0.0002
epochs = 50


def train(generator, discriminator, train_dataloader):
    loss = nn.BCELoss()
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_dcgan/'): 
        os.makedirs('train_generated_images_dcgan')
        
    inception_score_file = open("inception_score_dcgan.csv", "w")
    inception_score_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for real_images, _ in train_dataloader:
            real_images = real_images.to(device)
            z = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            real_labels = torch.ones(batch_size).to(device)
            fake_labels = torch.zeros(batch_size).to(device)

            ### train discriminator
            # compute loss using real images
            preds = discriminator(real_images)
            disc_loss_real = loss(preds.flatten(), real_labels)

            # compute loss using fake images
            fake_images = generator(z)
            preds = discriminator(fake_images)
            disc_loss_fake = loss(preds.flatten(), fake_labels)

            # optimize discriminator
            disc_loss = disc_loss_real + disc_loss_fake
            discriminator.zero_grad()
            disc_loss.backward()
            optim_discriminator.step()

            ### train generator
            # compute loss with fake images
            z = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            fake_images = generator(z)
            preds = discriminator(fake_images)
            gen_loss = loss(preds.flatten(), real_labels)

            # optimize generator 
            generator.zero_grad()
            gen_loss.backward()
            optim_generator.step()

        # compute inception score and samples every epoch
        z = Variable(torch.randn(800, 100, 1, 1)).to(device)
        samples = generator(z)

        # normalize to [0, 1]
        samples = samples.mul(0.5).add(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        inception_score, inception_score_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(inception_score, 2)) + ' ± ' + str(round(inception_score_std, 2)))

        samples = samples[:64].data.cpu()
        grid = utils.make_grid(samples)
        utils.save_image(grid, 'train_generated_images_dcgan/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(inception_score, 2)) + '\n')

    inception_score_file.close()

In [5]:
# train DCGAN
print("training DCGAN model...")
train(dcgan_generator, dcgan_discriminator, train_dataloader)

# save DCGAN to file
#print("saving DCGAN model to file...")
#torch.save(dcgan_generator.state_dict(), 'dcgan_generator.pkl')
#torch.save(dcgan_discriminator.state_dict(), 'dcgan_discriminator.pkl')

training DCGAN model...


  2%|▏         | 1/50 [02:30<2:02:58, 150.59s/it]

epoch: 0, inception score: 2.51 ± 0.17


  4%|▍         | 2/50 [04:59<1:59:30, 149.38s/it]

epoch: 1, inception score: 2.03 ± 0.1


  6%|▌         | 3/50 [07:22<1:54:43, 146.46s/it]

epoch: 2, inception score: 3.08 ± 0.26


  8%|▊         | 4/50 [09:49<1:52:26, 146.66s/it]

epoch: 3, inception score: 2.95 ± 0.15


 10%|█         | 5/50 [12:19<1:51:03, 148.07s/it]

epoch: 4, inception score: 3.16 ± 0.27


 12%|█▏        | 6/50 [14:51<1:49:30, 149.32s/it]

epoch: 5, inception score: 3.58 ± 0.28


 14%|█▍        | 7/50 [17:23<1:47:34, 150.11s/it]

epoch: 6, inception score: 3.41 ± 0.28


 16%|█▌        | 8/50 [19:54<1:45:27, 150.65s/it]

epoch: 7, inception score: 3.14 ± 0.14


 18%|█▊        | 9/50 [22:26<1:43:14, 151.08s/it]

epoch: 8, inception score: 3.8 ± 0.19


 20%|██        | 10/50 [24:56<1:40:21, 150.54s/it]

epoch: 9, inception score: 3.27 ± 0.18


 22%|██▏       | 11/50 [27:27<1:38:03, 150.86s/it]

epoch: 10, inception score: 3.87 ± 0.3


 24%|██▍       | 12/50 [30:00<1:35:47, 151.25s/it]

epoch: 11, inception score: 4.25 ± 0.33


 26%|██▌       | 13/50 [32:31<1:33:21, 151.38s/it]

epoch: 12, inception score: 4.39 ± 0.22


 28%|██▊       | 14/50 [35:03<1:30:54, 151.50s/it]

epoch: 13, inception score: 4.36 ± 0.36


 30%|███       | 15/50 [37:34<1:28:17, 151.35s/it]

epoch: 14, inception score: 4.74 ± 0.3


 32%|███▏      | 16/50 [40:00<1:24:48, 149.65s/it]

epoch: 15, inception score: 4.58 ± 0.47


 34%|███▍      | 17/50 [42:24<1:21:29, 148.16s/it]

epoch: 16, inception score: 4.91 ± 0.19


 36%|███▌      | 18/50 [44:51<1:18:43, 147.61s/it]

epoch: 17, inception score: 4.55 ± 0.42


 38%|███▊      | 19/50 [47:14<1:15:37, 146.36s/it]

epoch: 18, inception score: 4.82 ± 0.42


 40%|████      | 20/50 [49:42<1:13:23, 146.78s/it]

epoch: 19, inception score: 4.78 ± 0.38


 42%|████▏     | 21/50 [52:07<1:10:45, 146.39s/it]

epoch: 20, inception score: 4.69 ± 0.25


 44%|████▍     | 22/50 [54:35<1:08:30, 146.80s/it]

epoch: 21, inception score: 4.85 ± 0.4


 46%|████▌     | 23/50 [57:02<1:06:04, 146.84s/it]

epoch: 22, inception score: 5.07 ± 0.36


 48%|████▊     | 24/50 [59:27<1:03:21, 146.21s/it]

epoch: 23, inception score: 5.07 ± 0.48


 50%|█████     | 25/50 [1:01:52<1:00:50, 146.01s/it]

epoch: 24, inception score: 5.22 ± 0.25


 52%|█████▏    | 26/50 [1:04:19<58:27, 146.14s/it]  

epoch: 25, inception score: 4.79 ± 0.17


 54%|█████▍    | 27/50 [1:06:43<55:46, 145.50s/it]

epoch: 26, inception score: 5.18 ± 0.28


 56%|█████▌    | 28/50 [1:09:09<53:22, 145.57s/it]

epoch: 27, inception score: 4.92 ± 0.34


 58%|█████▊    | 29/50 [1:11:35<50:59, 145.69s/it]

epoch: 28, inception score: 5.48 ± 0.46


 60%|██████    | 30/50 [1:14:01<48:40, 146.01s/it]

epoch: 29, inception score: 5.06 ± 0.34


 62%|██████▏   | 31/50 [1:16:33<46:48, 147.80s/it]

epoch: 30, inception score: 5.26 ± 0.31


 64%|██████▍   | 32/50 [1:19:04<44:38, 148.79s/it]

epoch: 31, inception score: 5.07 ± 0.33


 66%|██████▌   | 33/50 [1:21:34<42:13, 149.02s/it]

epoch: 32, inception score: 4.93 ± 0.29


 68%|██████▊   | 34/50 [1:23:44<38:15, 143.47s/it]

epoch: 33, inception score: 5.07 ± 0.3


 70%|███████   | 35/50 [1:25:54<34:48, 139.24s/it]

epoch: 34, inception score: 5.39 ± 0.56


 72%|███████▏  | 36/50 [1:28:03<31:48, 136.33s/it]

epoch: 35, inception score: 5.31 ± 0.37


 74%|███████▍  | 37/50 [1:30:13<29:05, 134.30s/it]

epoch: 36, inception score: 5.4 ± 0.31


 76%|███████▌  | 38/50 [1:32:23<26:35, 132.95s/it]

epoch: 37, inception score: 5.49 ± 0.36


 78%|███████▊  | 39/50 [1:34:33<24:12, 132.05s/it]

epoch: 38, inception score: 5.18 ± 0.37


 80%|████████  | 40/50 [1:36:45<22:01, 132.18s/it]

epoch: 39, inception score: 5.02 ± 0.22


 82%|████████▏ | 41/50 [1:38:55<19:44, 131.57s/it]

epoch: 40, inception score: 5.14 ± 0.34


 84%|████████▍ | 42/50 [1:41:06<17:29, 131.23s/it]

epoch: 41, inception score: 5.19 ± 0.47


 86%|████████▌ | 43/50 [1:43:12<15:07, 129.71s/it]

epoch: 42, inception score: 5.46 ± 0.76


 88%|████████▊ | 44/50 [1:45:17<12:49, 128.21s/it]

epoch: 43, inception score: 5.28 ± 0.41


 90%|█████████ | 45/50 [1:47:22<10:37, 127.42s/it]

epoch: 44, inception score: 5.61 ± 0.42


 92%|█████████▏| 46/50 [1:49:25<08:24, 126.09s/it]

epoch: 45, inception score: 5.31 ± 0.42


 94%|█████████▍| 47/50 [1:51:29<06:16, 125.40s/it]

epoch: 46, inception score: 5.19 ± 0.25


 96%|█████████▌| 48/50 [1:53:32<04:09, 124.75s/it]

epoch: 47, inception score: 5.33 ± 0.35


 98%|█████████▊| 49/50 [1:55:35<02:04, 124.30s/it]

epoch: 48, inception score: 5.59 ± 0.48


100%|██████████| 50/50 [1:57:38<00:00, 141.18s/it]

epoch: 49, inception score: 5.23 ± 0.42
saving DCGAN model to file...





In [6]:
def generate_images(generator):
    z = torch.randn(batch_size, 100, 1, 1).to(device)
    samples = generator(z)
    samples = samples.mul(0.5).add(0.5)
    samples = samples.data.cpu()
    grid = utils.make_grid(samples)
    print("Grid of 8x8 images saved to 'dcgan_generated_images.png'.")
    utils.save_image(grid, 'dcgan_generated_images.png')

def load_model(model, model_filename): 
    model.load_state_dict(torch.load(model_filename))

# load trained model and generate sample images
print("loading DCGAN model...")
load_model(dcgan_generator, 'dcgan_generator.pkl')
load_model(dcgan_discriminator, 'dcgan_discriminator.pkl')

generate_images(dcgan_generator)

loading DCGAN model...
Grid of 8x8 images saved to 'dcgan_generated_images.png'.
