In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [0]:
# MNIST Dataset
original_train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
original_test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=True)

In [0]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set Hyper-parameters (change None)
BATCH_SIZE = 50
LEARNING_RATE_D = 0.0001
LEARNING_RATE_G = 0.0001
N_EPOCH = 20

In [0]:
# Define Train loader
train_tensors = original_train_dataset.data.float() / 255
test_tensors = original_test_dataset.data.float() / 255

train_dataset = torch.utils.data.TensorDataset(train_tensors, original_train_dataset.targets)
test_dataset = torch.utils.data.TensorDataset(test_tensors, original_test_dataset.targets)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [27]:
mnist_dim = original_train_dataset.train_data.size(1) * original_train_dataset.train_data.size(2)
print(mnist_dim)

784




In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        #######################################
        ## use linear or convolutional layer ##
        ##        it is better to use        ##
        ##    LeakyReLU instead of ReLU      ##
        ##   Dropout can help to converge    ##
        #######################################
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    def forward(self, img):
        x = F.leaky_relu(self.fc1(img), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        #######################################
        ## use linear or convolutional layer ##
        ##        it is better to use        ##
        ##    LeakyReLU instead of ReLU      ##
        #######################################
        self.fc1 = nn.Linear(100, 128)
        self.fc2 = nn.Linear(128, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, 784)

    def forward(self, z):
        x = F.leaky_relu(self.fc1(z), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))


D = Discriminator()
G = Generator()

In [0]:
# Device setting
D = D.to(device)
G = G.to(device)

In [0]:
# Create two optimizer for discriminator and generator (change None), you can also change the optimizer
opt_D = optim.Adam(D.parameters(), lr=LEARNING_RATE_D)
opt_G = optim.Adam(G.parameters(), lr=LEARNING_RATE_G)
# Loss function (use ".to(device)" to use gpu(cuda))
loss_function = nn.BCELoss().to(device)  

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 3) # set default size of plots

In [32]:
for epoch in range(N_EPOCH):
    for i, (img, label) in enumerate(train_loader):

        
        ######################################
        ###    if you used linear layers   ###
        ###    you have to flat images     ###
        ######################################
        img=img.reshape(BATCH_SIZE,784)
        real_img = img.to(device)
        
        fake_labels = torch.zeros(img.shape[0], 1).to(device)
        real_labels = torch.ones(img.shape[0], 1).to(device)

        ######################################
        ###     training discriminator     ###
        ### define loss_d for dicriminator ###
        ### be carful loss_d should not    ###
        ###        change generator !      ###
        ######################################
        
        z = torch.randn(img.shape[0], 100).to(device)
        fake_img = G(z)

        opt_D.zero_grad()
        d_f_labels=D(fake_img).to(device)
        fake_loss=loss_function(d_f_labels,fake_labels)
        d_r_labels=D(real_img)
        real_loss=loss_function(d_r_labels,real_labels)
        loss_d = fake_loss+real_loss
        
        loss_d.backward()
        opt_D.step()
        
        ######################################
        ###        training generator      ###
        ###  define loss_g for generator   ###
        ###   it is better to maximize     ###
        ###        generator loss !        ###
        ######################################
        z = torch.randn(img.shape[0], 100).to(device)
        fake_img = G(z)

        opt_G.zero_grad()
        d_out=D(fake_img)
        
        loss_g = loss_function(d_out,real_labels)

        
        loss_g.backward()
        opt_G.step()

    
    print("epoch: {} \t last batch loss D: {} \t last batch loss G: {}".format(epoch + 1, 
                                                                               loss_d.item(), 
                                                                               loss_g.item()))

    for i in range(3):
        for j in range(10):
            plt.subplot(3, 10, i * 10 + j + 1)
            plt.imshow(fake_img[i * 10 + j].detach().cpu().view(28, 28).numpy())
    plt.show()



Output hidden; open in https://colab.research.google.com to view.