In [1]:
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
from torch.nn import init
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

def my_weight_init(m):
    if isinstance(m, torch.nn.Linear):
        init.xavier_uniform(m.weight.data) #initialising the model with random params 
        init.constant(m.bias.data, 0)


In [2]:
class Generator(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.initial_features=512
        
        self.n_first_layer_inputs=3*3*512
        
        self.fc1=nn.Linear(100,self.n_first_layer_inputs)
        self.bn1=nn.BatchNorm2d(self.initial_features)
        
        self.deconv1=nn.ConvTranspose2d(self.initial_features,self.initial_features//2,kernel_size=3,stride=2,padding=0,output_padding=0)
        self.bn2 = nn.BatchNorm2d(self.initial_features//2)
        self.deconv2=nn.ConvTranspose2d(self.initial_features//2,self.initial_features//4,kernel_size=5,stride=2,padding=2,output_padding=1)
        self.bn3 = nn.BatchNorm2d(self.initial_features//4)
        self.deconv3=nn.ConvTranspose2d(self.initial_features//4,1,kernel_size=5,stride=2,padding=2,output_padding=1)
        
        for m in self.modules():
            my_weight_init(m)
            
        
    def forward(self,x):
            x=self.fc1(x)
            x=x.view(-1,512,3,3)
            x=F.relu(self.bn1(x))
            x=F.relu(self.deconv1(x))
            x=self.bn2(x)
            x=F.relu(self.deconv2(x))
            x=self.bn3(x)
            x=F.relu(self.deconv3(x))
            x=F.tanh(x)
            return x
    
    
            
            
            
            
        
        

In [3]:
 class Discriminator(nn.Module):
    def __init__(self):
            
            super().__init__()
            self.conv1=nn.Conv2d(1,64,kernel_size=5,stride=2,padding=2) #14x14x64
            self.bn1=nn.BatchNorm2d(64)
            self.conv2=nn.Conv2d(64,128,kernel_size=5,stride=2,padding=2) #7x7x128
            self.bn2=nn.BatchNorm2d(128)
            self.conv3=nn.Conv2d(128,256,kernel_size=5,stride=2,padding=2) #4x4x256
            self.bn3=nn.BatchNorm2d(256)
            self.fc4=nn.Linear(4*4*256,1)
            
            for m in self.modules():
                my_weight_init(m)
        
        
    def forward(self,input):
        x=F.relu(self.conv1(input))
        x=self.bn1(x)
        x=F.relu(self.conv2(x))
        x=self.bn2(x)
        x=F.relu(self.conv3(x))
        x=self.bn3(x)
        x=x.view(-1,4*4*256)
        x=self.fc4(x)
        x=F.sigmoid(x)
        return x

In [4]:
def save_generator_output(G, fixed_z, img_str, title):
    n_images = fixed_z.size()[0]
    n_rows = np.sqrt(n_images).astype(np.int32)
    n_cols = np.sqrt(n_images).astype(np.int32)
    
    z_ = fixed_z
    samples = G(z_)
    samples = samples.data.numpy()

    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(5,5), sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples):
        ax.axis('off')
        ax.set_adjustable('box-forced')
        ax.imshow(img.reshape((28,28)), cmap='Greys_r', aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.suptitle(title)
    plt.savefig(img_str)
    plt.close(fig)

assets_dir = './assets/'
if not os.path.isdir(assets_dir):
    os.mkdir(assets_dir)

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../Data_sets/MNIST_data', train=True, download=True, transform=transform),
    batch_size=15, shuffle=True)

In [10]:
image_width = 28
image_height = 28
image_channels = 1
x_size = image_channels
z_size = 100
# n_hidden = 128
# n_classes = 10
epochs = 30
batch_size = 64
learning_rate = 0.0002
alpha = 0.2
beta1 = 0.5
print_every = 50
G = Generator()
D = Discriminator()
fixed_z = torch.Tensor(25, z_size).uniform_(-1, 1)
BCE_loss = torch.nn.BCELoss()
G_opt = torch.optim.Adam( G.parameters(), lr=learning_rate, betas=[0.5, 0.999] )
D_opt = torch.optim.Adam( D.parameters(), lr=learning_rate, betas=[0.5, 0.999] )
step=0


  from ipykernel import kernelapp as app
  app.launch_new_instance()


In [9]:
losses=[]
for e in range(10):
    for x,_ in train_loader:
        step+=1
        curr_batch_size=x.size()[0]
        
        y_real=torch.ones(curr_batch_size)
        y_fake=torch.zeros(curr_batch_size)
        
        D_result_real=D(x)
        D_loss_real = BCE_loss(D_result_real, y_real)
        
        #generating the input for getting images using gan
        z1_ = torch.Tensor(curr_batch_size, 100).uniform_(-1, 1)
        x_fake=G(z1_)
        D_result_fake= D(x_fake)
        D_loss_fake = BCE_loss(D_result_fake,y_fake)
        
        D_loss=D_loss_fake + D_loss_real
        
        D.zero_grad()
        
        D_loss.backward()
        D_opt.step()
        
        #training the generator
        
        z2=torch.Tensor(curr_batch_size,100).uniform_(-1,1)
        y_=torch.ones(curr_batch_size)
        G_res=G(z2)
        D_G=D(G_res)
        G_loss=BCE_loss(D_G,y_)
        G.zero_grad()
        G_loss.backward()
        G_opt.step()
        if step % print_every == 0:
            losses.append((D_loss.data[0], G_loss.data[0]))

            print("Epoch {}/{}...".format(e+1, epochs),
                "Discriminator Loss: {:.4f}...".format(D_loss.data[0]),
                "Generator Loss: {:.4f}".format(G_loss.data[0])) 
    # Sample from generator as we're training for viewing afterwards
    image_fn = './assets/epoch_{:d}_pytorch.png'.format(e)
    image_title = 'epoch {:d}'.format(e)
    save_generator_output(G, fixed_z, image_fn, image_title)
        
        

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch 1/30... Discriminator Loss: 0.0102... Generator Loss: 7.4886


KeyboardInterrupt: 