In [1]:
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.autograd as autograd

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import numpy as np


In [2]:
class Front_End(nn.Module):
    '''front end of disc and q remains same only at the end that probability changes'''
    def __init__(self):
        super(Front_End,self).__init__()
        self.main=nn.Sequential(nn.Conv2d(1,64,4,2,1),
        nn.LeakyReLU(0.1,inplace=True),
        nn.Conv2d(64,128,4,2,1,bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.1,inplace=True),
        
        nn.Conv2d(128,1024,7,bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.1,inplace=True),
        )
    def forward(self,x):
        output=self.main(x)
        return output

class Disc(nn.Module):
    
    def __init__(self):
        super(Disc,self).__init__()
        self.main=nn.Sequential(
        nn.Conv2d(1024,1,1),
        nn.Sigmoid(),)
    def forward(self,x):
        output=self.main(x)
        output=output.view(-1,1)
        return output

class Qr(nn.Module):
    
    def __init__(self):
        super(Qr,self).__init__()
        self.conv=nn.Conv2d(1024,128,1)
        self.conv_disc=nn.Conv2d(128,10,1)
        self.conv_mu=nn.Conv2d(128,2,1)
        self.conv_var=nn.Conv2d(128,2,1)
    
    def forward(self,x):
        y=self.conv(x)
        disc_logits=self.conv_disc(y).squeeze()
        mu=self.conv_mu(y).squeeze()
        var=self.conv_var(y).squeeze()
        
        return disc_logits,mu,var


class Gen(nn.Module):
    
    def __init__(self):
        super(Gen,self).__init__()
        
        self.fc1=nn.Linear(74,1024)
        
        self.main=nn.Sequential(
            
            nn.ConvTranspose2d(1024,128,7,1,bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,1,4,2,1,bias=False),
            nn.Sigmoid(),
        )
    def forward(self,x):
        x=x.view(-1,74)
        x=self.fc1(x)
        x=x.view(-1,1024,1,1)
        output=self.main(x)
        return output
    

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.autograd as autograd

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import numpy as np


In [4]:
class log_gaussian:

  def __call__(self, x, mu, var):

    logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \
            (x-mu).pow(2).div(var.mul(2.0)+1e-6)
    
    return logli.sum(1).mean().mul(-1)

In [5]:
bs=15
idx = np.random.randint(10,size=bs)
c = np.zeros((bs, 10))
c[range(bs),idx] = 1.0

dis_c=torch.Tensor(c)
con_c=torch.Tensor(bs,2).uniform_(-1.0,1.0)
noise=torch.Tensor(bs,62).uniform_(-1.0,1.0)
z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)

In [6]:
import torchvision
batch_size=bs
# data_loader normalize [0, 1] ==> [-1, 1]
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('../Data_sets/MNIST_data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

In [7]:
criterionD=nn.BCELoss()
criterionQ_dis=nn.MSELoss()
criterionQ_con=log_gaussian()
FE=Front_End().cuda()
G=Gen().cuda()
D=Disc().cuda()
Q=Qr().cuda()

FE.apply(weights_init)
G.apply(weights_init)
D.apply(weights_init)
Q.apply(weights_init)


Qr(
  (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
  (conv_disc): Conv2d(128, 10, kernel_size=(1, 1), stride=(1, 1))
  (conv_mu): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
  (conv_var): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
)

In [None]:
optimD = optim.Adam([{'params':FE.parameters()}, {'params':D.parameters()}], lr=0.0002, betas=(0.5, 0.99))
optimG = optim.Adam([{'params':G.parameters()}, {'params':Q.parameters()}], lr=0.001, betas=(0.5, 0.99))

x = np.linspace(-1, 1, 15).reshape(1, -1)
x = x.reshape(-1, 1)
c1 = np.hstack([x, np.zeros_like(x)])
print(c1)
c2 = np.hstack([np.zeros_like(x), x])

idx = np.random.randint(10,size=bs)
one_hot = np.zeros((bs, 10))
one_hot[range(bs), idx] = 1
print(idx)
fix_noise = torch.Tensor(bs, 62).uniform_(-1, 1)
for epoch in range(40):
    for num_iters,data in enumerate(train_loader,0):
        real_x,labels=data
        idx = np.random.randint(10,size=bs)
        c = np.zeros((bs, 10))
        c[range(bs),idx] = 1.0
        real_x=Variable(real_x.cuda())

        dis_c=torch.Tensor(c)
        con_c=torch.Tensor(bs,2).uniform_(-1.0,1.0)
        noise=torch.Tensor(bs,62).uniform_(-1.0,1.0)
        fe_out1=FE(real_x)
        probs_real=D(fe_out1)
        real_lab=torch.ones(bs)
        real_lab=Variable(real_lab.cuda())
        loss_real=criterionD(probs_real,real_lab)
        loss_real.backward()
        z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
        z=Variable(z.cuda())
        fake_x=G(z)
        fe_out2=FE(fake_x.detach())
        probs_fake=D(fe_out2)
        fake_lab=torch.zeros(bs)
        probs_fake=Variable(probs_fake.cuda(),requires_grad=True)
        fake_lab=Variable(fake_lab.cuda())
        loss_fake=criterionD(probs_fake,fake_lab)                                   
        loss_fake.backward() 
        D_loss=loss_real+loss_fake
        optimD.step()
        
        #G & Q part
        optimG.zero_grad()
        fe_out=FE(fake_x)
        probs_fake=D(fe_out)
        label=torch.ones(bs)
        label=Variable(label.cuda())
        probs_fake=Variable(probs_fake.cuda())
        class_ = torch.LongTensor(idx)
        recon_loss=criterionD(probs_fake,label)
        
        q_logits , q_mu, q_var=Q(fe_out)
        q_logits=Variable(q_logits.cuda())
        class_=Variable(class_.cuda())
        one = np.zeros((bs, 10))
        for i in range(len(class_)):
            one[i,class_[i]]=1
        one=torch.from_numpy(one)
        one=one.type(torch.FloatTensor)
        q_logits=q_logits.type(torch.FloatTensor)
        class_=class_.type(torch.FloatTensor)
        dis_loss = criterionQ_dis(q_logits,one)
        
        con_c=Variable(con_c.cuda())
        con_loss = criterionQ_con(con_c, q_mu, q_var)
        recon_loss=Variable(recon_loss.cuda())
        dis_loss=Variable(dis_loss.cuda())
        
        G_loss= recon_loss + dis_loss + con_loss
        G_loss.backward()
        optimG.step()
        if num_iters %15 == 0:

          print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
            epoch, num_iters, D_loss.data.cpu().numpy(),
            G_loss.data.cpu().numpy())
          )
          f=torch.from_numpy(c1)
          f=f.type(torch.FloatTensor)
          tr=torch.FloatTensor(one_hot)
          tr=tr.type(torch.FloatTensor)
          z = torch.cat([noise, tr, f], 1).view(-1, 74, 1, 1)
          z=Variable(z.cuda())
          x_save = G(z)
          save_image(x_save.data, './tmp/c1.png', nrow=10)

          f2=torch.from_numpy(c1)
          f2=f2.type(torch.FloatTensor)
          z = torch.cat([fix_noise, tr, f2], 1).view(-1, 74, 1, 1)
          z=Variable(z.cuda())
          x_save = G(z)
          save_image(x_save.data, './tmp/c2.png', nrow=10)
        
                                           

[[-1.          0.        ]
 [-0.85714286  0.        ]
 [-0.71428571  0.        ]
 [-0.57142857  0.        ]
 [-0.42857143  0.        ]
 [-0.28571429  0.        ]
 [-0.14285714  0.        ]
 [ 0.          0.        ]
 [ 0.14285714  0.        ]
 [ 0.28571429  0.        ]
 [ 0.42857143  0.        ]
 [ 0.57142857  0.        ]
 [ 0.71428571  0.        ]
 [ 0.85714286  0.        ]
 [ 1.          0.        ]]
[8 7 1 7 3 7 9 7 5 4 6 7 9 6 4]
Epoch/Iter:0/0, Dloss: 1.4552903175354004, Gloss: nan


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch/Iter:0/15, Dloss: 1.8851802349090576, Gloss: nan
Epoch/Iter:0/30, Dloss: 3.1967902183532715, Gloss: 2.4683454036712646
Epoch/Iter:0/45, Dloss: 4.352138519287109, Gloss: nan
Epoch/Iter:0/60, Dloss: 5.97489595413208, Gloss: nan
Epoch/Iter:0/75, Dloss: 6.385566711425781, Gloss: nan
Epoch/Iter:0/90, Dloss: 8.77228832244873, Gloss: nan
Epoch/Iter:0/105, Dloss: 8.574454307556152, Gloss: nan
Epoch/Iter:0/120, Dloss: 9.012786865234375, Gloss: 2.4266304969787598
Epoch/Iter:0/135, Dloss: 11.107748985290527, Gloss: nan
Epoch/Iter:0/150, Dloss: 11.950663566589355, Gloss: nan
Epoch/Iter:0/165, Dloss: 9.981861114501953, Gloss: nan
Epoch/Iter:0/180, Dloss: 13.162750244140625, Gloss: nan
Epoch/Iter:0/195, Dloss: 10.980905532836914, Gloss: nan
Epoch/Iter:0/210, Dloss: 12.234779357910156, Gloss: nan
Epoch/Iter:0/225, Dloss: 13.347461700439453, Gloss: nan
Epoch/Iter:0/240, Dloss: 19.13930320739746, Gloss: nan
Epoch/Iter:0/255, Dloss: 16.24195671081543, Gloss: nan
Epoch/Iter:0/270, Dloss: 13.5755233