## WGAN ##

In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
def create_CIFAR10_dataloaders(batch_size): 
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Resize(32), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
    test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

    train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
    return train_CIFAR10_dataloader, test_CIFAR10_dataloader

print("Downloading CIFAR10 dataset...")
batch_size = 64
train_dataloader, test_dataloader = create_CIFAR10_dataloaders(batch_size)

Downloading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified


In [3]:
class Generator_WGAN(nn.Module):
    def __init__(self):
        super(Generator_WGAN, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.net = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 3*32*32),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.net(z)
        img = img.view(img.shape[0], *(3, 32, 32))
        return img

class Discriminator_WGAN(nn.Module):
    def __init__(self):
        super(Discriminator_WGAN, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(3*32*32, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

print("Instantiating DCGAN generator and discriminator...")
wgan_generator = Generator_WGAN()
wgan_discriminator = Discriminator_WGAN()
wgan_generator.to(device)
wgan_discriminator.to(device)
print()

Instantiating DCGAN generator and discriminator...



In [4]:
learning_rate=5e-4
epochs=50
batch_size=64
n_critic = 5
weight_cliping_limit=0.01

def train(generator, discriminator, train_dataloader):
    optim_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
    optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))

    if not os.path.exists('train_generated_images_wgan/'): 
        os.makedirs('train_generated_images_wgan')
        
    inception_score_file = open("inception_score_wgan.txt", "w")
    inception_score_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for i, (images, _) in enumerate(train_dataloader):

            real_images = Variable(images.type(torch.cuda.FloatTensor))

            ### train discriminator

            optim_discriminator.zero_grad()
            z = Variable(torch.Tensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
            fake_images = generator(z).detach()
            disc_loss = -torch.mean(discriminator(real_images)) + torch.mean(discriminator(fake_images))
            disc_loss.backward()
            optim_discriminator.step()

            # apply weight clipping
            for p in discriminator.parameters():
                p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)

            # Train generator every n_critic batches
            if i % n_critic == 0:

                optim_generator.zero_grad()
                fake_images = generator(z)
                gen_loss = -torch.mean(discriminator(fake_images))
                gen_loss.backward()
                optim_generator.step()

        # compute inception score and samples every epoch
        z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (images.shape[0], 100)))).to(device)
        samples = generator(z)

        # normalize to [0, 1]
        samples = samples.mul(0.5).add(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        inception_score, inception_score_std = get_inception_score(samples)
        print("epoch: " + str(epoch) + ', inception score: ' + str(round(inception_score, 2)) + ' ± ' + str(round(inception_score_std, 2)))

        samples = samples[:64].data.cpu()
        grid = utils.make_grid(samples)
        utils.save_image(grid, 'train_generated_images_wgan/epoch_{}.png'.format(str(epoch)))
        
        inception_score_file.write(str(epoch) + ', ' + str(round(inception_score, 2)) + '\n')

    inception_score_file.close()

In [5]:
# train WGAN
print("training WGAN model...")
train(wgan_generator, wgan_discriminator, train_dataloader)

# save WGAN to file
print("saving WGAN model to file...")
torch.save(wgan_generator.state_dict(), 'wgan_generator.pkl')
torch.save(wgan_discriminator.state_dict(), 'wgan_discriminator.pkl')

training WGAN model...


  1%|          | 1/180 [00:10<31:21, 10.51s/it]

epoch: 2, inception score: 1.82 ± 0.18


  1%|          | 2/180 [00:29<46:36, 15.71s/it]

epoch: 3, inception score: 1.71 ± 0.21


  2%|▏         | 3/180 [00:49<52:01, 17.64s/it]

epoch: 4, inception score: 1.87 ± 0.29


  3%|▎         | 5/180 [01:27<53:43, 18.42s/it]

epoch: 6, inception score: 1.92 ± 0.29


  3%|▎         | 6/180 [01:36<44:17, 15.27s/it]

epoch: 7, inception score: 1.85 ± 0.19


  4%|▍         | 7/180 [01:50<42:00, 14.57s/it]

epoch: 8, inception score: 1.96 ± 0.26


  4%|▍         | 8/180 [02:07<44:45, 15.61s/it]

epoch: 9, inception score: 1.81 ± 0.33


  6%|▌         | 10/180 [02:42<46:38, 16.46s/it]

epoch: 11, inception score: 2.04 ± 0.56


  6%|▌         | 11/180 [03:00<47:35, 16.90s/it]

epoch: 12, inception score: 2.03 ± 0.26


  7%|▋         | 12/180 [03:18<48:10, 17.21s/it]

epoch: 13, inception score: 1.97 ± 0.42


  8%|▊         | 14/180 [03:52<47:37, 17.21s/it]

epoch: 15, inception score: 1.93 ± 0.48


  8%|▊         | 15/180 [04:10<47:55, 17.43s/it]

epoch: 16, inception score: 1.84 ± 0.3


  9%|▉         | 16/180 [04:28<48:00, 17.56s/it]

epoch: 17, inception score: 1.84 ± 0.33


  9%|▉         | 17/180 [04:46<47:58, 17.66s/it]

epoch: 18, inception score: 1.91 ± 0.21


 11%|█         | 19/180 [05:21<46:56, 17.49s/it]

epoch: 20, inception score: 1.94 ± 0.28


 11%|█         | 20/180 [05:39<46:50, 17.57s/it]

epoch: 21, inception score: 2.08 ± 0.36


 12%|█▏        | 21/180 [05:53<44:15, 16.70s/it]

epoch: 22, inception score: 1.82 ± 0.22


 13%|█▎        | 23/180 [06:17<37:46, 14.44s/it]

epoch: 24, inception score: 2.01 ± 0.3


 13%|█▎        | 24/180 [06:36<41:19, 15.90s/it]

epoch: 25, inception score: 2.19 ± 0.35


 14%|█▍        | 25/180 [06:56<43:50, 16.97s/it]

epoch: 26, inception score: 1.9 ± 0.27


 14%|█▍        | 26/180 [07:15<45:36, 17.77s/it]

epoch: 27, inception score: 1.88 ± 0.23


 16%|█▌        | 28/180 [07:53<46:19, 18.28s/it]

epoch: 29, inception score: 1.89 ± 0.29


 16%|█▌        | 29/180 [08:13<46:54, 18.64s/it]

epoch: 30, inception score: 1.86 ± 0.34


 17%|█▋        | 30/180 [08:32<47:25, 18.97s/it]

epoch: 31, inception score: 1.91 ± 0.14


 18%|█▊        | 32/180 [09:10<46:36, 18.90s/it]

epoch: 33, inception score: 2.22 ± 0.37


 18%|█▊        | 33/180 [09:30<46:42, 19.06s/it]

epoch: 34, inception score: 1.91 ± 0.3


 19%|█▉        | 34/180 [09:49<46:42, 19.20s/it]

epoch: 35, inception score: 2.08 ± 0.25


 19%|█▉        | 35/180 [10:09<46:43, 19.33s/it]

epoch: 36, inception score: 2.09 ± 0.2


 21%|██        | 37/180 [10:47<45:25, 19.06s/it]

epoch: 38, inception score: 1.85 ± 0.15


 21%|██        | 38/180 [11:06<45:20, 19.16s/it]

epoch: 39, inception score: 2.09 ± 0.21


 22%|██▏       | 39/180 [11:26<45:22, 19.31s/it]

epoch: 40, inception score: 2.19 ± 0.57


 22%|██▏       | 40/180 [11:36<38:42, 16.59s/it]

epoch: 41, inception score: 1.9 ± 0.32


 23%|██▎       | 42/180 [12:01<34:10, 14.86s/it]

epoch: 43, inception score: 1.89 ± 0.31


 24%|██▍       | 43/180 [12:19<36:01, 15.78s/it]

epoch: 44, inception score: 2.06 ± 0.36


 24%|██▍       | 44/180 [12:37<37:14, 16.43s/it]

epoch: 45, inception score: 1.95 ± 0.27


 26%|██▌       | 46/180 [13:12<37:42, 16.88s/it]

epoch: 47, inception score: 1.84 ± 0.24


 26%|██▌       | 47/180 [13:30<38:07, 17.20s/it]

epoch: 48, inception score: 2.06 ± 0.41


 27%|██▋       | 48/180 [13:47<38:16, 17.40s/it]

epoch: 49, inception score: 1.85 ± 0.28


 27%|██▋       | 49/180 [14:05<38:15, 17.52s/it]

epoch: 50, inception score: 1.95 ± 0.32


 28%|██▊       | 51/180 [14:40<37:25, 17.41s/it]

epoch: 52, inception score: 1.89 ± 0.29


 29%|██▉       | 52/180 [14:58<37:27, 17.56s/it]

epoch: 53, inception score: 1.84 ± 0.22


 29%|██▉       | 53/180 [15:16<37:22, 17.65s/it]

epoch: 54, inception score: 1.75 ± 0.25


 31%|███       | 55/180 [15:50<36:21, 17.45s/it]

epoch: 56, inception score: 1.79 ± 0.29


 31%|███       | 56/180 [16:08<36:20, 17.58s/it]

epoch: 57, inception score: 1.71 ± 0.11


 32%|███▏      | 57/180 [16:26<36:13, 17.67s/it]

epoch: 58, inception score: 1.73 ± 0.24


 32%|███▏      | 58/180 [16:44<36:11, 17.80s/it]

epoch: 59, inception score: 1.87 ± 0.12


 33%|███▎      | 60/180 [17:19<35:05, 17.54s/it]

epoch: 61, inception score: 1.72 ± 0.26


 34%|███▍      | 61/180 [17:37<34:56, 17.62s/it]

epoch: 62, inception score: 2.02 ± 0.23


 34%|███▍      | 62/180 [17:55<34:47, 17.69s/it]

epoch: 63, inception score: 1.85 ± 0.18


 36%|███▌      | 64/180 [18:30<34:11, 17.69s/it]

epoch: 65, inception score: 1.71 ± 0.2


 36%|███▌      | 65/180 [18:48<34:02, 17.76s/it]

epoch: 66, inception score: 1.99 ± 0.42


 37%|███▋      | 66/180 [19:06<33:51, 17.82s/it]

epoch: 67, inception score: 2.06 ± 0.36


 37%|███▋      | 67/180 [19:24<33:38, 17.86s/it]

epoch: 68, inception score: 2.06 ± 0.31


 38%|███▊      | 69/180 [19:59<32:29, 17.56s/it]

epoch: 70, inception score: 2.02 ± 0.31


 39%|███▉      | 70/180 [20:16<32:18, 17.63s/it]

epoch: 71, inception score: 1.81 ± 0.36


 39%|███▉      | 71/180 [20:34<32:08, 17.70s/it]

epoch: 72, inception score: 1.8 ± 0.3


 40%|████      | 72/180 [20:52<31:45, 17.64s/it]

epoch: 73, inception score: 2.08 ± 0.26


 41%|████      | 74/180 [21:21<29:11, 16.52s/it]

epoch: 75, inception score: 2.0 ± 0.33


 42%|████▏     | 75/180 [21:42<31:05, 17.77s/it]

epoch: 76, inception score: 1.87 ± 0.15


 42%|████▏     | 76/180 [22:02<31:55, 18.42s/it]

epoch: 77, inception score: 1.89 ± 0.22


 43%|████▎     | 78/180 [22:39<31:39, 18.62s/it]

epoch: 79, inception score: 1.89 ± 0.17


 44%|████▍     | 79/180 [22:59<31:45, 18.87s/it]

epoch: 80, inception score: 1.94 ± 0.35


 44%|████▍     | 80/180 [23:19<31:51, 19.12s/it]

epoch: 81, inception score: 1.85 ± 0.38


 45%|████▌     | 81/180 [23:38<31:47, 19.27s/it]

epoch: 82, inception score: 1.96 ± 0.49


 46%|████▌     | 83/180 [24:16<30:45, 19.03s/it]

epoch: 84, inception score: 1.85 ± 0.24


 47%|████▋     | 84/180 [24:35<30:35, 19.12s/it]

epoch: 85, inception score: 1.9 ± 0.25


 47%|████▋     | 85/180 [24:55<30:34, 19.31s/it]

epoch: 86, inception score: 1.95 ± 0.34


 48%|████▊     | 87/180 [25:33<29:35, 19.09s/it]

epoch: 88, inception score: 1.96 ± 0.39


 49%|████▉     | 88/180 [25:52<29:23, 19.17s/it]

epoch: 89, inception score: 1.98 ± 0.26


 49%|████▉     | 89/180 [26:12<29:17, 19.31s/it]

epoch: 90, inception score: 2.05 ± 0.39


 50%|█████     | 90/180 [26:32<29:05, 19.39s/it]

epoch: 91, inception score: 2.06 ± 0.43


 51%|█████     | 92/180 [27:09<28:02, 19.12s/it]

epoch: 93, inception score: 1.87 ± 0.29


 52%|█████▏    | 93/180 [27:29<27:48, 19.18s/it]

epoch: 94, inception score: 1.97 ± 0.4


 52%|█████▏    | 94/180 [27:49<27:46, 19.38s/it]

epoch: 95, inception score: 1.97 ± 0.23


 53%|█████▎    | 96/180 [28:27<26:47, 19.13s/it]

epoch: 97, inception score: 1.98 ± 0.28


 54%|█████▍    | 97/180 [28:46<26:35, 19.22s/it]

epoch: 98, inception score: 2.09 ± 0.52


 54%|█████▍    | 98/180 [29:06<26:28, 19.38s/it]

epoch: 99, inception score: 1.91 ± 0.35


 55%|█████▌    | 99/180 [29:25<26:16, 19.46s/it]

epoch: 100, inception score: 1.92 ± 0.34


 56%|█████▌    | 101/180 [30:03<25:13, 19.16s/it]

epoch: 102, inception score: 2.14 ± 0.38


 57%|█████▋    | 102/180 [30:23<25:10, 19.37s/it]

epoch: 103, inception score: 1.95 ± 0.18


 57%|█████▋    | 103/180 [30:43<25:09, 19.61s/it]

epoch: 104, inception score: 2.05 ± 0.35


 58%|█████▊    | 105/180 [31:23<24:32, 19.63s/it]

epoch: 105, inception score: 2.02 ± 0.32


 59%|█████▉    | 106/180 [31:37<22:24, 18.17s/it]

epoch: 107, inception score: 1.9 ± 0.36


 59%|█████▉    | 107/180 [31:47<18:53, 15.52s/it]

epoch: 108, inception score: 2.03 ± 0.21


 59%|█████▉    | 107/180 [32:01<21:51, 17.96s/it]


KeyboardInterrupt: 

In [12]:
def generate_images(generator):
    z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
    samples = generator(z)
    samples = samples.mul(0.5).add(0.5)
    samples = samples.data.cpu()
    grid = utils.make_grid(samples)
    print("Grid of 8x8 images saved to 'wgan_generated_images.png'.")
    utils.save_image(grid, 'wgan_generated_images.png')

def load_model(model, model_filename): 
    model.load_state_dict(torch.load(model_filename))

# load trained model and generate sample images
print("loading WGAN model...")
load_model(wgan_generator, 'wgan_generator.pkl')
load_model(wgan_discriminator, 'wgan_discriminator.pkl')

generate_images(wgan_generator)

loading WGAN model...
Grid of 8x8 images saved to 'wgan_generated_images.png'.
