In [1]:
import os
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F


if not os.path.exists('./gan_img'):
    os.mkdir('./gan_img')


def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

# Define the hyper-parameter and load the training dataset

In [2]:
# hyper-parameter
train_epoch = 50
batch_size = 64
noise_size = 10
lr = 2e-3

In [3]:
# tensor transform
img_transform = transforms.Compose([
transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=img_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Define the generator and discriminator

In [4]:
#############################################################
# default given GAN model
# class generator(nn.Module):
#     # initializers
#     def __init__(self, input_size=100, n_class = 28*28):
#         super(generator, self).__init__()
#         self.fc1 = nn.Linear(input_size, 256)
#         self.fc2 = nn.Linear(self.fc1.out_features, 512)
#         self.fc3 = nn.Linear(self.fc2.out_features, 1024)
#         self.fc4 = nn.Linear(self.fc3.out_features, n_class)
#         self.tanh = nn.Tanh()

#     # forward method
#     def forward(self, input):
#         x = F.leaky_relu(self.fc1(input), 0.2)
#         x = F.leaky_relu(self.fc2(x), 0.2)
#         x = F.leaky_relu(self.fc3(x), 0.2)
#         x = self.tanh(self.fc4(x))

#         return x
    
# class discriminator(nn.Module):
#     # initializers
#     def __init__(self, input_size=28*28, n_class=1):
#         super(discriminator, self).__init__()
#         self.fc1 = nn.Linear(input_size, 1024)
#         self.fc2 = nn.Linear(self.fc1.out_features, 512)
#         self.fc3 = nn.Linear(self.fc2.out_features, 256)
#         self.fc4 = nn.Linear(self.fc3.out_features, n_class)
#         self.sigmoid = nn.Sigmoid()

#     # forward method
#     def forward(self, input):
#         x = F.leaky_relu(self.fc1(input), 0.2)
#         x = F.dropout(x, 0.3)
#         x = F.leaky_relu(self.fc2(x), 0.2)
#         x = F.dropout(x, 0.3)
#         x = F.leaky_relu(self.fc3(x), 0.2)
#         x = F.dropout(x, 0.3)
#         x = self.sigmoid(self.fc4(x))

#         return x
#############################################################
# modified GAN as exercise 4-1
# revise nnumber of latent space from 100 to 10
class generator(nn.Module):
    # initializers
    def __init__(self, input_size=10, n_class = 28*28):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 1024)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)
        self.tanh = nn.Tanh()

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.tanh(self.fc4(x))

        return x
    
class discriminator(nn.Module):
    # initializers
    def __init__(self, input_size=28*28, n_class=1):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 256)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)
        self.sigmoid = nn.Sigmoid()

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = self.sigmoid(self.fc4(x))

        return x
#############################################################

G = generator( input_size = noise_size, n_class=28*28)
D = discriminator(input_size=28*28, n_class=1)

if torch.cuda.is_available():
    G.cuda()
    D.cuda()

print(G)
print(D)


BCE_loss = nn.BCELoss() 

D_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
G_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

generator(
  (fc1): Linear(in_features=10, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
  (tanh): Tanh()
)
discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)


# Start to training and save the reconstruction images

In [5]:
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    for i, (x_, _ )in enumerate(dataloader, 0):
        # train discriminator D
        D.zero_grad()

        x_ = x_.view(-1, 28 * 28)

        mini_batch = x_.size()[0]

        y_real_ = torch.ones((mini_batch,1))
        y_fake_ = torch.zeros((mini_batch,1))
        x_, y_real_, y_fake_ = Variable(x_), Variable(y_real_), Variable(y_fake_)

        if torch.cuda.is_available():
            x_ = x_.cuda()
            y_real_ = y_real_.cuda()
            y_fake_ = y_fake_.cuda()



        D_result = D(x_)

        D_real_loss = BCE_loss(D_result, y_real_)
        D_real_score = D_result

        z_ = torch.randn((mini_batch, noise_size))
        z_ = Variable(z_)
        if torch.cuda.is_available():
            z_ = z_.cuda()
        G_result = G(z_)

        D_result = D(G_result)
        D_fake_loss = BCE_loss(D_result, y_fake_)
        D_fake_score = D_result

        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()

        D_losses.append(D_train_loss.data)

        # train generator G

        G.zero_grad()

        z_ = torch.randn((mini_batch, noise_size))
        y_ = torch.ones((mini_batch,1))

        z_, y_ = Variable(z_), Variable(y_)
        if torch.cuda.is_available():
            z_ = z_.cuda()
            y_ = y_.cuda()
        G_result = G(z_)
        D_result = D(G_result)
        G_train_loss = BCE_loss(D_result, y_)
        G_train_loss.backward()
        G_optimizer.step()

        G_losses.append(G_train_loss.data)
#         if i % 300 == 0:
#             print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
#                   % (epoch, train_epoch, i, len(dataloader),
#                      D_train_loss.data, G_train_loss.data))

    print('[%d/%d]: Loss_D: %.3f, Loss_G: %.3f' % (
        (epoch), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    if epoch % 1 == 0:
        pic = to_img(G_result.cpu().data)
        save_image(pic, './gan_img/output_{}.png'.format(epoch))

[0/50]: Loss_D: 1.301, Loss_G: 1.084
[1/50]: Loss_D: 1.368, Loss_G: 0.821
[2/50]: Loss_D: 69.503, Loss_G: 1.777
[3/50]: Loss_D: 100.012, Loss_G: 0.000
[4/50]: Loss_D: 100.000, Loss_G: 0.000
[5/50]: Loss_D: 100.000, Loss_G: 0.000
[6/50]: Loss_D: 100.000, Loss_G: 0.002
[7/50]: Loss_D: 100.000, Loss_G: 0.000
[8/50]: Loss_D: 100.000, Loss_G: 0.000
[9/50]: Loss_D: 100.000, Loss_G: 0.000
[10/50]: Loss_D: 100.000, Loss_G: 0.000
[11/50]: Loss_D: 100.000, Loss_G: 0.000
[12/50]: Loss_D: 100.000, Loss_G: 0.000
[13/50]: Loss_D: 100.000, Loss_G: 0.000
[14/50]: Loss_D: 100.000, Loss_G: 0.000
[15/50]: Loss_D: 100.000, Loss_G: 0.000
[16/50]: Loss_D: 100.000, Loss_G: 0.000
[17/50]: Loss_D: 100.000, Loss_G: 0.000
[18/50]: Loss_D: 100.000, Loss_G: 0.000
[19/50]: Loss_D: 100.000, Loss_G: 0.000
[20/50]: Loss_D: 100.000, Loss_G: 0.000
[21/50]: Loss_D: 100.000, Loss_G: 0.000
[22/50]: Loss_D: 100.000, Loss_G: 0.000
[23/50]: Loss_D: 100.000, Loss_G: 0.000
[24/50]: Loss_D: 100.000, Loss_G: 0.000
[25/50]: Loss_D