<a href="https://colab.research.google.com/github/shipra-saxena/miniature-potato/blob/master/Gan_Using_keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:

import torch
from torch import nn,optim
from torch.autograd.variable import Variable
from torchvision import transforms,datasets

In [0]:


from utils import Logger

In [0]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])

In [0]:
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

num_batches = len(data_loader)

In [0]:
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)


In [0]:
class DiscriminatorNet(nn.Module):
    
    def __init__(self):
        super(DiscriminatorNet,self).__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential(nn.Linear(n_features,1024),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.3)
                                    )
        
        self.hidden1 = nn.Sequential(nn.Linear(1024,512),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.3),
                                    )
        
        self.hidden2 = nn.Sequential(nn.Linear(512,256),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.3),
                                    )
        
        self.out = nn.Sequential(nn.Linear(256,n_out),
                                nn.Sigmoid()
                                )
        
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
discriminator = DiscriminatorNet()

In [0]:
def images_to_vectors(images):
    return images.view(images.size(0),784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0),1,28,28)


In [0]:
class GeneratorNet(nn.Module):
    
    def __init__(self):
        super(GeneratorNet,self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(nn.Linear(n_features,256),
                                    nn.LeakyReLU(0.2),
                                    )
        
        self.hidden1 = nn.Sequential(nn.Linear(256,512),
                                    nn.LeakyReLU(0.2),
                                    )
        
        self.hidden2 = nn.Sequential(nn.Linear(512,1024),
                                    nn.LeakyReLU(0.2),
                                    )
        
        self.out = nn.Sequential(nn.Linear(1024,n_out),
                                nn.Tanh()
                                )
        
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
generator = GeneratorNet()

In [0]:
def noise(size):
    n = Variable(torch.randn(size,100))
    return n

In [0]:
d_optimizer = optim.Adam(discriminator.parameters(),lr = 0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr = 0.0002)


In [0]:
loss = nn.BCELoss()

In [0]:
def ones_target(size):
    data = Variable(torch.ones(size,1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

In [0]:
def train_discriminator(optimizer,real_data,fake_data):
    
    N = real_data.size(0)
    
    #Resetting gradients
    optimizer.zero_grad()
    
    #Training on real data
    prediction_real = discriminator(real_data)
    #Calculate error and backpropagation
    error_real = loss(prediction_real,ones_target(N))
    error_real.backward()
    
    #Training on fake data
    prediction_fake = discriminator(fake_data)
    #Calculate error and backpropagation
    error_fake = loss(prediction_fake,zeros_target(N))
    error_fake.backward()
    
    #Update weights with gradients
    optimizer.step()
    
    # Return error and prediction for real and fake inputs
    return error_real + error_fake,prediction_real,prediction_fake

In [0]:
def train_generator(optimizer,fake_data):
    
    N = fake_data.size(0)
    
    #Resetting gradients
    optimizer.zero_grad()
    
    # Noise generation
    prediction = discriminator(fake_data)
    
    # Calculating error and backpropagation
    error = loss(prediction,ones_target(N))
    error.backward()
    
    # Update weights with gradients
    optimizer.step()
    
    return error

In [0]:
num_test_samples = 16
test_noise = noise(num_test_exanples)

In [0]:
 Creating Logger instance

logger = Logger(model_name = "VGAN",data_name = "MNIST")

# Total number of epochs
num_epochs = 200

for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        
        N = real_batch.size(0)
        
        # Training Discriminator
        real_data = Variable(images_to_vectors(real_batch))
        
        # Generate fake data and detach
        # So no gradients are calculated for generator
        fake_data = generator(noise(N)).detach()
        
        d_error,d_pred_real,d_pred_fake = train_discriminator(d_optimizer,real_data,fake_data)
        
        # Training Generator
        fake_data = generator(noise(N))
        
        g_error = train_generator(g_optimizer,fake_data)
        
        logger.log(d_error,g_error,epoch,n_batch,num_batches)
        
        if (n_batch) % 100 == 0: 
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples, 
                epoch, n_batch, num_batches
            );
            
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

