## ACGAN ##

In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 64
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class Generator_ACGAN(nn.Module):
    def __init__(self):
        super(Generator_ACGAN, self).__init__()
        self.emb = nn.Embedding(10, 100)
        self.fc = nn.Linear(100, 128 * 8 ** 2)
        self.main = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        x = torch.mul(self.emb(labels), noise)
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 8, 8)
        x = self.main(x)
        return x

class Discriminator_ACGAN(nn.Module):
    def __init__(self):
        super(Discriminator_ACGAN, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.main = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        self.adv_layer = nn.Sequential(nn.Linear(128 * 2 ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * 2 ** 2, 10), nn.Softmax())

    def forward(self, img):
        x = self.main(img)
        x = x.view(x.shape[0], -1)
        validity = self.adv_layer(x)
        label = self.aux_layer(x)
        return validity, label


print("Instantiating DCGAN generator and discriminator...")
acgan_generator = Generator_ACGAN()
acgan_discriminator = Discriminator_ACGAN()
acgan_generator.to(device)
acgan_discriminator.to(device)
print()

Instantiating DCGAN generator and discriminator...



In [4]:
epochs=50
learning_rate = 2e-4

def train(generator, discriminator, train_dataloader):
    source_criterion = nn.BCELoss()
    class_criterion = nn.NLLLoss()
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_acgan_real/'): 
        os.makedirs('train_generated_images_acgan_real')
    if not os.path.exists('train_generated_images_acgan_fake/'): 
        os.makedirs('train_generated_images_acgan_fake')
        
    inception_score_file = open("inception_score_acgan.csv", "w")
    inception_score_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for images, labels in train_dataloader:
            batch_size = images.shape[0]
            real_images = Variable(images.type(torch.cuda.FloatTensor)).to(device)
            real_labels = Variable(labels.type(torch.cuda.LongTensor)).to(device)

            # adversarial ground truth
            fake = torch.zeros(batch_size).to(device)
            valid = torch.ones(batch_size).to(device)

            ### train generator
            optim_generator.zero_grad()
            z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
            generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))

            # generate image batch
            generated_images = generator(z, generated_labels)

            # compute generator loss, optimize generator
            validity, predicted_label = discriminator(generated_images)
            gen_loss = 0.5 * (source_criterion(validity, valid.unsqueeze(1)) + class_criterion(predicted_label, generated_labels))
            gen_loss.backward()
            optim_generator.step()

            ### train discriminator
            optim_discriminator.zero_grad()

            # compute real images loss
            real_pred, real_aux = discriminator(real_images)
            disc_loss_real = 0.5 * (source_criterion(real_pred, valid.unsqueeze(1)) + class_criterion(real_aux, real_labels))

            # compute fake images loss
            fake_pred, fake_aux = discriminator(generated_images.detach())
            disc_loss_fake = 0.5 * (source_criterion(fake_pred, fake.unsqueeze(1)) + class_criterion(fake_aux, generated_labels))

            # compute overall discriminator loss, optimize discriminator
            disc_loss = 0.5 * (disc_loss_real + disc_loss_fake)
            disc_loss.backward()
            optim_discriminator.step()

        # compute inception score and samples every epoch
        z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
        generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))
        samples = generator(z, generated_labels)

        # normalize to [0, 1]
        samples = samples.mul(0.5).add(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        inception_score, inception_score_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(inception_score, 2)) + ' ± ' + str(round(inception_score_std, 2)))

        samples = samples[:64].data.cpu()
        utils.save_image(samples, 'train_generated_images_acgan_fake/epoch_{}.png'.format(str(epoch)))
        utils.save_image(real_images, 'train_generated_images_acgan_real/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(inception_score, 2)) + '\n')

    inception_score_file.close()

In [5]:
# train ACGAN
print("training ACGAN model...")
train(acgan_generator, acgan_discriminator, train_dataloader)

# save ACGAN to file
#print("saving ACGAN model to file...")
#torch.save(acgan_generator.state_dict(), 'acgan_generator.pkl')
#torch.save(acgan_discriminator.state_dict(), 'acgan_discriminator.pkl')

training ACGAN model...


  input = module(input)
  2%|▏         | 1/50 [01:06<53:56, 66.04s/it]

epoch: 0, inception score: 1.59 ± 0.2


  4%|▍         | 2/50 [01:38<37:07, 46.40s/it]

epoch: 1, inception score: 1.44 ± 0.13


  6%|▌         | 3/50 [01:57<26:23, 33.69s/it]

epoch: 2, inception score: 1.73 ± 0.19


  8%|▊         | 4/50 [02:16<21:27, 27.98s/it]

epoch: 3, inception score: 1.62 ± 0.17


 10%|█         | 5/50 [02:37<19:01, 25.36s/it]

epoch: 4, inception score: 1.65 ± 0.14


 12%|█▏        | 6/50 [02:57<17:16, 23.55s/it]

epoch: 5, inception score: 1.7 ± 0.25


 14%|█▍        | 7/50 [03:31<19:21, 27.00s/it]

epoch: 6, inception score: 1.54 ± 0.08


 16%|█▌        | 8/50 [04:08<21:05, 30.12s/it]

epoch: 7, inception score: 1.58 ± 0.22


 18%|█▊        | 9/50 [04:44<21:53, 32.04s/it]

epoch: 8, inception score: 1.57 ± 0.15


 20%|██        | 10/50 [05:03<18:43, 28.09s/it]

epoch: 9, inception score: 1.63 ± 0.22


 22%|██▏       | 11/50 [05:22<16:27, 25.31s/it]

epoch: 10, inception score: 1.81 ± 0.28


 24%|██▍       | 12/50 [05:43<15:12, 24.01s/it]

epoch: 11, inception score: 1.76 ± 0.28


 26%|██▌       | 13/50 [06:16<16:25, 26.64s/it]

epoch: 12, inception score: 1.72 ± 0.24


 28%|██▊       | 14/50 [06:42<15:50, 26.39s/it]

epoch: 13, inception score: 1.65 ± 0.15


 30%|███       | 15/50 [07:02<14:24, 24.69s/it]

epoch: 14, inception score: 1.77 ± 0.16


 32%|███▏      | 16/50 [07:22<13:09, 23.22s/it]

epoch: 15, inception score: 1.79 ± 0.21


 34%|███▍      | 17/50 [07:42<12:07, 22.05s/it]

epoch: 16, inception score: 1.58 ± 0.1


 36%|███▌      | 18/50 [08:01<11:19, 21.24s/it]

epoch: 17, inception score: 1.89 ± 0.35


 38%|███▊      | 19/50 [08:20<10:38, 20.60s/it]

epoch: 18, inception score: 1.91 ± 0.34


 40%|████      | 20/50 [08:39<10:07, 20.23s/it]

epoch: 19, inception score: 1.67 ± 0.3


 42%|████▏     | 21/50 [08:59<09:37, 19.90s/it]

epoch: 20, inception score: 1.65 ± 0.14


 44%|████▍     | 22/50 [09:18<09:09, 19.62s/it]

epoch: 21, inception score: 1.76 ± 0.18


 46%|████▌     | 23/50 [09:37<08:44, 19.43s/it]

epoch: 22, inception score: 1.78 ± 0.32


 48%|████▊     | 24/50 [09:56<08:24, 19.40s/it]

epoch: 23, inception score: 1.77 ± 0.25


 50%|█████     | 25/50 [10:15<08:03, 19.35s/it]

epoch: 24, inception score: 1.8 ± 0.14


 52%|█████▏    | 26/50 [10:35<07:51, 19.63s/it]

epoch: 25, inception score: 1.75 ± 0.19


 54%|█████▍    | 27/50 [11:06<08:44, 22.81s/it]

epoch: 26, inception score: 1.77 ± 0.27


 56%|█████▌    | 28/50 [11:45<10:09, 27.71s/it]

epoch: 27, inception score: 1.83 ± 0.24


 58%|█████▊    | 29/50 [12:27<11:13, 32.09s/it]

epoch: 28, inception score: 2.0 ± 0.37


 60%|██████    | 30/50 [12:49<09:41, 29.07s/it]

epoch: 29, inception score: 1.77 ± 0.19


 62%|██████▏   | 31/50 [13:09<08:20, 26.35s/it]

epoch: 30, inception score: 1.83 ± 0.16


 64%|██████▍   | 32/50 [13:30<07:26, 24.82s/it]

epoch: 31, inception score: 1.81 ± 0.21


 66%|██████▌   | 33/50 [14:06<07:59, 28.21s/it]

epoch: 32, inception score: 1.95 ± 0.25


 68%|██████▊   | 34/50 [14:27<06:54, 25.92s/it]

epoch: 33, inception score: 1.87 ± 0.27


 70%|███████   | 35/50 [15:03<07:12, 28.85s/it]

epoch: 34, inception score: 1.83 ± 0.28


 72%|███████▏  | 36/50 [15:29<06:34, 28.15s/it]

epoch: 35, inception score: 1.97 ± 0.33


 74%|███████▍  | 37/50 [15:49<05:33, 25.65s/it]

epoch: 36, inception score: 1.83 ± 0.29


 76%|███████▌  | 38/50 [16:09<04:46, 23.83s/it]

epoch: 37, inception score: 1.87 ± 0.22


 78%|███████▊  | 39/50 [16:29<04:11, 22.89s/it]

epoch: 38, inception score: 1.9 ± 0.18


 80%|████████  | 40/50 [16:50<03:41, 22.16s/it]

epoch: 39, inception score: 2.01 ± 0.28


 82%|████████▏ | 41/50 [17:14<03:26, 22.92s/it]

epoch: 40, inception score: 1.82 ± 0.3


 84%|████████▍ | 42/50 [17:35<02:56, 22.09s/it]

epoch: 41, inception score: 1.9 ± 0.17


 86%|████████▌ | 43/50 [17:56<02:32, 21.76s/it]

epoch: 42, inception score: 2.02 ± 0.15


 88%|████████▊ | 44/50 [18:16<02:08, 21.43s/it]

epoch: 43, inception score: 1.9 ± 0.19


 90%|█████████ | 45/50 [18:38<01:47, 21.57s/it]

epoch: 44, inception score: 1.83 ± 0.24


 92%|█████████▏| 46/50 [19:00<01:26, 21.63s/it]

epoch: 45, inception score: 1.83 ± 0.19


 94%|█████████▍| 47/50 [19:21<01:04, 21.56s/it]

epoch: 46, inception score: 1.83 ± 0.18


 96%|█████████▌| 48/50 [19:45<00:44, 22.24s/it]

epoch: 47, inception score: 1.84 ± 0.28


 98%|█████████▊| 49/50 [20:06<00:21, 21.72s/it]

epoch: 48, inception score: 1.73 ± 0.15


100%|██████████| 50/50 [20:37<00:00, 24.75s/it]

epoch: 49, inception score: 2.04 ± 0.26
saving ACGAN model to file...





In [6]:
def generate_images(generator):
    z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
    generated_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))
    samples = generator(z, generated_labels)
    
    samples = samples.mul(0.5).add(0.5)
    samples = samples.data.cpu()
    grid = utils.make_grid(samples)
    print("Grid of 8x8 images saved to 'acgan_generated_images.png'.")
    utils.save_image(grid, 'acgan_generated_images.png')

def load_model(model, model_filename): 
    model.load_state_dict(torch.load(model_filename))

# load trained model and generate sample images
print("loading ACGAN model...")
load_model(acgan_generator, 'acgan_generator.pkl')
load_model(acgan_discriminator, 'acgan_discriminator.pkl')

generate_images(acgan_generator)

loading ACGAN model...
Grid of 8x8 images saved to 'acgan_generated_images.png'.
