In [1]:
import os, time
from tqdm import tnrange, tqdm_notebook, tqdm
import torch
import torch.nn.functional as F
from torch import nn, autograd, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
now = int(time.time())

In [3]:
from matplotlib import rcParams
rcParams['figure.figsize'] = (12, 8)

%matplotlib inline

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
batch_size = 128
num_epochs = 1000

z_dimension = 100

In [6]:
device_ids = [0] #, 1]

In [7]:
wh = 64
img_transform = transforms.Compose([
    transforms.Resize(wh),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

dataset = datasets.ImageFolder('~/data/anime-faces', transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [8]:
def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(1, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    alpha = alpha.to(device)
    
    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())

    interpolates = interpolates.to(device)
    interpolates.requires_grad_(True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.view(gradients.size(0), -1)                              
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
    return gradient_penalty

In [9]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [10]:
class Discriminator(nn.Module): # b 3 64 64
    def __init__(self, d=64):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, d, 4, 2, 1)),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(.2, True),
        ) # b d 32 32
        
        self.conv2 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d, d*2, 4, 2, 1)),
            nn.BatchNorm2d(d*2),
            nn.LeakyReLU(.2, True),
        ) # b d*2 16 16
        
        self.conv3 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d*2, d*4, 4, 2, 1)),
            nn.BatchNorm2d(d*4),
            nn.LeakyReLU(.2, True),
        ) # b d*4 8 8
        
        self.conv4 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d*4, d*8, 4, 2, 1)),
            nn.BatchNorm2d(d*8),
            nn.LeakyReLU(.2, True),
        ) # b d*8 4 4
        
        self.conv5 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d*8, d*16, 4, 2, 1)),
            nn.BatchNorm2d(d*16),
            nn.LeakyReLU(.2, True),
        ) # b d*16 2 2
        
        self.conv6 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d*16, d*32, 4, 2, 2)),
            nn.BatchNorm2d(d*32),
            nn.LeakyReLU(.2, True),
        ) # b d*32 2 2
        
        
        self.output = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(d*32, 1, 4, 2, 1)),
            nn.Sigmoid(),
        ) # b 1 1 1
        

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, x): # b 3 w h
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        
        return self.output(out).reshape(x.size(0), -1)

# class Discriminator(nn.Module):
#     def __init__(self, d=128):
#         super(Discriminator, self).__init__() # b d 64 64
        
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(3, d, 4, 2, 2),
#             nn.LeakyReLU(.2, True),
#             nn.BatchNorm2d(d),
#         ) # d 32 32
        
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(d, d*2, 4, 2, 2),
#             nn.LeakyReLU(.2, True),
#             nn.BatchNorm2d(d*2),
#         ) # d 16 16
        
#         self.conv3 = nn.Sequential(
#             nn.Conv2d(d*2, d*4, 4, 2, 1),
#             nn.LeakyReLU(.2, True),
#             nn.BatchNorm2d(d*4),
#         ) # d*4 8 8
        
#         self.conv4 = nn.Sequential(
#             nn.Conv2d(d*4, d*8, 4, 2, 1),
#             nn.LeakyReLU(.2, True),
#             nn.BatchNorm2d(d*8),
#         ) # d*4 4 4
        
#         self.conv5 = nn.Sequential(
#             nn.Conv2d(d*8, d*16, 4, 2, 1),
#             nn.LeakyReLU(.2, True),
#             nn.BatchNorm2d(d*16),
#         ) # d*4 2 2
        
        
#         self.output = nn.Sequential(
#             nn.Conv2d(d*16, 1, 4, 2, 1),
#             nn.Sigmoid(),
#         ) # 1 1 1
        
#     def weight_init(self, mean, std):
#         for m in self._modules:
#             normal_init(self._modules[m], mean, std)
            
#     def forward(self, x): # b 1 32 32

#         out = self.conv1(x)
#         out = self.conv2(out)
#         out = self.conv3(out)
#         out = self.conv4(out)
#         out = self.conv5(out)
        
#         out = self.output(out)
        
#         return out

In [11]:
class Generator(nn.Module):
    def __init__(self, inp_dim, d=64):
        super(Generator, self).__init__()
        
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(inp_dim, d*32, 4, 2, 1),
            nn.BatchNorm2d(d*32),
            nn.ReLU(True),
        ) # b d*16 2 2
        
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(d*32, d*16, 4, 2, 1),
            nn.BatchNorm2d(d*16),
            nn.ReLU(True),
        ) # b d*16 2 2
        
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(d*16, d*8, 4, 2, 1),
            nn.BatchNorm2d(d*8),
            nn.ReLU(True),
        ) # b d*16 4 4
        
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(d*8, d*4, 4, 2, 1),
            nn.BatchNorm2d(d*4),
            nn.ReLU(True),
        ) # b d*4 8 8
        
        self.deconv5 = nn.Sequential(
            nn.ConvTranspose2d(d*4, d*2, 4, 2, 1),
            nn.BatchNorm2d(d*2),
            nn.ReLU(True),
        ) # b d*2 16 16
        
#         self.deconv6 = nn.Sequential(
#             nn.ConvTranspose2d(d*2, d, 4, 2, 1),
#             nn.BatchNorm2d(d),
#             nn.ReLU(True),
#         ) # b d 32 32
        
        self.output = nn.Sequential(
            nn.ConvTranspose2d(d*2, 3, 4, 2, 1),
            nn.Tanh(),
        ) # b 3 64 64
        
        
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            
    def forward(self, x):
        out = self.deconv1(x)
        out = self.deconv2(out)
        out = self.deconv3(out)
        out = self.deconv4(out)
        out = self.deconv5(out)
#         out = self.deconv6(out)

        out = self.output(out)
        
        return out

# class Generator(nn.Module):
#     def __init__(self, z_dimension, d=128):
#         super(Generator, self).__init__()
        
#         self.upsample1 = nn.Sequential(
#             nn.ConvTranspose2d(z_dimension, d*16, 4, 2, 1),
#             nn.BatchNorm2d(d*16),
#             nn.ReLU(True),
#         ) # b d 2 2 
        
#         self.upsample2 = nn.Sequential(
#             nn.ConvTranspose2d(d*16, d*8, 4, 2, 1),
#             nn.BatchNorm2d(d*8),
#             nn.ReLU(True),
#         ) # b d*8 4 4
        
#         self.upsample3 = nn.Sequential(
#             nn.ConvTranspose2d(d*8, d*4, 4, 2, 1),
#             nn.BatchNorm2d(d*4),
#             nn.ReLU(True),
#         ) # b d*8 8 8
        
#         self.upsample4 = nn.Sequential(
#             nn.ConvTranspose2d(d*4, d*2, 4, 2, 1),
#             nn.BatchNorm2d(d*2),
#             nn.ReLU(True),
#         ) # b d*2 16 16
        
#         self.upsample5 = nn.Sequential(
#             nn.ConvTranspose2d(d*2, d, 4, 2, 1),
#             nn.BatchNorm2d(d),
#             nn.ReLU(True),
#         ) # b d 32 32
        
#         self.output = nn.Sequential(
#             nn.ConvTranspose2d(d, 3, 4, 2, 1),
#             nn.Tanh(),
#         ) # b 3 64 64
        
#     def weight_init(self, mean, std):
#         for m in self._modules:
#             normal_init(self._modules[m], mean, std)
            
#     def forward(self, x): # b 100 1 1

#         outs = self.upsample1(x)
#         outs = self.upsample2(outs)
#         outs = self.upsample3(outs)
#         outs = self.upsample4(outs)
#         outs = self.upsample5(outs)
        
#         outs = self.output(outs)

#         return outs

In [12]:
d = Discriminator(d=64)#.cuda(device_ids[0])
g = Generator(z_dimension, d=64)#.cuda(device_ids[0])

d.weight_init(0.0, 0.02)
g.weight_init(0.0, 0.02)

d = nn.DataParallel(d, device_ids=device_ids).to(device)
g = nn.DataParallel(g, device_ids=device_ids).to(device)

criterion = nn.BCELoss()

d_optimezer = optim.Adam(d.parameters(), lr=2e-4)
# d_optimezer = nn.DataParallel(d_optimezer, device_ids=device_ids)
g_optimezer = optim.Adam(g.parameters(), lr=2e-4)
# g_optimezer = nn.DataParallel(g_optimezer, device_ids=device_ids)

# one = torch.FloatTensor([1])
# mone = one * -1
# one = one.to(device)
# mone = mone.to(device)

In [13]:
writer = SummaryWriter(os.path.join('./log/cnn_wgan_gp_faces', str(now)))

In [14]:
img_path = os.path.join("save_images/cnn_wgan_gp_faces", str(now))
if not os.path.exists(img_path): os.makedirs(img_path)

In [None]:
total_count = len(dataloader)
for epoch in tqdm_notebook(range(num_epochs)):
    
    d_loss_total = .0
    g_loss_total = .0
    _step = epoch * total_count
    for i, (img, _) in enumerate(dataloader):
        
        real_img = img.cuda()
        z = torch.randn(img.size(0), z_dimension, 1, 1).cuda()
        
#         real_labels = torch.ones(img.size(0), 1).cuda()
#         fake_labels = torch.zeros(img.size(0), 1).cuda()
        
        real_labels = torch.from_numpy(np.random.normal(.95, .02, [img.size(0), 1])).float().to(device)
        fake_labels = torch.from_numpy(np.random.normal(.05, .02, [img.size(0), 1])).float().to(device)
        
#         ################### G ###################
        fake_img = g(z)
        fake_out = d(fake_img)
        
        g_loss = -fake_out.mean()
        
        g_optimezer.zero_grad()
        g_loss.backward()
        g_optimezer.step()
#         #########################################
        
#         ################### D ###################
        real_out = d(real_img)
        d_loss_real = real_out.mean()
        real_scores = real_out
        
        fake_out = d(fake_img.detach())
        d_loss_fake = fake_out.mean()
        fake_scores = fake_out
        
        gradient_penalty = calc_gradient_penalty(d, real_img, fake_img)
        
        d_loss = d_loss_fake - d_loss_real + gradient_penalty
        d_optimezer.zero_grad()
        d_loss.backward()
        d_optimezer.step()
#         #########################################

        ################### G ###################
#         fake_img = g(z)
#         fake_out = d(fake_img)
#         g_loss = criterion(fake_out, real_labels)
        
#         g_optimezer.zero_grad()
#         g_loss.backward()
#         g_optimezer.step()
        #########################################
        
        ################### D ###################
#         real_out = d(real_img)
#         d_loss_real = criterion(real_out, real_labels)
#         real_scores = real_out
        
#         fake_out = d(fake_img.detach())
#         d_loss_fake = criterion(fake_out, fake_labels)
#         fake_scores = fake_out
        
#         d_loss = d_loss_real + d_loss_fake
#         d_optimezer.zero_grad()
#         d_loss.backward()
#         d_optimezer.step()
        #########################################
        
        w_dist = d_loss_fake - d_loss_real
        
        d_loss_total += d_loss.item() * img.size(0)
        g_loss_total += g_loss.item() * img.size(0)
        
        step = _step + i + 1
        
        if (i + 1) % 100 == 0:
            writer.add_scalar('Discriminator Real Loss', d_loss_real.item(), step)
            writer.add_scalar('Discriminator Fake Loss', d_loss_fake.item(), step)
            writer.add_scalar('Discriminator Loss', d_loss.item(), step)
            writer.add_scalar('Generator Loss', g_loss.item(), step)
        
        
        if (i + 1) % 200 == 0:
            tqdm.write('Epoch [{}/{}], Step: {:6d}, d_loss: {:.6f}, g_loss: {:.6f}, real_scores: {:.6f}' \
', fake_scores: {:.6f}, W: {:.6f}'.format(epoch+1, num_epochs, (i+1) * batch_size, d_loss, g_loss, real_scores.mean(), fake_scores.mean(), w_dist))
        
        if (i + 1) % 300 == 0:
            fake_images = fake_img.view(-1, 3, wh, wh)[:8].cpu().data
            save_image(fake_images, os.path.join(img_path, 'fake_images_{:04d}_{:06d}.png'.format(epoch + 1, i + 1)))
        
    _d_loss_total = d_loss_total / (total_count * (epoch + 1))
    _g_loss_total = g_loss_total / (total_count * (epoch + 1))
    
    setp = (epoch + 1) * total_count
    writer.add_scalar('Discriminator Total Loss', _d_loss_total, step)
    writer.add_scalar('Generator Total Loss', _g_loss_total, step)
    tqdm.write("Finish Epoch [{}/{}], D Loss: {:.6f}, G Loss: {:.6f}".format(epoch+1, 
                                                                             num_epochs, 
                                                                             _d_loss_total,
                                                                             _g_loss_total, ))
    if epoch == 0:
        real_images = real_img.view(-1, 3, wh, wh).cpu().data
        save_image(real_images, os.path.join(img_path, 'real_images.png'))

    fake_images = fake_img.view(-1, 3, wh, wh).cpu().data
    save_image(fake_images, os.path.join(img_path, 'fake_images-{:03d}.png'.format(epoch + 1)))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

Epoch [1/1000], Step:  25600, d_loss: 2.867093, g_loss: -0.593351, real_scores: 0.249294, fake_scores: 0.593351, W: 0.344057
Epoch [1/1000], Step:  51200, d_loss: 2.406201, g_loss: -0.576729, real_scores: 0.196455, fake_scores: 0.576729, W: 0.380273
Epoch [1/1000], Step:  76800, d_loss: 2.476454, g_loss: -0.234531, real_scores: 0.166005, fake_scores: 0.234531, W: 0.068526
Epoch [1/1000], Step: 102400, d_loss: 2.452682, g_loss: -0.469254, real_scores: 0.104390, fake_scores: 0.469254, W: 0.364864
Finish Epoch [1/1000], D Loss: 590.254739, G Loss: -63.776675
Epoch [2/1000], Step:  25600, d_loss: 2.169613, g_loss: -0.186922, real_scores: 0.136537, fake_scores: 0.186922, W: 0.050385
Epoch [2/1000], Step:  51200, d_loss: 5.621141, g_loss: -0.408960, real_scores: 0.014727, fake_scores: 0.408960, W: 0.394233
Epoch [2/1000], Step:  76800, d_loss: 2.888328, g_loss: -0.984035, real_scores: 0.474159, fake_scores: 0.984035, W: 0.509876
Epoch [2/1000], Step: 102400, d_loss: 3.061282, g_loss: -0.9968

Epoch [15/1000], Step: 102400, d_loss: 1.765272, g_loss: -0.433336, real_scores: 0.375794, fake_scores: 0.433336, W: 0.057542
Finish Epoch [15/1000], D Loss: 20.702914, G Loss: -3.597482
Epoch [16/1000], Step:  25600, d_loss: 2.910476, g_loss: -0.522824, real_scores: 0.312776, fake_scores: 0.522824, W: 0.210048
Epoch [16/1000], Step:  51200, d_loss: 1.244490, g_loss: -0.415155, real_scores: 0.207889, fake_scores: 0.415155, W: 0.207266
Epoch [16/1000], Step:  76800, d_loss: 2.111396, g_loss: -0.435174, real_scores: 0.257718, fake_scores: 0.435174, W: 0.177456
Epoch [16/1000], Step: 102400, d_loss: 1.097329, g_loss: -0.329194, real_scores: 0.197479, fake_scores: 0.329194, W: 0.131716
Finish Epoch [16/1000], D Loss: 7.544042, G Loss: -3.017168
Epoch [17/1000], Step:  25600, d_loss: 0.473427, g_loss: -0.544262, real_scores: 0.297366, fake_scores: 0.544262, W: 0.246896
Epoch [17/1000], Step:  51200, d_loss: 0.271410, g_loss: -0.395777, real_scores: 0.403841, fake_scores: 0.395777, W: -0.008

Epoch [30/1000], Step:  51200, d_loss: -0.146049, g_loss: -0.323078, real_scores: 0.536036, fake_scores: 0.323078, W: -0.212957
Epoch [30/1000], Step:  76800, d_loss: 0.215865, g_loss: -0.405306, real_scores: 0.388321, fake_scores: 0.405306, W: 0.016985
Epoch [30/1000], Step: 102400, d_loss: -0.006345, g_loss: -0.425302, real_scores: 0.617292, fake_scores: 0.425302, W: -0.191990
Finish Epoch [30/1000], D Loss: 0.100050, G Loss: -1.600790
Epoch [31/1000], Step:  25600, d_loss: -0.042016, g_loss: -0.365759, real_scores: 0.519796, fake_scores: 0.365759, W: -0.154038
Epoch [31/1000], Step:  51200, d_loss: 0.101848, g_loss: -0.438727, real_scores: 0.574701, fake_scores: 0.438727, W: -0.135975
Epoch [31/1000], Step:  76800, d_loss: 0.684580, g_loss: -0.210469, real_scores: 0.291605, fake_scores: 0.210469, W: -0.081136
Epoch [31/1000], Step: 102400, d_loss: -0.158750, g_loss: -0.424362, real_scores: 0.666478, fake_scores: 0.424362, W: -0.242116
Finish Epoch [31/1000], D Loss: -0.038106, G Los

Epoch [44/1000], Step: 102400, d_loss: 0.042218, g_loss: -0.364858, real_scores: 0.662431, fake_scores: 0.364857, W: -0.297573
Finish Epoch [44/1000], D Loss: -0.170685, G Loss: -1.215265
Epoch [45/1000], Step:  25600, d_loss: -0.065854, g_loss: -0.157492, real_scores: 0.328739, fake_scores: 0.157492, W: -0.171247
Epoch [45/1000], Step:  51200, d_loss: -0.105051, g_loss: -0.341810, real_scores: 0.512028, fake_scores: 0.341810, W: -0.170218
Epoch [45/1000], Step:  76800, d_loss: 0.544959, g_loss: -0.496788, real_scores: 0.751535, fake_scores: 0.496788, W: -0.254747
Epoch [45/1000], Step: 102400, d_loss: 0.167446, g_loss: -0.244480, real_scores: 0.478204, fake_scores: 0.244480, W: -0.233724
Finish Epoch [45/1000], D Loss: -0.105724, G Loss: -1.157293
Epoch [46/1000], Step:  25600, d_loss: -0.292613, g_loss: -0.492281, real_scores: 0.826865, fake_scores: 0.492281, W: -0.334584
Epoch [46/1000], Step:  51200, d_loss: -0.017016, g_loss: -0.272439, real_scores: 0.381614, fake_scores: 0.272439

In [None]:
writer.close()

In [1]:
torch.save(d.state_dict(), './ser/cnn_wgan_gp_faces_discriminator.pkl')
torch.save(g.state_dict(), './ser/cnn_wgan_gp_faces_generator.pkl')

NameError: name 'torch' is not defined

In [None]:
d.load_state_dict(torch.load('./ser/cnn_wgan_gp_faces_discriminator.pt'))
g.load_state_dict(torch.load('./ser/cnn_wgan_gp_faces_generator.pt'))

In [None]:
z = torch.randn(4, z_dimension).to(device)
images = g(z)
# save_image(images, 'xx.png')
plt.imshow(Image.fromarray(make_grid(images).mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()))
plt.show()