In [1]:
import os
from tqdm import tnrange, tqdm_notebook, tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
from matplotlib import rcParams
rcParams['figure.figsize'] = (12, 8)

%matplotlib inline

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
batch_size = 64
num_epochs = 1000

z_dimension = 100
num_feature_x1 = 192
num_feature_x2 = 192

In [5]:
device_ids = [0, 1]

In [6]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

dataset = datasets.ImageFolder('./datas/faces', transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [7]:
class Discriminator(nn.Module): # b 3 96 96
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(2, 2), 
        ) # b 32 48 48
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(2, 2),
        ) # b 64 24 24
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(4, 4),
        ) # b 64 6 6
        
        self.fc = nn.Sequential(
            nn.Linear(64 * 6 * 6, 1024),
            nn.LeakyReLU(.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        ) # b 1
    
    def forward(self, x): # b 1 28 28
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        
        out = out.view(x.size(0), -1)
        return self.fc(out)

In [8]:
class Generator(nn.Module):
    def __init__(self, inp_dim, num_feature_x1, num_feature_x2):
        super(Generator, self).__init__()
        
        self.num_feature_x1 = num_feature_x1
        self.num_feature_x2 = num_feature_x2
        
        self.fc = nn.Sequential(
            nn.Linear(inp_dim, num_feature_x1 * num_feature_x2)
        ) # b h*w
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.LeakyReLU(.2, True),
        ) # b 1 192 192
        
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(.2, True),
        ) # b 64 192 192
        
        self.downsample2 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(.2, True),
        ) # b 32 192 192
        
        self.downsample3 = nn.Sequential(
            nn.Conv2d(32, 3, 3, padding=1, stride=2),
            nn.Tanh(),
        ) # b 3 96 96
        
    def forward(self, x):
        out = self.fc(x)
        out = out.view(x.size(0), 1, self.num_feature_x1, self.num_feature_x2)
        out = self.br(out)

        out = self.downsample1(out)
        out = self.downsample2(out)
        out = self.downsample3(out)
        return out

In [9]:
d = Discriminator()#.cuda(device_ids[0])
g = Generator(z_dimension, num_feature_x1, num_feature_x2)#.cuda(device_ids[0])

d = nn.DataParallel(d, device_ids=device_ids).to(device)
g = nn.DataParallel(g, device_ids=device_ids).to(device)


criterion = nn.BCELoss()

d_optimezer = optim.Adam(d.parameters(), lr=1e-4)
# d_optimezer = nn.DataParallel(d_optimezer, device_ids=device_ids)
g_optimezer = optim.Adam(g.parameters(), lr=1e-4)
# g_optimezer = nn.DataParallel(g_optimezer, device_ids=device_ids)

    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [10]:
writer = SummaryWriter('./log/cnn_gan_faces_multip_gpu')

In [None]:
total_count = len(dataloader)
for epoch in tqdm_notebook(xrange(num_epochs)):
    
    d_loss_total = .0
    g_loss_total = .0
    for i, (img, _) in enumerate(dataloader):
        
        real_img = img.cuda()
        real_labels = torch.ones(img.size(0), 1).cuda()
        fake_labels = torch.zeros(img.size(0), 1).cuda()
        
        real_out = d(real_img)
        d_loss_real = criterion(real_out, real_labels)
        real_scores = real_out
        
        z = torch.randn(img.size(0), z_dimension).cuda()
        fake_img = g(z)
        fake_out = d(fake_img)
        d_loss_fake = criterion(fake_out, fake_labels)
        fake_scores = fake_out
        
        d_loss = d_loss_real + d_loss_fake
        d_optimezer.zero_grad()
        d_loss.backward()
        d_optimezer.step()
        
        z = torch.randn(img.size(0), z_dimension).cuda()
        fake_img = g(z)
        fake_out = d(fake_img)
        g_loss = criterion(fake_out, real_labels)
        
        g_optimezer.zero_grad()
        g_loss.backward()
        g_optimezer.step()
        
        d_loss_total += d_loss.item() * img.size(0)
        g_loss_total += g_loss.item() * img.size(0)
        
        step = epoch * total_count + i + 1
        
        if (i + 1) % 100 == 0:
            writer.add_scalar('Discriminator Real Loss', d_loss_real.item(), step)
            writer.add_scalar('Discriminator Fake Loss', d_loss_fake.item(), step)
            writer.add_scalar('Discriminator Loss', d_loss.item(), step)
            writer.add_scalar('Generator Loss', g_loss.item(), step)
        
        
        if (i + 1) % 500 == 0:
            tqdm.write('Epoch [{}/{}], Step: {:6d}, d_loss: {:.6f}, g_loss: {:.6f}, real_scores: {:.6f}' \
', fake_scores: {:.6f}'.format(epoch+1, num_epochs, (i+1) * batch_size, d_loss, g_loss, real_scores.mean(), fake_scores.mean()))
    
    _d_loss_total = d_loss_total / (total_count * (epoch + 1))
    _g_loss_total = g_loss_total / (total_count * (epoch + 1))
    
    writer.add_scalar('Discriminator Total Loss', _d_loss_total, step)
    writer.add_scalar('Generator Total Loss', _g_loss_total, step)
    tqdm.write("Finish Epoch [{}/{}], D Loss: {:.6f}, G Loss: {:.6f}".format(epoch+1, 
                                                                             num_epochs, 
                                                                             _d_loss_total,
                                                                             _g_loss_total, ))
    if epoch == 0:
        real_images = real_img.view(-1, 3, 96, 96).cpu().data
        save_image(real_images, './cnn_gan_faces_multip_gpu/real_images.png')

    fake_images = fake_img.view(-1, 3, 96, 96).cpu().data
    save_image(fake_images, './cnn_gan_faces_multip_gpu/fake_images-{}.png'.format(epoch+1))

Epoch [1/1000], Step:  32000, d_loss: 0.054164, g_loss: 7.191885, real_scores: 0.975190, fake_scores: 0.007158
Finish Epoch [1/1000], D Loss: 21.932961, G Loss: 245.072496
Epoch [2/1000], Step:  32000, d_loss: 0.267216, g_loss: 3.399000, real_scores: 0.896402, fake_scores: 0.100556
Finish Epoch [2/1000], D Loss: 16.346100, G Loss: 82.266996
Epoch [3/1000], Step:  32000, d_loss: 0.392418, g_loss: 2.323717, real_scores: 0.870723, fake_scores: 0.149981
Finish Epoch [3/1000], D Loss: 15.015400, G Loss: 47.042996
Epoch [4/1000], Step:  32000, d_loss: 0.141008, g_loss: 4.168742, real_scores: 0.932382, fake_scores: 0.027019
Finish Epoch [4/1000], D Loss: 11.034464, G Loss: 32.813597
Epoch [5/1000], Step:  32000, d_loss: 0.476203, g_loss: 2.338293, real_scores: 0.882729, fake_scores: 0.233534
Finish Epoch [5/1000], D Loss: 7.580066, G Loss: 29.334026
Epoch [6/1000], Step:  32000, d_loss: 0.444863, g_loss: 1.939400, real_scores: 0.814008, fake_scores: 0.170997
Finish Epoch [6/1000], D Loss: 6.9

Epoch [49/1000], Step:  32000, d_loss: 0.515902, g_loss: 1.914449, real_scores: 0.876801, fake_scores: 0.240451
Finish Epoch [49/1000], D Loss: 0.605097, G Loss: 3.303579
Epoch [50/1000], Step:  32000, d_loss: 0.319007, g_loss: 2.090187, real_scores: 0.930550, fake_scores: 0.166616
Finish Epoch [50/1000], D Loss: 0.576846, G Loss: 3.320785
Epoch [51/1000], Step:  32000, d_loss: 0.538726, g_loss: 3.460716, real_scores: 0.761366, fake_scores: 0.032149
Finish Epoch [51/1000], D Loss: 0.536505, G Loss: 3.378083
Epoch [52/1000], Step:  32000, d_loss: 0.356284, g_loss: 2.320139, real_scores: 0.877112, fake_scores: 0.132763
Finish Epoch [52/1000], D Loss: 0.513332, G Loss: 3.336861
Epoch [53/1000], Step:  32000, d_loss: 0.506870, g_loss: 2.407567, real_scores: 0.846604, fake_scores: 0.190875
Finish Epoch [53/1000], D Loss: 0.491287, G Loss: 3.404518
Epoch [54/1000], Step:  32000, d_loss: 0.525008, g_loss: 1.897930, real_scores: 0.838367, fake_scores: 0.151591
Finish Epoch [54/1000], D Loss: 0

In [None]:
writer.close()

In [None]:
torch.save(d.state_dict(), './ser/faces_discriminator.pkl')
torch.save(g.state_dict(), './ser/faces_generator.pkl')

In [None]:
z = torch.randn(4, z_dimension).to(device)
images = g(z)
# save_image(images, 'xx.png')
plt.imshow(Image.fromarray(make_grid(images).mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()))
plt.show()