## WGAN ##

In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 64
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class Generator_WGAN(nn.Module):
    def __init__(self):
        super(Generator_WGAN, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.net = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 3*32*32),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.net(z)
        img = img.view(img.shape[0], *(3, 32, 32))
        return img

class Discriminator_WGAN(nn.Module):
    def __init__(self):
        super(Discriminator_WGAN, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(3*32*32, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

print("Instantiating DCGAN generator and discriminator...")
wgan_generator = Generator_WGAN()
wgan_discriminator = Discriminator_WGAN()
wgan_generator.to(device)
wgan_discriminator.to(device)
print()

Instantiating DCGAN generator and discriminator...



In [4]:
learning_rate=5e-4
epochs=50
batch_size=64
n_critic = 5
weight_cliping_limit=0.01

def train(generator, discriminator, train_dataloader):
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))

    if not os.path.exists('train_generated_images_wgan/'): 
        os.makedirs('train_generated_images_wgan')
        
    inception_score_file = open("inception_score_wgan.csv", "w")
    inception_score_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for i, (images, _) in enumerate(train_dataloader):

            real_images = Variable(images.type(torch.cuda.FloatTensor))

            ### train discriminator

            optim_discriminator.zero_grad()
            z = Variable(torch.Tensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
            fake_images = generator(z).detach()
            disc_loss = -torch.mean(discriminator(real_images)) + torch.mean(discriminator(fake_images))
            disc_loss.backward()
            optim_discriminator.step()

            # apply weight clipping
            for p in discriminator.parameters():
                p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

            # Train generator every n_critic batches
            if i % n_critic == 0:

                optim_generator.zero_grad()
                fake_images = generator(z)
                gen_loss = -torch.mean(discriminator(fake_images))
                gen_loss.backward()
                optim_generator.step()

        # compute inception score and samples every epoch
        z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
        samples = generator(z)

        # normalize to [0, 1]
        samples = samples.mul(0.5).add(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        inception_score, inception_score_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(inception_score, 2)) + ' ± ' + str(round(inception_score_std, 2)))

        samples = samples[:64].data.cpu()
        grid = utils.make_grid(samples)
        utils.save_image(grid, 'train_generated_images_wgan/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(inception_score, 2)) + '\n')

    inception_score_file.close()

In [5]:
# train WGAN
print("training WGAN model...")
train(wgan_generator, wgan_discriminator, train_dataloader)

# save WGAN to file
#print("saving WGAN model to file...")
#torch.save(wgan_generator.state_dict(), 'wgan_generator.pkl')
#torch.save(wgan_discriminator.state_dict(), 'wgan_discriminator.pkl')

training WGAN model...


  2%|▏         | 1/50 [00:17<14:21, 17.58s/it]

epoch: 0, inception score: 1.63 ± 0.14


  4%|▍         | 2/50 [00:26<10:09, 12.69s/it]

epoch: 1, inception score: 1.74 ± 0.13


  6%|▌         | 3/50 [00:36<08:41, 11.09s/it]

epoch: 2, inception score: 1.81 ± 0.25


  8%|▊         | 4/50 [00:45<07:52, 10.27s/it]

epoch: 3, inception score: 1.85 ± 0.18


 10%|█         | 5/50 [00:56<08:04, 10.77s/it]

epoch: 4, inception score: 1.88 ± 0.25


 12%|█▏        | 6/50 [01:16<10:02, 13.69s/it]

epoch: 5, inception score: 1.9 ± 0.24


 14%|█▍        | 7/50 [01:35<11:12, 15.64s/it]

epoch: 6, inception score: 1.86 ± 0.26


 16%|█▌        | 8/50 [01:54<11:44, 16.76s/it]

epoch: 7, inception score: 1.89 ± 0.3


 18%|█▊        | 9/50 [02:12<11:37, 17.02s/it]

epoch: 8, inception score: 1.79 ± 0.2


 20%|██        | 10/50 [02:22<09:50, 14.75s/it]

epoch: 9, inception score: 1.78 ± 0.26


 22%|██▏       | 11/50 [02:31<08:32, 13.14s/it]

epoch: 10, inception score: 1.65 ± 0.11


 24%|██▍       | 12/50 [02:40<07:32, 11.91s/it]

epoch: 11, inception score: 1.77 ± 0.34


 26%|██▌       | 13/50 [02:49<06:47, 11.00s/it]

epoch: 12, inception score: 1.69 ± 0.19


 28%|██▊       | 14/50 [02:58<06:16, 10.45s/it]

epoch: 13, inception score: 1.94 ± 0.24


 30%|███       | 15/50 [03:07<05:52, 10.07s/it]

epoch: 14, inception score: 1.82 ± 0.3


 32%|███▏      | 16/50 [03:16<05:31,  9.75s/it]

epoch: 15, inception score: 1.97 ± 0.19


 34%|███▍      | 17/50 [03:26<05:14,  9.53s/it]

epoch: 16, inception score: 1.98 ± 0.43


 36%|███▌      | 18/50 [03:35<04:59,  9.37s/it]

epoch: 17, inception score: 1.88 ± 0.27


 38%|███▊      | 19/50 [03:44<04:53,  9.47s/it]

epoch: 18, inception score: 1.92 ± 0.26


 40%|████      | 20/50 [03:54<04:43,  9.46s/it]

epoch: 19, inception score: 1.93 ± 0.15


 42%|████▏     | 21/50 [04:03<04:33,  9.43s/it]

epoch: 20, inception score: 1.91 ± 0.22


 44%|████▍     | 22/50 [04:12<04:22,  9.37s/it]

epoch: 21, inception score: 1.91 ± 0.25


 46%|████▌     | 23/50 [04:21<04:11,  9.33s/it]

epoch: 22, inception score: 2.17 ± 0.36


 48%|████▊     | 24/50 [04:31<04:01,  9.29s/it]

epoch: 23, inception score: 1.91 ± 0.2


 50%|█████     | 25/50 [04:40<03:52,  9.32s/it]

epoch: 24, inception score: 2.04 ± 0.29


 52%|█████▏    | 26/50 [04:49<03:42,  9.26s/it]

epoch: 25, inception score: 1.86 ± 0.33


 54%|█████▍    | 27/50 [04:58<03:30,  9.15s/it]

epoch: 26, inception score: 1.8 ± 0.14


 56%|█████▌    | 28/50 [05:07<03:21,  9.17s/it]

epoch: 27, inception score: 1.8 ± 0.27


 58%|█████▊    | 29/50 [05:17<03:14,  9.24s/it]

epoch: 28, inception score: 1.9 ± 0.28


 60%|██████    | 30/50 [05:26<03:06,  9.34s/it]

epoch: 29, inception score: 1.99 ± 0.26


 62%|██████▏   | 31/50 [05:36<02:59,  9.43s/it]

epoch: 30, inception score: 1.92 ± 0.19


 64%|██████▍   | 32/50 [05:45<02:49,  9.44s/it]

epoch: 31, inception score: 1.92 ± 0.36


 66%|██████▌   | 33/50 [05:55<02:40,  9.43s/it]

epoch: 32, inception score: 1.95 ± 0.22


 68%|██████▊   | 34/50 [06:04<02:30,  9.38s/it]

epoch: 33, inception score: 1.78 ± 0.18


 70%|███████   | 35/50 [06:14<02:21,  9.44s/it]

epoch: 34, inception score: 1.9 ± 0.15


 72%|███████▏  | 36/50 [06:23<02:12,  9.45s/it]

epoch: 35, inception score: 2.0 ± 0.32


 74%|███████▍  | 37/50 [06:33<02:03,  9.53s/it]

epoch: 36, inception score: 1.87 ± 0.21


 76%|███████▌  | 38/50 [06:43<01:55,  9.63s/it]

epoch: 37, inception score: 1.88 ± 0.23


 78%|███████▊  | 39/50 [06:52<01:45,  9.55s/it]

epoch: 38, inception score: 2.05 ± 0.28


 80%|████████  | 40/50 [07:02<01:36,  9.62s/it]

epoch: 39, inception score: 1.84 ± 0.22


 82%|████████▏ | 41/50 [07:12<01:27,  9.67s/it]

epoch: 40, inception score: 2.0 ± 0.35


 84%|████████▍ | 42/50 [07:21<01:16,  9.59s/it]

epoch: 41, inception score: 2.09 ± 0.27


 86%|████████▌ | 43/50 [07:30<01:06,  9.52s/it]

epoch: 42, inception score: 2.09 ± 0.33


 88%|████████▊ | 44/50 [07:40<00:56,  9.48s/it]

epoch: 43, inception score: 1.88 ± 0.21


 90%|█████████ | 45/50 [07:52<00:51, 10.27s/it]

epoch: 44, inception score: 1.88 ± 0.26


 92%|█████████▏| 46/50 [08:05<00:44, 11.17s/it]

epoch: 45, inception score: 1.96 ± 0.24


 94%|█████████▍| 47/50 [08:14<00:31, 10.61s/it]

epoch: 46, inception score: 1.97 ± 0.22


 96%|█████████▌| 48/50 [08:24<00:20, 10.17s/it]

epoch: 47, inception score: 2.11 ± 0.46


 98%|█████████▊| 49/50 [08:33<00:09,  9.88s/it]

epoch: 48, inception score: 1.92 ± 0.3


100%|██████████| 50/50 [08:42<00:00, 10.45s/it]

epoch: 49, inception score: 1.84 ± 0.22
saving WGAN model to file...





In [6]:
def generate_images(generator):
    z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
    samples = generator(z)
    samples = samples.mul(0.5).add(0.5)
    samples = samples.data.cpu()
    grid = utils.make_grid(samples)
    print("Grid of 8x8 images saved to 'wgan_generated_images.png'.")
    utils.save_image(grid, 'wgan_generated_images.png')

def load_model(model, model_filename): 
    model.load_state_dict(torch.load(model_filename))

# load trained model and generate sample images
print("loading WGAN model...")
load_model(wgan_generator, 'wgan_generator.pkl')
load_model(wgan_discriminator, 'wgan_discriminator.pkl')

generate_images(wgan_generator)

loading WGAN model...
Grid of 8x8 images saved to 'wgan_generated_images.png'.
