In [0]:
import torch
import torchvision
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

from matplotlib import pyplot as plt
from IPython import display
%load_ext autoreload
%autoreload 2

In [0]:
#Used fashion mnist dataset
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])
        ])
mnist = datasets.FashionMNIST(root='./mnist', train=True, transform=transform, download=True)
mnist.data.shape


In [0]:

Epoch = 1000
CUDA = torch.cuda.is_available()
batch_size = 128
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [4]:
CUDA

True

In [0]:
class Discriminator(torch.nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    in_features = 784
    hidden1 = 256
    hidden2 = 256
    out_features = 1
    
    
    self.hid1 = nn.Sequential(nn.Linear(in_features,hidden1), nn.ReLU(),nn.Dropout(0.1))
    self.hid2 = nn.Sequential(nn.Linear(hidden1,hidden2), nn.ReLU(),nn.Dropout(0.1))
    self.out  = nn.Sequential(nn.Linear(hidden2,out_features), nn.Sigmoid())
    
  def forward(self,x):
    x = self.hid1(x)
    x = self.hid2(x)
    x = self.out(x)
    return x

In [0]:

class Generator(torch.nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        in_features = 100
        hidden1 = 256
        hidden2 = 256
        out_features = 784
        
        self.hid1 = nn.Sequential(nn.Linear(in_features,hidden1), nn.ReLU(),nn.Dropout(0.1))
        self.hid2 = nn.Sequential(nn.Linear(hidden1,hidden2), nn.ReLU(),nn.Dropout(0.1))
        self.out  = nn.Sequential(nn.Linear(hidden2,out_features), nn.Tanh())

    def forward(self, x):
        x = self.hid1(x)
        x = self.hid2(x)
        x = self.out(x)
        return x
      
    def image(self,size):
        n = Variable(torch.randn(size, 100))
        if CUDA:
          n = n.cuda()
        return self.forward(n)
        

In [0]:
generator = Generator()
discriminator = Discriminator()
if CUDA:
    discriminator.cuda()
    generator.cuda()
   

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.001)
loss = nn.BCELoss()

In [0]:
def discriminator_train(real,fake):
  
  real_label = Variable(torch.ones(real.shape[0])).cuda() if CUDA else Variable(torch.ones(real.shape[0]))
  fake_label = Variable(torch.zeros(real.shape[0])).cuda() if CUDA else Variable(torch.zeros(real.shape[0]))
  
  d_optimizer.zero_grad()
  
  real_predict = discriminator(real)
  
  delta_real = loss(real_predict , real_label)
  delta_real.backward()

  fake_predict = discriminator(fake)
  delta_fake = loss(fake_predict , fake_label)
  delta_fake.backward()
  
  d_optimizer.step()
  
  delta = delta_real + delta_fake
  
  return real_predict , fake_predict , delta

def generator_train(fake):
  
  fake_label = Variable(torch.ones(fake.shape[0])).cuda() if CUDA else Variable(torch.ones(fake.shape[0]))
  g_optimizer.zero_grad()
  
  predict = discriminator(fake)
  delta = loss(predict , fake_label)
  delta.backward()
  
  g_optimizer.step()
  
  return predict , delta
    

In [0]:
def plot_images(epoch=0):
  images = generator.image(16)
  images = images.view(images.size(0), 1, 28, 28)
  figure = plt.figure()
  num_of_images = 16
  for index in range(1, num_of_images + 1):
      plt.subplot(4, 4, index)
      plt.axis('off')
      plt.imshow(images[index-1].cpu().detach().numpy().squeeze(), cmap='gray')
  if epoch%10 == 0 and epoch!= 0:
    plt.savefig('Epoch_{:04d}.png'.format(epoch)) 
  plt.show()

In [0]:
for epoch in range(Epoch):
  
  for batch , (train_data , label ) in enumerate(data_loader):
    train_data = train_data.view(train_data.size(0), 784)
    real = train_data.cuda() if CUDA else train_data
    fake = generator.image(real.shape[0]).detach()
    real_predict , fake_predict , derror = discriminator_train(real,fake) 
    
    fake = generator.image(real.shape[0])
    generator_predict , gerror = generator_train(fake)
    
  display.clear_output(True)  
  print("Epoch:{} [{}/{}] DError:{} GError:{}".format(epoch,batch,len(data_loader), derror.mean(),gerror.mean() ))
  plot_images()
    
    

