cgan.py

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class Generator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Generator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        
        self.model = nn.Sequential(
            *self._create_layer(self.latent_dim + self.classes, 128, False),
            *self._create_layer(128, 256),
            *self._create_layer(256, 512),
            *self._create_layer(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )
        
    def _create_layer(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers
    
    def forward(self, noise, labels):
        z = torch.cat((self.label_embedding(labels), noise), -1)
        x = self.model(z)
        x = x.view(x.size(0), *self.img_shape)
        return x

In [3]:
class Discriminator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Discriminator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = torch.nn.BCELoss()
        self.model = nn.Sequential(
            *self._create_layer(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer(1024, 512, True, True),
            *self._create_layer(512, 256, True, True),
            *self._create_layer(256, 128, False, False),
            *self._create_layer(128, 1, False, False),
            nn.Sigmoid()
        )
        
    def _create_layer(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers
    
    def forward(self, image, labels):
        x = torch.cat((image.view(image.size(0), -1), self.label_embedding(labels)), -1)
        return self.model(x)
    
    def loss(self, output, label):
        return self.adv_loss(output, label)

build_gan.py

In [5]:
import os

import numpy as np
import torch
import torchvision.utils as vutils

# from cgan import Generator as cganG
# from cgan import Discriminator as cganD

In [7]:
class Model(object):
    def __init__(self, 
                 name, 
                 device,
                 data_loader,
                 classes,
                 channels,
                 img_size,
                 latent_dim):
        self.name = name
        self.device = device
        self.data_loader = data_loader
        self.classes = classes
        self.channels = channels
        self.latent_dim = latent_dim
        if self.name == 'cgan':
            self.netG = cganG(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netG.to(self.device)
        if self.name == 'cgan':
            self.netD = cganD(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netD.to(device)
        self.optim_G = None
        self.optim_D = None
        
        
    def create_optim(self, lr, alpha=0.5, beta=0.999):
        self.optim_G = torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), lr=lr, betas=(alpha, beta))
        self.optim_D = torch.optim.Adam(filter(lambda p: p.requires_grad, self.netD.parameters()), lr=lr, betas=(alpha, beta))

In [8]:
def train(self,
          epochs, 
          log_interval=100,
          out_dir = '',
          verbose=True):
    self.netG.train()
    self.netD.train()
    viz_noise = torch.randn(self.data_loader.batch_size, self.latent_dim, device=self.device)
    viz_label = torch.LongTensor(np.array([num for _ in range(nrows)
                                           for num in range(8)])).to(self.device)
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)
            batch_size = data.size(0)
            real_label = torch.full((batch_size, 1), 1., device=self.device)
            fake_label = torch.full((batch_size, 1), 0., device=self.device)
            
            # Train G
            self.netG.zero_grad()
            z_noise = torch.randn(batch_size, self.latent_dim, device=self.device)
            x_fake_labels = torch.randint(0, self.classes, (batch_size,), device=self.device)
            x_fake = self.netG(z_noise, x_fake_labels)
            y_fake_g = self.netD(x_fake, x_fake_labels)
            g_loss = self.netD.loss(y_fake_g, real_label)
            g_loss.backward()
            self.optim_G.step()
            
            # Train D
            self.netD.zero_grad()
            y_real = self.netD(data, target)
            d_real_loss = self.netD.loss(y_real, real_label)
            
            y_fake_d = self.netD(x_fake.detach(), x_fake_labels)
            d_fake_loss = self.netD.loss(y_fake_d, fake_label)
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            self.optim_D.step()
            
            if verbose and batch_idx%log_interval==0 and batch_idx>0:
                print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f}'.format(epoch, batch_idx, len(self.data_loader),
                                                                              d_loss.mean().item(), g_loss.mean().item()))
                vutils.save_image(data, os.path.join(out_dir, 'real_samples.png'), normalize=True)
                with torch.no_grad():
                    viz_sample = self.netG(viz_noise, viz_label)
                    vutils.save_image(viz_sample, os.path.join(out_dir, 'fake_samples_{}.png'.format(epoch)), nrow=8, normalize=True)
        self.save_to(path=out_dir, name=self.name, verbose=False)

main.py

In [9]:
import argparse
import os
import sys

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

import utils

# from build_gan import Model

In [10]:
FLAGS = None

def main():
    device = torch.device('cuda:0' if FLAGS.cuda else 'cpu')
    
    if FLAGS.train:
        print('Loading data...\n')
        dataset = dset.MNIST(root=FLAGS.data_dir, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(FLAGS.img_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,))]))
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        print('Creating model...\n')
        model = Model(FLAGS.model, device, dataloader, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
        model.create_optim(FLAGS.lr)
        
        # Train
        model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)
        
        model.save_to('')
    else:
        model = Model(FLAGS.model, device, None, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
        model.load_from(FLAGS.out_dir)
        model.eval(mode=0, batch_size=FLAGS.batch_size)

In [None]:
if __name__ == '__main__':
    from utils import boolean_string
    parser = argparse.ArgumentParser(description='Hands-On GANs - Chapter 5')
    parser.add_argument('--model', type=str, default='cgan', help='one of `cgan` and `infogan`.')
    parser.add_argument('--cuda', type=boolean_string, default=True, help='enable CUDA.')
    parser.add_argument('--train', type=boolean_string, default=True, help='train mode or eval mode.')
    parser.add_argument('--data_dir', type=str, default='~/Data/mnist', help='Directory for dataset.')
    parser.add_argument('--out_dir', type=str, default='output', help='Directory for output.')
    parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='size of batches')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
    parser.add_argument('--latent_dim', type=int, default=100, help='latent space dimension')
    parser.add_argument('--classes', type=int, default=10, help='number of classes')
    parser.add_argument('--img_size', type=int, default=64, help='size of images')
    parser.add_argument('--channels', type=int, default=1, help='number of image channels')
    parser.add_argument('--log_interval', type=int, default=100, help='interval between logging and image sampling')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    
    FLAGS = parser.parse_args()
    
    FLAGS.cuda = FLAGS.cuda and torch.cuda.is_available()
    
    if FLAGS.seed is not None:
        torch.manual_seed(FLAGS.seed)
        if FLAGS.cuda:
            torch.cuda.manual_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        
    cudnn.benchmark = True
    
    if FLAGS.train:
        utils.clear_folder(FLAGS.out_dir)
        
    log_file = os.path.join(FLAGS.out_dir, 'log.txt')
    print('Logging to {}\n'.format(log_file))
    sys.stdout = utils.StdOut(log_file)
    print('PyTorch version: {}'.format(torch.__version__))
    print('CUDA version: {}\n'.format(torch.version.cuda))
    
    print(' ' * 9 + 'Args' + ' '*9 + '| ' + 'Type' +' | ' + 'Value')
    print('-' * 50)
    for arg in vars(FLAGS):
        arg_str = str(arg)
        var_str = str(getattr(FLAGS, arg))
        type_str = str(type(getattr(FLAGS, arg)).__name__)
        print(' ' + arg_str + ' ' * (20 - len(arg_str)) + '|' + 
              ' ' + type_str+ ' '*(10 - len(type_str)) + '|' + 
              ' ' _ var_str)
    
    main()