In [None]:
from IPython import display
import torch
import sys
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import torch.nn.functional as F
from torchvision.utils import save_image
content_folder = './'

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,)),
])


data = datasets.MNIST(
    root='./content/mnist',
    train=True,
    download=True,
    transform=img_transform
)
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)


In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x


In [None]:
class Generator(torch.nn.Module):
  def __init__(self):
      super(Generator, self).__init__()
      n_features = 100
      n_out = 784
      
      self.hidden0 = nn.Sequential(
          nn.Linear(n_features, 256),
          nn.LeakyReLU(0.2)
      )
      self.hidden1 = nn.Sequential(            
          nn.Linear(256, 512),
          nn.LeakyReLU(0.2)
      )
      self.hidden2 = nn.Sequential(
          nn.Linear(512, 1024),
          nn.LeakyReLU(0.2)
      )
      
      self.out = nn.Sequential(
          nn.Linear(1024, n_out),
          nn.Tanh()
      )

  def forward(self, x):
      x = self.hidden0(x)
      x = self.hidden1(x)
      x = self.hidden2(x)
      x = self.out(x)
      return x  


In [None]:
discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

loss = nn.BCELoss()
d_steps = 1

In [None]:
def check_data_real(size):
    data = Variable(torch.ones(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

def noise(size):
    n = Variable(torch.randn(size, 100))
    if torch.cuda.is_available(): return n.cuda() 
    return n

def check_data_fake(size):
    data = Variable(torch.zeros(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data
    
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)



In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    optimizer.zero_grad()
    
    pred_real = discriminator(real_data)
    error_real = loss(pred_real, check_data_real(real_data.size(0)))
    error_real.backward()

    pred_fake = discriminator(fake_data)
    error_fake = loss(pred_fake, check_data_fake(real_data.size(0)))
    error_fake.backward()
    
    optimizer.step()
    return error_real + error_fake, pred_real, pred_fake


In [None]:
def train_generator(optimizer, fake_data):
    optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction, check_data_real(prediction.size(0)))
    error.backward()
    optimizer.step()
    return error


In [None]:
num_test_samples = 64
test_noise = noise(num_test_samples)
num_epochs = 300

for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):

        real_data = Variable(images_to_vectors(real_batch))
        if torch.cuda.is_available(): real_data = real_data.cuda()
        fake_data = generator(noise(real_data.size(0))).detach()
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)

        fake_data = generator(noise(real_batch.size(0)))
        g_error = train_generator(g_optimizer, fake_data)

        print(f"Generator loss: {g_error:.8f}, Discriminator loss: {d_error:.8f}")

        if (n_batch) % 100 == 0:
            display.clear_output(True)
            test_images = vectors_to_images(generator(test_noise)).data.cpu()
            save_image(test_images, f"./content/img_{epoch}_{n_batch}.png", normalize=True)

