In [1]:
import os, sys
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

dataset = 'FashionMNIST'

trans = transforms.Compose([transforms.ToTensor()])

if dataset == 'FashionMNIST':
    train_set = datasets.FashionMNIST('../datasets/fashion_mnist', train=True, download=True, transform=trans)
elif dataset == 'MNIST':
    train_set = datasets.MNIST('../datasets/mnist', train=True, download=True, transform=trans)
    
batch_size = 64
z_dim = 62
y_dim = 10

train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)

In [2]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
            
class generator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self, dataset = 'mnist'):
        super(generator, self).__init__()
        if dataset == 'mnist' or 'fashion-mnist':
            self.input_height = 28
            self.input_width = 28
            self.input_dim = 62 + 10
            self.output_dim = 1

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )
        initialize_weights(self)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        x = self.fc(x)
        x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
        x = self.deconv(x)
        return x

In [3]:
class discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, dataset = 'mnist'):
        super(discriminator, self).__init__()
        if dataset == 'mnist' or 'fashion-mnist':
            self.input_height = 28
            self.input_width = 28
            self.input_dim = 1 + 10
            self.output_dim = 1

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        initialize_weights(self)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        x = self.conv(x)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)

        return x

In [4]:
device = torch.device("cuda:0")
dataset = 'fashion-mnist'

lrG = 0.0002
lrD = 0.0002
beta1 = 0.5
beta2 = 0.999

G = generator(dataset).to(device)
D = discriminator(dataset).to(device)
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))

BCE_loss = nn.BCELoss()

In [5]:
def one_hot(label_batch, num_classes):
    yb_onehot = torch.eye(num_classes)[label_batch-1]
    yb_onehot = Variable(yb_onehot)
    return yb_onehot

In [6]:
num_epochs = 25
num_batches = len(train_loader)

# fixed noise & condition
sample_num = 100
sample_z_ = torch.zeros((sample_num, z_dim))
for i in range(10):
    sample_z_[i*y_dim] = torch.rand(1, z_dim)
    for j in range(1, y_dim):
        sample_z_[i*y_dim + j] = sample_z_[i*y_dim]
        
temp = torch.zeros((10, 1))
for i in range(y_dim):
    temp[i, 0] = i
    
temp_y = torch.zeros((sample_num, 1))
for i in range(10):
    temp_y[i*y_dim: (i+1)*y_dim] = temp
    
sample_y_ = torch.zeros((sample_num, y_dim))
sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1); # convert to one-hot

sample_z_ = Variable(sample_z_.cuda())
sample_y_ = Variable(sample_y_.cuda())

fill = torch.zeros([10, 10, 28, 28])
for i in range(10):
    fill[i, i, :, :] = 1

with open('logs/loss.log', 'w') as log_fn:
    
    log_fn.write('epoch,d_error,g_error,n_batch,num_batches\n')
    
    for epoch in range(num_epochs):
        for n_batch, (x_, y_) in enumerate(train_loader):
            N = x_.size(0)
            z_ = torch.randn(N, z_dim)
            y_vec_ = one_hot(y_, y_dim)
            y_fill_ = fill[torch.argmax(y_vec_, 1)]

            x_ = Variable(x_.to(device))
            z_ = Variable(z_.to(device))
            y_vec_ = Variable(y_vec_.to(device))
            y_fill_ = Variable(y_fill_.to(device))

            y_real_ = Variable(torch.ones(N, 1).cuda())
            y_fake_ = Variable(torch.zeros(N, 1).cuda())

            # update D network
            D_optimizer.zero_grad()

            D_real = D(x_, y_fill_)
            D_real_loss = BCE_loss(D_real, y_real_)

            G_ = G(z_, y_vec_)
            D_fake = D(G_, y_fill_)
            D_fake_loss = BCE_loss(D_fake, y_fake_)

            D_loss = D_real_loss + D_fake_loss
            #self.train_hist['D_loss'].append(D_loss.data[0])

            D_loss.backward()
            D_optimizer.step()

            # update G network
            G_optimizer.zero_grad()

            G_ = G(z_, y_vec_)
            D_fake = D(G_, y_fill_)
            G_loss = BCE_loss(D_fake, y_real_)
            #self.train_hist['G_loss'].append(G_loss.data[0])

            G_loss.backward()
            G_optimizer.step()

            log_fn.write('{},{:.6f},{:.6f},{},{}\n'.format(epoch, D_loss.item(), G_loss.item(), n_batch, num_batches))
        print("epoch: {} d_error: {:.4f} g_error: {:.4f}".format(epoch, D_loss.item(), G_loss.item()))

epoch: 0 d_error: 1.2501 g_error: 0.8000
epoch: 1 d_error: 1.3131 g_error: 0.9381
epoch: 2 d_error: 1.2512 g_error: 1.0140
epoch: 3 d_error: 1.2793 g_error: 1.0041
epoch: 4 d_error: 1.2944 g_error: 1.0852
epoch: 5 d_error: 1.2666 g_error: 1.0059
epoch: 6 d_error: 1.2661 g_error: 1.0522
epoch: 7 d_error: 1.2272 g_error: 1.2348
epoch: 8 d_error: 1.1910 g_error: 1.2162
epoch: 9 d_error: 1.2396 g_error: 1.3248
epoch: 10 d_error: 1.2318 g_error: 1.3571
epoch: 11 d_error: 1.2736 g_error: 1.4416
epoch: 12 d_error: 1.1992 g_error: 1.2234
epoch: 13 d_error: 1.3694 g_error: 1.3308
epoch: 14 d_error: 0.9763 g_error: 1.4267
epoch: 15 d_error: 0.8787 g_error: 1.6091
epoch: 16 d_error: 0.9172 g_error: 1.5231
epoch: 17 d_error: 1.1064 g_error: 1.5818
epoch: 18 d_error: 1.1167 g_error: 1.3588
epoch: 19 d_error: 0.9873 g_error: 1.7074
epoch: 20 d_error: 1.1956 g_error: 1.5077
epoch: 21 d_error: 1.0274 g_error: 1.6488
epoch: 22 d_error: 1.0942 g_error: 1.9907
epoch: 23 d_error: 0.6523 g_error: 2.1652
ep

In [7]:
samples = G(sample_z_, sample_y_)
samples = samples.permute(0, 2, 3, 1).data.cpu().numpy()
np.save('img/cgan', samples)