<a href="https://colab.research.google.com/github/tejasmeshram99/Practicing-DL-Models/blob/master/GANs/InfoGAN/InfoGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import math
import itertools
from IPython import display

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()

    self.fc = nn.Sequential(
        nn.Linear(74,1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(True),
        nn.Linear(1024,128*7*7),
        nn.BatchNorm1d(128*7*7),
        nn.ReLU(True),
    )

    self.gen = nn.Sequential(
        nn.ConvTranspose2d(128,64,4,2,1,bias = False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        nn.ConvTranspose2d(64,1,4,2,1,bias = False),
        nn.Tanh()
    )

  def forward(self,x):
    x = self.fc(x)
    x = x.view(-1,128,7,7)
    x = self.gen(x)
    return x

In [None]:
class FrontEnd(nn.Module):
  def __init__(self):
    super().__init__()

    self.layer = nn.Sequential(
        nn.Conv2d(1,64,4,2,1),
        nn.LeakyReLU(0.1,inplace = True),
        nn.Conv2d(64,128,4,2,1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.1,inplace = True),
    )  

  def forward(self,x):
    x = self.layer(x)
    x = x.view(-1,128*7*7)

    return x  

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()

    self.prob = nn.Sequential(     
        nn.Linear(128*7*7,1024),
        nn.BatchNorm1d(1024),
        nn.LeakyReLU(0.2,True),
        nn.Linear(1024,512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(0.2,True),
        nn.Linear(512,128),
        nn.BatchNorm1d(128),
        nn.LeakyReLU(0.2,True),
        nn.Linear(128,1),
        
    )

  def forward(self,x):
    x = self.prob(x)
    x = F.sigmoid(x)

    return x

In [None]:
class Recognizer(nn.Module):
  def __init__(self):
    super().__init__()

    self.classprob = nn.Sequential(     
        nn.Linear(128*7*7,1024),
        nn.BatchNorm1d(1024),
        nn.LeakyReLU(0.2,True),
        nn.Linear(1024,512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(0.2,True),
        nn.Linear(512,128),
        nn.BatchNorm1d(128),
        nn.LeakyReLU(0.2,True),
        nn.Linear(128,10),
    )

  def forward(self,x):
    x = self.classprob(x)

    return x

In [None]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_images = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_images, batch_size=100,shuffle=True, num_workers=2)

In [None]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv')!=-1:
    m.weight.data.normal_(0.0,0.02)
  elif classname.find('BatchNorm')!=-1:
    m.weight.data.normal_(1.0, 0.02)
    m.bias.data.fill_(0)

In [None]:
def gen_noise(batch_size):
  idx = np.random.randint(10,size = batch_size)
  c = np.zeros((batch_size,10))
  c[range(batch_size),idx] = 1
  c = torch.Tensor(c)
  noise = torch.FloatTensor(batch_size,64)
  noise.data.uniform_(-10,10)
  z = torch.cat((noise,c),1).view(-1,74)
  z = z.cuda()

  return z, idx

In [None]:
G = Generator().cuda()
FE = FrontEnd().cuda()
D = Discriminator().cuda()
Q = Recognizer().cuda()

G.apply(weights_init)
FE.apply(weights_init)
D.apply(weights_init)
Q.apply(weights_init)

Recognizer(
  (classprob): Sequential(
    (0): Linear(in_features=6272, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Linear(in_features=512, out_features=128, bias=True)
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [None]:
criterion_D_FE = nn.BCELoss()
criterion_GQ = nn.CrossEntropyLoss()
optimizer_D_FE = optim.Adam([{'params':FE.parameters()}, {'params':D.parameters()}], lr=0.0002, betas=(0.5, 0.99))
optimizer_GQ = optim.Adam([{'params':G.parameters()}, {'params':Q.parameters()}], lr=0.001, betas=(0.5, 0.99))

In [None]:
num_test_samples = 100
num_epochs = 15
index = np.arange(10).repeat(10)
one_hot = np.zeros((num_test_samples,10))
one_hot[range(num_test_samples), index] = 1
test_noise = torch.FloatTensor(num_test_samples,64)
test_noise.data.uniform_(-10,10)
test_z = torch.cat((test_noise,torch.Tensor(one_hot)),1).view(-1,74)
test_z = test_z.cuda()

In [None]:
# Training
for epoch in range(num_epochs):
  for n,(images,_) in enumerate(train_loader):
    bs = images.size(0)   #batch size
    images = Variable(images).cuda()

    optimizer_D_FE.zero_grad()
    target1 = torch.Tensor(np.ones(bs)).cuda()
    target1 = Variable(target1,requires_grad = False)
    out = FE(images)
    real_prob = D(out)
    real_loss = criterion_D_FE(real_prob,target1)
    real_loss.backward()

    z,idx = gen_noise(bs)
    z = Variable(z)
    target2 = torch.Tensor(np.zeros(bs)).cuda()
    target2 = Variable(target2,requires_grad = False)
    fake_image = G(z)
    feout_G = FE(fake_image)
    fake_prob = D(feout_G)
    fake_loss = criterion_D_FE(fake_prob,target2)
    fake_loss.backward(retain_graph = True)

    D_loss = real_loss + fake_loss
    optimizer_D_FE.step()

    
    optimizer_GQ.zero_grad()
    feout_G = FE(fake_image)
    fake_prob = D(feout_G)
    label = torch.Tensor(np.ones(bs))
    label = label.cuda()
    label = Variable(label,requires_grad=False)
    reconstruct_loss = criterion_D_FE(fake_prob,label)

    q_logits = Q(feout_G)
    target3 = torch.LongTensor(idx).cuda()
    target3 = Variable(target3)
    q_loss = criterion_GQ(q_logits,target3)
    G_loss = reconstruct_loss + q_loss
    G_loss.backward(retain_graph = True)
    optimizer_GQ.step()

    if n % 100 == 0:
      print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
      epoch, n, D_loss.data.cpu().numpy(),G_loss.data.cpu().numpy())    
          )


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Epoch/Iter:0/0, Dloss: 1.4241058826446533, Gloss: 3.062471628189087
Epoch/Iter:0/100, Dloss: 1.4273419380187988, Gloss: 3.118267774581909
Epoch/Iter:0/200, Dloss: 1.4225895404815674, Gloss: 3.0024032592773438
Epoch/Iter:0/300, Dloss: 1.4163697957992554, Gloss: 3.070539712905884
Epoch/Iter:0/400, Dloss: 1.422751545906067, Gloss: 3.08467960357666


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


Epoch/Iter:1/0, Dloss: 1.4215741157531738, Gloss: 3.035616636276245
Epoch/Iter:1/100, Dloss: 1.4263967275619507, Gloss: 3.0620474815368652
Epoch/Iter:1/200, Dloss: 1.4309176206588745, Gloss: 3.0539472103118896
Epoch/Iter:1/300, Dloss: 1.4317901134490967, Gloss: 3.1278605461120605
Epoch/Iter:1/400, Dloss: 1.425169587135315, Gloss: 3.097337484359741
Epoch/Iter:2/0, Dloss: 1.4198166131973267, Gloss: 3.0535776615142822
Epoch/Iter:2/100, Dloss: 1.4231886863708496, Gloss: 3.0554373264312744
Epoch/Iter:2/200, Dloss: 1.4218871593475342, Gloss: 3.1453962326049805
Epoch/Iter:2/300, Dloss: 1.424310326576233, Gloss: 3.0779623985290527
Epoch/Iter:2/400, Dloss: 1.4170933961868286, Gloss: 3.1190378665924072
Epoch/Iter:3/0, Dloss: 1.4311838150024414, Gloss: 3.1377978324890137
Epoch/Iter:3/100, Dloss: 1.4316771030426025, Gloss: 3.076446533203125
Epoch/Iter:3/200, Dloss: 1.4228405952453613, Gloss: 3.0053305625915527
Epoch/Iter:3/300, Dloss: 1.4267805814743042, Gloss: 3.085775375366211
Epoch/Iter:3/400, 

In [None]:
test_images = G(test_z)
save_image(test_images.data,'./data/epoch_{:d}_pytorch.png'.format(epoch),nrow=10)