In [1]:
import torch
import numpy as np
import skimage.io
import skimage
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas as pd
import random
import itertools
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
train_X = np.load("all_img.npy")
train_X = torch.from_numpy(train_X).type(torch.FloatTensor)

In [8]:
# fixed input: two 10-dimensional data
z = torch.randn((1, 72)).repeat(100,1)
c = []
for i in range(10):
    c_1 = torch.zeros(10,10)
    c_1[:,i] = 1
    c_2 = torch.zeros(10,10)
    c_2[[range(10)],[range(10)]] = 1
    c_ = torch.cat((c_1, c_2),1)
    c.append(c_)
c = torch.cat(c)
fixed_input = torch.cat((z, c),1).view(100,92,1,1)
print(fixed_input.size())
fixed_input = Variable(fixed_input).cuda()

torch.Size([100, 92, 1, 1])


In [9]:
class Generator(nn.Module):
    def __init__(self, figsize=64):
        super(Generator, self).__init__()
        self.noise_dim = 72+20
        self.figsize = figsize
        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            
            nn.ConvTranspose2d( self.noise_dim, self.figsize * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(figsize * 8),
            nn.ReLU(inplace=True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(self.figsize * 8, self.figsize * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(figsize * 4),
            nn.ReLU(inplace=True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(self.figsize * 4, self.figsize * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.figsize * 2),
            nn.ReLU(inplace=True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(self.figsize * 2, self.figsize, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.figsize),
            nn.ReLU(inplace=True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(self.figsize, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, X):
        output = self.decoder(X)/2.0+0.5
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        figsize = 64
        self.conv_decoder = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, figsize, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(figsize, figsize * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(figsize * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(figsize * 2, figsize * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(figsize * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(figsize * 4, figsize * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(figsize * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
                    
        self.D_out = nn.Sequential(
            nn.Conv2d(figsize * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.Q_out = nn.Sequential(
            nn.Linear(8192, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(),
            nn.Linear(128, 20),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.conv_decoder(x)
        encode_output_size = x.size()
#         print("conv encode size:",x.size()) # (batch_size, 512,4,4)
        real_fake = self.D_out(x).view(-1, 1)
        x = x.view(-1, 512*4*4)
        discrete = self.Q_out(x)
        return real_fake, discrete

In [10]:
BATCH_SIZE = 64
G = Generator().cuda()
D = Discriminator().cuda()

BCE_loss = nn.BCELoss().cuda()
CE_loss = nn.CrossEntropyLoss().cuda()

# setup optimizer
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerInfo = optim.Adam(itertools.chain(G.parameters(), D.parameters()), 
                           lr=0.0002, betas=(0.5, 0.999))

D_loss_list = []
G_loss_list = []
D_fake_acc_list = []
D_real_acc_list = []
Info_loss_list = []
for epoch in range(120):
    print("Epoch:", epoch+1)
    epoch_D_loss = 0.0
    epoch_G_loss = 0.0
    D_fake_acc = 0.0
    D_real_acc = 0.0
    epoch_Info_loss = 0.0
    total_length = len(train_X)
    # shuffle
    perm_index = torch.randperm(total_length)
    train_X_sfl = train_X[perm_index]
    # learning rate decay
    if (epoch+1) == 50:
        optimizerG.param_groups[0]['lr'] /= 2
        optimizerD.param_groups[0]['lr'] /= 2
        optimizerInfo.param_groups[0]['lr'] /= 2
        print("learning rate change!")

    if (epoch+1) == 80:
        optimizerG.param_groups[0]['lr'] /= 2
        optimizerD.param_groups[0]['lr'] /= 2
        optimizerInfo.param_groups[0]['lr'] /= 2
        print("learning rate change!")
        
    # construct training batch
    for index in range(0,total_length ,BATCH_SIZE):
        if index+BATCH_SIZE > total_length:
            break
            
        # zero the parameter gradients
        D.zero_grad()
        input_X = train_X_sfl[index:index+BATCH_SIZE]
        
        #### train with real image -> ground truth = real label
        real_image = Variable(input_X.cuda()) # use GPU 
        real_label = Variable(torch.ones((BATCH_SIZE))).cuda()
        output, disc_output = D(real_image)
        D_real_loss = BCE_loss(output, real_label)
        D_real_acc += np.mean(((output > 0.5).cpu().data.numpy() == real_label.cpu().data.numpy()))
        
        #### train with fake image -> ground truth = fake label
        z_ = torch.randn((BATCH_SIZE, 72,1,1))
        c_1 = np.random.multinomial(1, 10 * [float(1.0 /10)], size=[BATCH_SIZE])
        c_2 = np.random.multinomial(1, 10 * [float(1.0 /10)], size=[BATCH_SIZE])
        c1 = np.concatenate((c_1,c_2),1)
        c1 = torch.from_numpy(c1).type(torch.FloatTensor).view(BATCH_SIZE,20,1,1)
        combine_input = Variable(torch.cat((z_, c1),1)).cuda()
        
        fake_image = G(combine_input)
        fake_label = Variable(torch.zeros((BATCH_SIZE))).cuda()
        output, disc_output = D(fake_image)
        D_fake_loss = BCE_loss(output, fake_label)
        D_fake_acc += np.mean(((output > 0.5).cpu().data.numpy() == fake_label.cpu().data.numpy()))
        
        # update D
        D_train_loss = D_real_loss + D_fake_loss
        epoch_D_loss+=(D_train_loss.data[0])
        D_train_loss.backward(retain_graph=True)
        optimizerD.step()
        
        #### train Generator
        G.zero_grad()
        # generate fake image
        z_ = torch.randn((BATCH_SIZE, 72,1,1))
        c_1 = np.random.multinomial(1, 10 * [float(1.0 /10)], size=[BATCH_SIZE])
        c_2 = np.random.multinomial(1, 10 * [float(1.0 /10)], size=[BATCH_SIZE])
        c1 = np.concatenate((c_1,c_2),1)
        c1 = torch.from_numpy(c1).type(torch.FloatTensor).view(BATCH_SIZE,20,1,1)
        combine_input = Variable(torch.cat((z_, c1),1)).cuda()
        c1 = Variable(c1).cuda()
        
        fake_image = G(combine_input)
        fake_label_for_G = Variable(torch.ones((BATCH_SIZE))).cuda()
        output, disc_output = D(fake_image)
        G_loss = BCE_loss(output, fake_label_for_G)
        epoch_G_loss += (G_loss.data[0])
        G_loss.backward(retain_graph=True)
        optimizerG.step()
        
        # information loss
        disc_loss = BCE_loss(disc_output,c1.view(BATCH_SIZE,-1))
        epoch_Info_loss += (disc_loss.data[0])
        disc_loss.backward()
        optimizerInfo.step()
        
    print("training D Loss:",epoch_D_loss/(total_length))
    print("training G Loss:", epoch_G_loss/(total_length))
    print("training Info Loss:", epoch_Info_loss/(total_length))
    D_loss_list.append(epoch_D_loss/(total_length))
    G_loss_list.append(epoch_G_loss/(total_length))
    Info_loss_list.append(epoch_Info_loss/(total_length))
    
    print("D_real_acc:", D_real_acc/(total_length/BATCH_SIZE))
    print("D_fake_acc:", D_fake_acc/(total_length/BATCH_SIZE))
    
    D_real_acc_list.append(D_real_acc/(total_length/BATCH_SIZE))
    D_fake_acc_list.append(D_fake_acc/(total_length/BATCH_SIZE))
    # evaluation
    G.eval()
    fixed_img_output = G(fixed_input)
    G.train()
    torchvision.utils.save_image(fixed_img_output.cpu().data, './InfoGAN_output2/figB_'+str(epoch+1)+'.jpg',nrow=10)
    
torch.save(G.state_dict(), "./models/InfoG_model_2.pkt")

Epoch: 1


  "Please ensure they have the same size.".format(target.size(), input.size()))


training D Loss: 0.021318719933916326
training G Loss: 0.012957220859712858
training Info Loss: 0.006141772730806231
D_real_acc: 0.9784847844959058
D_fake_acc: 0.06011121278243119
Epoch: 2
training D Loss: 0.02217619225436297
training G Loss: 0.010677009041285134
training Info Loss: 0.004974195961575067
D_real_acc: 0.9810891344642313
D_fake_acc: 0.011919007062246311
Epoch: 3
training D Loss: 0.022112462217464695
training G Loss: 0.0103468227156005
training Info Loss: 0.004215199323676752
D_real_acc: 0.9809483587902678
D_fake_acc: 0.02106942586987635
Epoch: 4
training D Loss: 0.022172785986174934
training G Loss: 0.010191022045678734
training Info Loss: 0.0030382949369348244
D_real_acc: 0.9828488303887755
D_fake_acc: 0.020389010112385912
Epoch: 5
training D Loss: 0.02204310897368805
training G Loss: 0.010173230778032347
training Info Loss: 0.0019389892562971063
D_real_acc: 0.9825438164285212
D_fake_acc: 0.01987283264118627
Epoch: 6
training D Loss: 0.022119849297249687
training G Loss: 

training D Loss: 0.01900921568055286
training G Loss: 0.01291845616695538
training Info Loss: 0.00011852257913664004
D_real_acc: 0.9838811853311747
D_fake_acc: 0.2303559278290045
Epoch: 45
training D Loss: 0.01888232523450567
training G Loss: 0.013050789334762215
training Info Loss: 0.00012115446401297216
D_real_acc: 0.9835057835339387
D_fake_acc: 0.23486074939583773
Epoch: 46
training D Loss: 0.018687899684232498
training G Loss: 0.013169788874587696
training Info Loss: 0.00011853917950921146
D_real_acc: 0.9839984983928111
D_fake_acc: 0.24884446634288263
Epoch: 47
training D Loss: 0.018564921838887426
training G Loss: 0.013415771686013538
training Info Loss: 0.00012720938193254335
D_real_acc: 0.9847493019872833
D_fake_acc: 0.25977804368738416
Epoch: 48
training D Loss: 0.018440007517354078
training G Loss: 0.013640830370869743
training Info Loss: 0.00012367110866817713
D_real_acc: 0.9842331245160836
D_fake_acc: 0.27312827010159313
Epoch: 49
training D Loss: 0.018352269815880062
traini

training D Loss: 0.013808350176677076
training G Loss: 0.015199277566326814
training Info Loss: 0.00013826022835528403
D_real_acc: 0.9948851505126581
D_fake_acc: 0.6253020811337134
Epoch: 88
training D Loss: 0.013712439702404695
training G Loss: 0.015304293225213813
training Info Loss: 0.00014168773738751104
D_real_acc: 0.9944862861030948
D_fake_acc: 0.6299711409868375
Epoch: 89
training D Loss: 0.013731447934226124
training G Loss: 0.015449618201558877
training Info Loss: 0.00014878891779685443
D_real_acc: 0.9951432392482579
D_fake_acc: 0.6309565707045822
Epoch: 90
training D Loss: 0.01370442706543278
training G Loss: 0.01539914546356681
training Info Loss: 0.00013830179371378856
D_real_acc: 0.9945566739400765
D_fake_acc: 0.6281175946129842
Epoch: 91
training D Loss: 0.013674517026469707
training G Loss: 0.01551971129894939
training Info Loss: 0.0001472466230961769
D_real_acc: 0.9940404964688768
D_fake_acc: 0.6278595058773844
Epoch: 92
training D Loss: 0.013606871022726336
training G 

In [None]:
# plot loss
plt.figure(figsize=(16,4))
plt.subplot(1,2,1)
plt.plot(D_real_acc_list, label = "D real accuracy")
plt.plot(D_fake_acc_list, label = "D fake accuracy")
plt.title("Discriminator Accuracy")
plt.xlabel("epoch")
plt.legend()

plt.subplot(1,2,2)
plt.plot(D_loss_list, label="D loss")
plt.plot(G_loss_list, label="G loss")
plt.title("Training Loss")
plt.xlabel("epoch")
plt.legend()
# plt.savefig("./GAN_output/fig2_2.jpg")
plt.savefig("./InfoGAN_output2/figB_2.jpg")
plt.show()