In [3]:
import torch
import torch.nn as nn
import tqdm
import numpy as np
import torch.optim as optim
import torchvision
from torchvision.utils import make_grid, save_image
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [4]:
torch.cuda.get_device_name()

'Tesla T4'

In [5]:
class Critic(nn.Module):
    def __init__(self, img_channels, features_d):
        super().__init__()
        self.c = nn.Sequential(nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1), 
                                nn.LeakyReLU(0.2), 

                                self.conv_block(features_d, features_d * 2, 4, 2, 1), 
                                self.conv_block(features_d * 2, features_d * 4, 4, 2, 1), 
                                self.conv_block(features_d * 4, features_d * 8, 4, 2, 1), 

                                nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0))
        
    def conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
         return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 
                              nn.InstanceNorm2d(out_channels, affine=True), 
                              nn.LeakyReLU(0.2))

    def forward(self, x):
        return self.c(x)

In [6]:
class Generator(nn.Module):
    def __init__(self, noise_channels, img_channels, features_g):
        super().__init__()
        self.g = nn.Sequential(self.conv_block(noise_channels, features_g * 16, 4, 1, 0), 
                               self.conv_block(features_g * 16, features_g * 8, 4, 2, 1),
                               self.conv_block(features_g * 8, features_g * 4, 4, 2, 1), 
                               self.conv_block(features_g * 4, features_g * 2, 4, 2, 1),
                               
                               nn.ConvTranspose2d(features_g * 2, img_channels, kernel_size=4, stride=2, padding=1),
                               nn.Tanh())
        
    def conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 
                             nn.BatchNorm2d(out_channels), 
                             nn.ReLU())
    
    def forward(self, x):
        return self.g(x)

In [7]:
def init_parameters(net):
    for n in net.modules():
        if isinstance(n, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(n.weight.data, 0.0, 0.02)

In [8]:
def gradient_penalty(C, real, fake, device):
    bs, c, h, w = real.shape
    alpha = torch.rand((bs, 1, 1, 1)).repeat(1, c, h, w).to(device)
    juxtaposition = alpha * real + (1 - alpha) * fake
    scores = C(juxtaposition)

    gradient = torch.autograd.grad(outputs=scores, inputs=juxtaposition, grad_outputs=torch.ones_like(scores),
                                   create_graph=True, retain_graph=True)[0]

    gradient = gradient.view(gradient.shape[0], -1)
    grad_norm = gradient.norm(2, dim=1)
    gp = torch.mean((grad_norm - 1) ** 2)
    return gp

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4
batch_size = 64
img_size = 64
img_channels = 3
noise_dim = 100
epochs = 500
features_c = 64
features_g = 64
critic_iter = 5
lambda_gp = 10

In [10]:
device

device(type='cuda')

In [11]:
transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                 transforms.ToTensor(), 
                                 transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)])])

In [12]:
dataset = datasets.ImageFolder(root="faces/", transform=transform)

In [13]:
len(dataset)

15167

In [14]:
dataloader = DataLoader(dataset, batch_size, shuffle=True)

In [15]:
len(dataloader)

237

In [16]:
G = Generator(noise_dim, img_channels, features_g).to(device)
C = Critic(img_channels, features_c).to(device)
init_parameters(G)
init_parameters(C)

In [17]:
opt_G = optim.Adam(G.parameters(), lr, betas=(0.0, 0.9))
opt_C = optim.Adam(C.parameters(), lr, betas=(0.0, 0.9))

In [18]:
fixed_noise = torch.randn(64, noise_dim, 1, 1).to(device)

In [19]:
def save_fake_images(epoch_num):
    with torch.no_grad():
        fake = G(fixed_noise)
    img_grid = make_grid(fake, normalize=True)
    file_name = "fake_images-{0:0=4d}.png".format(epoch_num)
    save_image(fake, file_name, nrow=8)
    print("")
    print("Image Saved!")
    print("")

In [20]:
for real, _ in dataloader:
    img_grid = make_grid(real[:64], normalize=True)
    save_image(img_grid, "real_images.png", nrow=8)
    break

In [21]:
save_fake_images(0)


Image Saved!



In [None]:
for epoch in range(epochs):
    batch_losses_C = []
    batch_losses_G = []
    for real, _ in tqdm.tqdm(dataloader, total=len(dataloader)):
        real = real.to(device)
        curr_batch_size = real.shape[0]

        for _ in range(critic_iter):
            noise = torch.randn(curr_batch_size, noise_dim, 1, 1).to(device)
            fake = G(noise)
            C_real = C(real).reshape(-1)
            C_fake = C(fake).reshape(-1)
            gp = gradient_penalty(C, fake, real, device)
            loss_C = -(torch.mean(C_real) - torch.mean(C_fake)) + lambda_gp * gp

            C.zero_grad()
            loss_C.backward(retain_graph=True)
            opt_C.step()
            batch_losses_C.append(loss_C.item())
        
        output = C(fake).reshape(-1)
        loss_G = -torch.mean(output)

        G.zero_grad()
        loss_G.backward()
        opt_G.step()
        batch_losses_G.append(loss_G.item())

    print(f" Epoch: {epoch} | Loss C: {np.round(sum(batch_losses_C)/len(batch_losses_C), 4)} | Loss G: {np.round(sum(batch_losses_G)/len(batch_losses_G), 4)}")
    if epoch % 10 == 0:
        save_fake_images(epoch)

100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 0 | Loss C: -84.4116 | Loss G: 115.1006

Image Saved!



100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 1 | Loss C: -47.3321 | Loss G: 141.1016


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 2 | Loss C: -23.5482 | Loss G: 127.1348


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 3 | Loss C: -18.1066 | Loss G: 116.9474


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 4 | Loss C: -16.8096 | Loss G: 116.9355


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 5 | Loss C: -15.8057 | Loss G: 115.7994


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 6 | Loss C: -15.0279 | Loss G: 115.9646


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 7 | Loss C: -13.96 | Loss G: 117.0942


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 8 | Loss C: -13.0876 | Loss G: 117.7157


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 9 | Loss C: -12.4576 | Loss G: 119.6281


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 10 | Loss C: -11.9549 | Loss G: 121.8198

Image Saved!



100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 11 | Loss C: -11.4767 | Loss G: 121.264


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 12 | Loss C: -11.2558 | Loss G: 124.0759


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 13 | Loss C: -10.8873 | Loss G: 126.1306


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 14 | Loss C: -10.6267 | Loss G: 126.0


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 15 | Loss C: -10.3785 | Loss G: 127.37


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 16 | Loss C: -10.0412 | Loss G: 128.8744


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 17 | Loss C: -9.9514 | Loss G: 129.9089


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 18 | Loss C: -9.7384 | Loss G: 131.5374


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 19 | Loss C: -9.6147 | Loss G: 132.8775


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 20 | Loss C: -9.4556 | Loss G: 132.3773

Image Saved!



100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 21 | Loss C: -9.2768 | Loss G: 134.7087


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 22 | Loss C: -9.2784 | Loss G: 134.1874


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 23 | Loss C: -9.0235 | Loss G: 135.3083


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 24 | Loss C: -9.0397 | Loss G: 135.3681


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 25 | Loss C: -8.9946 | Loss G: 137.3159


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 26 | Loss C: -8.868 | Loss G: 139.2601


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 27 | Loss C: -8.7005 | Loss G: 141.5486


100%|██████████| 237/237 [04:53<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 28 | Loss C: -8.7754 | Loss G: 141.9692


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 29 | Loss C: -8.701 | Loss G: 143.4269


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 30 | Loss C: -8.4723 | Loss G: 143.8937

Image Saved!



100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 31 | Loss C: -8.7848 | Loss G: 145.9944


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 32 | Loss C: -8.5287 | Loss G: 146.7654


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 33 | Loss C: -8.5135 | Loss G: 148.5421


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 34 | Loss C: -8.5113 | Loss G: 149.3048


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 35 | Loss C: -8.4012 | Loss G: 147.7917


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 36 | Loss C: -8.415 | Loss G: 151.3895


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 37 | Loss C: -8.3483 | Loss G: 147.7635


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 38 | Loss C: -8.1614 | Loss G: 152.1155


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 39 | Loss C: -8.4059 | Loss G: 151.2922


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 40 | Loss C: -8.323 | Loss G: 151.2824

Image Saved!



100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 41 | Loss C: -8.251 | Loss G: 153.9565


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 42 | Loss C: -8.2693 | Loss G: 151.6042


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 43 | Loss C: -8.2761 | Loss G: 152.9363


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 44 | Loss C: -8.2835 | Loss G: 152.0173


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 45 | Loss C: -8.1787 | Loss G: 156.599


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 46 | Loss C: -8.123 | Loss G: 156.572


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 47 | Loss C: -8.2759 | Loss G: 156.9499


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 48 | Loss C: -8.0126 | Loss G: 159.6222


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 49 | Loss C: -8.097 | Loss G: 157.9929


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 50 | Loss C: -7.95 | Loss G: 157.8737

Image Saved!



100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 51 | Loss C: -7.9473 | Loss G: 159.4288


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 52 | Loss C: -8.172 | Loss G: 158.1047


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 53 | Loss C: -8.034 | Loss G: 161.3006


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 54 | Loss C: -7.9479 | Loss G: 161.4078


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 55 | Loss C: -7.9916 | Loss G: 166.3042


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 56 | Loss C: -7.9806 | Loss G: 165.4727


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 57 | Loss C: -7.9813 | Loss G: 163.8937


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 58 | Loss C: -8.0295 | Loss G: 164.2107


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 59 | Loss C: -7.9689 | Loss G: 159.9514


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 60 | Loss C: -7.9585 | Loss G: 163.6567

Image Saved!



100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 61 | Loss C: -7.8502 | Loss G: 168.6348


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 62 | Loss C: -7.9158 | Loss G: 167.3649


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 63 | Loss C: -7.9285 | Loss G: 171.9807


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 64 | Loss C: -7.982 | Loss G: 172.1319


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 65 | Loss C: -7.9109 | Loss G: 173.798


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 66 | Loss C: -7.8768 | Loss G: 171.1866


100%|██████████| 237/237 [04:51<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 67 | Loss C: -7.8801 | Loss G: 173.5048


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 68 | Loss C: -7.9169 | Loss G: 169.5537


100%|██████████| 237/237 [04:52<00:00,  1.24s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 69 | Loss C: -7.8014 | Loss G: 170.2926


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 70 | Loss C: -7.8745 | Loss G: 169.477

Image Saved!



100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 71 | Loss C: -7.8312 | Loss G: 167.0011


100%|██████████| 237/237 [04:52<00:00,  1.23s/it]
  0%|          | 0/237 [00:00<?, ?it/s]

 Epoch: 72 | Loss C: -7.9153 | Loss G: 164.9164


 30%|██▉       | 70/237 [01:26<03:25,  1.23s/it]