## Convolutional GAN

In [1]:
import os
import torch
import torch.nn as nn

In [2]:
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image

In [3]:
if not os.path.exists('save/conv_gan'):
    os.mkdir('save/conv_gan')

In [4]:
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.reshape(-1, 1, 28, 28)
    return out

In [5]:
batch_size = 128
num_epoches = 100
z_dimension = 100  # noise dimension
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # whether GPU is supportted

In [6]:
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [7]:
mnist = datasets.MNIST('../_data/mnist', transform=img_transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=4)

In [8]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
        )
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        '''
        x: batch, width, height, channel=1
        '''
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze()

In [9]:
class generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(generator, self).__init__()
        self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.downsample1(x)
        x = self.downsample2(x)
        x = self.downsample3(x)
        return x

In [10]:
D = discriminator().to(device) # discriminator model
G = generator(z_dimension, 3136).to(device) # generator model

criterion = nn.BCELoss()  # binary cross entropy

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [11]:
# train
for epoch in range(num_epoches):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        real_img = img.to(device)
        real_label = torch.ones(num_img).to(device)
        fake_label = torch.zeros(num_img).to(device)

        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = torch.randn(num_img, z_dimension).to(device)
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch}/{num_epoches}], d_loss: {d_loss:.6f}, g_loss: {g_loss:.6f} '
                  f'D real: {real_scores.mean():.6f}, D fake: {fake_scores.mean():.6f}')

    if epoch == 0:
        real_images = to_img(real_img)
        save_image(real_images, 'save/conv_gan/real_images.png')

    fake_images = to_img(fake_img)
    save_image(fake_images, f'save/conv_gan/fake_images-{epoch+1:0>2}.png')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/100], d_loss: 0.004862, g_loss: 6.344368 D real: 0.997937, D fake: 0.002733
Epoch [0/100], d_loss: 0.072482, g_loss: 3.796228 D real: 0.959111, D fake: 0.027240
Epoch [0/100], d_loss: 0.270547, g_loss: 3.340040 D real: 0.830400, D fake: 0.022542
Epoch [0/100], d_loss: 0.575651, g_loss: 2.782383 D real: 0.888074, D fake: 0.237495


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/100], d_loss: 0.505719, g_loss: 1.531267 D real: 0.765558, D fake: 0.150167
Epoch [1/100], d_loss: 0.414941, g_loss: 1.646160 D real: 0.866862, D fake: 0.178987
Epoch [1/100], d_loss: 0.417752, g_loss: 1.508854 D real: 0.836492, D fake: 0.157988
Epoch [1/100], d_loss: 0.761354, g_loss: 3.875142 D real: 0.633709, D fake: 0.026091
Epoch [2/100], d_loss: 0.333925, g_loss: 2.182739 D real: 0.865069, D fake: 0.097720
Epoch [2/100], d_loss: 0.334393, g_loss: 3.307405 D real: 0.913608, D fake: 0.169393
Epoch [2/100], d_loss: 0.434161, g_loss: 3.434336 D real: 0.820223, D fake: 0.051994
Epoch [2/100], d_loss: 0.432479, g_loss: 3.617238 D real: 0.959572, D fake: 0.260158
Epoch [3/100], d_loss: 0.346676, g_loss: 1.696260 D real: 0.837161, D fake: 0.063617
Epoch [3/100], d_loss: 0.388021, g_loss: 2.449858 D real: 0.897044, D fake: 0.151859
Epoch [3/100], d_loss: 0.240741, g_loss: 3.794872 D real: 0.923281, D fake: 0.109866
Epoch [3/100], d_loss: 0.373882, g_loss: 1.993526 D real: 0.88701

In [12]:
torch.save(G.state_dict(), 'save/conv_gan/generator.pytorch')
torch.save(D.state_dict(), 'save/conv_gan/discriminator.pytorch')