In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import tensorboard

In [2]:
batch_size = 64
img_size = 64
num_channels = 3
z_dim = 100
features_gen = 64
features_disc = 64
lr = 5.0e-5
clipping_rate = 0.01
alpha = 0.2
num_epochs = 100
critic_iters = 5
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [3]:
print(device)

cuda:1


In [4]:
transform = transforms.Compose([
    transforms.Resize([img_size, img_size]),
    transforms.ToTensor(),
    transforms.Normalize(
    [0.5 for _ in range(num_channels)], [0.5 for _ in range(num_channels)])
])

In [5]:
data_dir = '../../DATASETS/celeba/images' # CELEBA IMAGES
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

In [6]:
images, labels = next(iter(dataloader))
np.shape(images)

torch.Size([64, 3, 64, 64])

In [7]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim, features_gen, num_channels):
        super().__init__()
        self.z_dim = z_dim
        self.features_gen = features_gen
        self.num_channels = num_channels
        
        self.net = nn.Sequential(
            self._block(in_channels=z_dim, out_channels=features_gen*16, kernel_size=4, stride=1, padding=0),
            self._block(features_gen*16, features_gen*8, 4, 2, 1),
            self._block(features_gen*8, features_gen*4, 4, 2, 1),
            self._block(features_gen*4, features_gen*2, 4, 2, 1),
            nn.ConvTranspose2d(features_gen*2, num_channels, 4, 2, 1),
            nn.Tanh())
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.net(x)
    

In [9]:
class Discriminator(nn.Module):
    def __init__(self, z_dim, features_disc, num_channels, alpha):
        super().__init__()
        self.z_dim = z_dim
        self.features_disc = features_disc
        self.num_channels = num_channels
        self.alpha = alpha
        
        
        self.net = nn.Sequential(
                self._block(num_channels, features_disc*2, 4, 2, 1),
                self._block(features_disc*2, features_disc*4, 4, 2, 1),
                self._block(features_disc*4, features_disc*8, 4, 2, 1),
                self._block(features_disc*8, features_disc*16, 4, 2, 1),
                nn.Conv2d(features_disc*16, 1, 4, 2, 0),
#                 nn.Sigmoid()       
        )
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(alpha)
        )
    
    def forward(self, x):
        return self.net(x).reshape(-1, 1)

In [10]:
gen = Generator(z_dim, features_gen,  num_channels).to(device)
critic = Discriminator(z_dim, features_disc, num_channels, alpha).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [11]:
gen_optim = optim.RMSprop(params=gen.parameters(), lr = lr)
critic_optim = optim.RMSprop(params=critic.parameters(), lr = lr)
# criterion = criterion = nn.BCELoss().to(device)
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
# 

In [12]:
# To save images in grid layout
def save_image_grid(epoch: int, batch_idx: int, images: torch.Tensor, ncol: int):
    
    image_grid = make_grid(images, ncol, normalize=True)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'GEN_IMAGES/generated_{epoch+1:03d}_{batch_idx}.jpg') # FLOWERS GENERATED IMAGES
    plt.close()

In [None]:
for epoch in range(num_epochs):
    critic_losses = []
    g_losses = []
    for batch_idx, (images, _) in enumerate(tqdm(dataloader)):
        images = images.to(device)
        
        # TRAINING DISCRIMINATOR
        for i in range(critic_iters):
            
            critic.train()
            critic_optim.zero_grad()
            critic_real_out = critic(images)

            gen.eval()
            noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
            critic_fake_out = critic(gen(noise))

            # Total loss
            critic_loss = -(torch.mean(critic_real_out) - torch.mean(critic_fake_out))

            # Backprop and Update params
            critic_loss.backward(retain_graph=True)
            critic_optim.step()
            
            # Clamming parameters 
            for p in critic.parameters():
                p.data.clamp_(-clipping_rate, clipping_rate)
        
        # Training Generator
        gen.train()
        gen_optim.zero_grad()
        fake_images = gen(noise)
        gen_loss = -torch.mean(critic(fake_images))
        
        # Backprop and Update params
        gen_loss.backward()
        gen_optim.step()
        
        critic_losses.append(critic_loss.item())
        g_losses.append(gen_loss.item())
        
        if batch_idx % 500 == 0:
            gen.eval()
            generated_images = gen(fixed_noise)
            save_image_grid(epoch, batch_idx, generated_images[:64], 8)
        
            print(f'EPOCH: {epoch+1} | G_LOSS: {np.mean(g_losses):.3f} | D_LOSS: {np.mean(critic_losses):.3f}')
    
   

  0%|          | 1/3165 [00:01<56:20,  1.07s/it]

EPOCH: 1 | G_LOSS: 0.077 | D_LOSS: -0.299


 16%|█▌        | 501/3165 [01:54<10:51,  4.09it/s]

EPOCH: 1 | G_LOSS: 0.894 | D_LOSS: -2.965


 32%|███▏      | 1001/3165 [03:47<08:48,  4.09it/s]

EPOCH: 1 | G_LOSS: 1.127 | D_LOSS: -2.991


 47%|████▋     | 1501/3165 [05:40<06:44,  4.11it/s]

EPOCH: 1 | G_LOSS: 1.204 | D_LOSS: -3.001


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.10it/s]

EPOCH: 1 | G_LOSS: 1.243 | D_LOSS: -3.006


 79%|███████▉  | 2501/3165 [09:27<02:42,  4.09it/s]

EPOCH: 1 | G_LOSS: 1.264 | D_LOSS: -3.004


 95%|█████████▍| 3001/3165 [11:20<00:40,  4.08it/s]

EPOCH: 1 | G_LOSS: 1.279 | D_LOSS: -3.002


100%|██████████| 3165/3165 [11:57<00:00,  4.41it/s]
  0%|          | 1/3165 [00:00<16:05,  3.28it/s]

EPOCH: 2 | G_LOSS: 1.360 | D_LOSS: -3.020


 16%|█▌        | 501/3165 [01:53<10:52,  4.08it/s]

EPOCH: 2 | G_LOSS: 1.345 | D_LOSS: -2.999


 32%|███▏      | 1001/3165 [03:46<08:48,  4.09it/s]

EPOCH: 2 | G_LOSS: 1.338 | D_LOSS: -2.984


 47%|████▋     | 1501/3165 [05:40<06:46,  4.09it/s]

EPOCH: 2 | G_LOSS: 1.337 | D_LOSS: -2.971


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 2 | G_LOSS: 1.324 | D_LOSS: -2.945


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.08it/s]

EPOCH: 2 | G_LOSS: 1.314 | D_LOSS: -2.921


 95%|█████████▍| 3001/3165 [11:19<00:40,  4.09it/s]

EPOCH: 2 | G_LOSS: 1.309 | D_LOSS: -2.910


100%|██████████| 3165/3165 [11:57<00:00,  4.41it/s]
  0%|          | 1/3165 [00:00<15:46,  3.34it/s]

EPOCH: 3 | G_LOSS: 1.295 | D_LOSS: -2.864


 16%|█▌        | 501/3165 [01:53<10:54,  4.07it/s]

EPOCH: 3 | G_LOSS: 1.264 | D_LOSS: -2.846


 32%|███▏      | 1001/3165 [03:47<08:48,  4.09it/s]

EPOCH: 3 | G_LOSS: 1.269 | D_LOSS: -2.841


 47%|████▋     | 1501/3165 [05:40<06:46,  4.09it/s]

EPOCH: 3 | G_LOSS: 1.154 | D_LOSS: -2.845


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 3 | G_LOSS: 1.051 | D_LOSS: -2.833


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.09it/s]

EPOCH: 3 | G_LOSS: 0.951 | D_LOSS: -2.826


 95%|█████████▍| 3001/3165 [11:19<00:40,  4.06it/s]

EPOCH: 3 | G_LOSS: 0.894 | D_LOSS: -2.824


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<15:45,  3.35it/s]

EPOCH: 4 | G_LOSS: 1.403 | D_LOSS: -2.905


 16%|█▌        | 501/3165 [01:53<10:53,  4.08it/s]

EPOCH: 4 | G_LOSS: 0.599 | D_LOSS: -2.852


 32%|███▏      | 1001/3165 [03:46<08:49,  4.08it/s]

EPOCH: 4 | G_LOSS: 0.603 | D_LOSS: -2.832


 47%|████▋     | 1501/3165 [05:39<06:45,  4.11it/s]

EPOCH: 4 | G_LOSS: 0.591 | D_LOSS: -2.814


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 4 | G_LOSS: 0.601 | D_LOSS: -2.808


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.09it/s]

EPOCH: 4 | G_LOSS: 0.623 | D_LOSS: -2.811


 95%|█████████▍| 3001/3165 [11:19<00:40,  4.09it/s]

EPOCH: 4 | G_LOSS: 0.634 | D_LOSS: -2.815


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<15:56,  3.31it/s]

EPOCH: 5 | G_LOSS: 1.465 | D_LOSS: -3.091


 16%|█▌        | 501/3165 [01:53<10:52,  4.09it/s]

EPOCH: 5 | G_LOSS: 0.258 | D_LOSS: -2.795


 32%|███▏      | 1001/3165 [03:46<08:49,  4.08it/s]

EPOCH: 5 | G_LOSS: 0.301 | D_LOSS: -2.758


 47%|████▋     | 1501/3165 [05:40<06:47,  4.09it/s]

EPOCH: 5 | G_LOSS: 0.357 | D_LOSS: -2.783


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 5 | G_LOSS: 0.356 | D_LOSS: -2.784


 79%|███████▉  | 2501/3165 [09:26<02:41,  4.10it/s]

EPOCH: 5 | G_LOSS: 0.418 | D_LOSS: -2.801


 95%|█████████▍| 3001/3165 [11:19<00:40,  4.10it/s]

EPOCH: 5 | G_LOSS: 0.441 | D_LOSS: -2.801


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<16:06,  3.27it/s]

EPOCH: 6 | G_LOSS: 1.441 | D_LOSS: -3.054


 16%|█▌        | 501/3165 [01:53<10:50,  4.09it/s]

EPOCH: 6 | G_LOSS: 0.589 | D_LOSS: -2.735


 32%|███▏      | 1001/3165 [03:46<08:48,  4.10it/s]

EPOCH: 6 | G_LOSS: 0.514 | D_LOSS: -2.774


 47%|████▋     | 1501/3165 [05:39<06:47,  4.09it/s]

EPOCH: 6 | G_LOSS: 0.488 | D_LOSS: -2.774


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 6 | G_LOSS: 0.565 | D_LOSS: -2.793


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.09it/s]

EPOCH: 6 | G_LOSS: 0.572 | D_LOSS: -2.787


 95%|█████████▍| 3001/3165 [11:19<00:39,  4.11it/s]

EPOCH: 6 | G_LOSS: 0.588 | D_LOSS: -2.781


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<15:34,  3.39it/s]

EPOCH: 7 | G_LOSS: 1.091 | D_LOSS: -1.228


 16%|█▌        | 501/3165 [01:53<10:49,  4.10it/s]

EPOCH: 7 | G_LOSS: 0.758 | D_LOSS: -2.832


 32%|███▏      | 1001/3165 [03:46<08:52,  4.06it/s]

EPOCH: 7 | G_LOSS: 0.696 | D_LOSS: -2.803


 47%|████▋     | 1501/3165 [05:39<06:46,  4.10it/s]

EPOCH: 7 | G_LOSS: 0.720 | D_LOSS: -2.791


 63%|██████▎   | 2001/3165 [07:32<04:46,  4.07it/s]

EPOCH: 7 | G_LOSS: 0.752 | D_LOSS: -2.786


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.08it/s]

EPOCH: 7 | G_LOSS: 0.781 | D_LOSS: -2.782


 95%|█████████▍| 3001/3165 [11:19<00:40,  4.09it/s]

EPOCH: 7 | G_LOSS: 0.787 | D_LOSS: -2.776


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<15:35,  3.38it/s]

EPOCH: 8 | G_LOSS: 1.430 | D_LOSS: -3.016


 16%|█▌        | 501/3165 [01:53<10:51,  4.09it/s]

EPOCH: 8 | G_LOSS: 0.832 | D_LOSS: -2.773


 32%|███▏      | 1001/3165 [03:47<08:47,  4.10it/s]

EPOCH: 8 | G_LOSS: 0.832 | D_LOSS: -2.765


 47%|████▋     | 1501/3165 [05:40<06:46,  4.10it/s]

EPOCH: 8 | G_LOSS: 0.845 | D_LOSS: -2.757


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 8 | G_LOSS: 0.833 | D_LOSS: -2.747


 79%|███████▉  | 2501/3165 [09:26<02:42,  4.09it/s]

EPOCH: 8 | G_LOSS: 0.834 | D_LOSS: -2.745


 95%|█████████▍| 3001/3165 [11:20<00:40,  4.09it/s]

EPOCH: 8 | G_LOSS: 0.827 | D_LOSS: -2.742


100%|██████████| 3165/3165 [11:57<00:00,  4.41it/s]
  0%|          | 1/3165 [00:00<15:46,  3.34it/s]

EPOCH: 9 | G_LOSS: -1.103 | D_LOSS: -2.614


 16%|█▌        | 501/3165 [01:53<10:50,  4.09it/s]

EPOCH: 9 | G_LOSS: 0.661 | D_LOSS: -2.620


 32%|███▏      | 1001/3165 [03:46<08:47,  4.11it/s]

EPOCH: 9 | G_LOSS: 0.596 | D_LOSS: -2.612


 47%|████▋     | 1501/3165 [05:39<06:45,  4.10it/s]

EPOCH: 9 | G_LOSS: 0.610 | D_LOSS: -2.600


 63%|██████▎   | 2001/3165 [07:32<04:43,  4.10it/s]

EPOCH: 9 | G_LOSS: 0.620 | D_LOSS: -2.595


 79%|███████▉  | 2501/3165 [09:25<02:42,  4.10it/s]

EPOCH: 9 | G_LOSS: 0.651 | D_LOSS: -2.584


 95%|█████████▍| 3001/3165 [11:18<00:39,  4.10it/s]

EPOCH: 9 | G_LOSS: 0.679 | D_LOSS: -2.581


100%|██████████| 3165/3165 [11:55<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<16:22,  3.22it/s]

EPOCH: 10 | G_LOSS: 1.415 | D_LOSS: -2.027


 16%|█▌        | 501/3165 [01:53<10:50,  4.10it/s]

EPOCH: 10 | G_LOSS: 0.719 | D_LOSS: -2.518


 32%|███▏      | 1001/3165 [03:46<08:48,  4.10it/s]

EPOCH: 10 | G_LOSS: 0.731 | D_LOSS: -2.533


 47%|████▋     | 1501/3165 [05:39<06:46,  4.09it/s]

EPOCH: 10 | G_LOSS: 0.778 | D_LOSS: -2.492


 63%|██████▎   | 2001/3165 [07:32<04:43,  4.10it/s]

EPOCH: 10 | G_LOSS: 0.787 | D_LOSS: -2.476


 79%|███████▉  | 2501/3165 [09:25<02:42,  4.09it/s]

EPOCH: 10 | G_LOSS: 0.788 | D_LOSS: -2.467


 95%|█████████▍| 3001/3165 [11:18<00:40,  4.10it/s]

EPOCH: 10 | G_LOSS: 0.786 | D_LOSS: -2.453


100%|██████████| 3165/3165 [11:55<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<15:44,  3.35it/s]

EPOCH: 11 | G_LOSS: 1.351 | D_LOSS: -2.804


 16%|█▌        | 501/3165 [01:53<11:04,  4.01it/s]

EPOCH: 11 | G_LOSS: 0.791 | D_LOSS: -2.411


 32%|███▏      | 1001/3165 [03:46<08:47,  4.10it/s]

EPOCH: 11 | G_LOSS: 0.808 | D_LOSS: -2.391


 47%|████▋     | 1501/3165 [05:40<07:44,  3.58it/s]

EPOCH: 11 | G_LOSS: 0.810 | D_LOSS: -2.366


 63%|██████▎   | 2001/3165 [07:33<04:44,  4.09it/s]

EPOCH: 11 | G_LOSS: 0.811 | D_LOSS: -2.359


 79%|███████▉  | 2501/3165 [09:26<02:41,  4.10it/s]

EPOCH: 11 | G_LOSS: 0.804 | D_LOSS: -2.357


 95%|█████████▍| 3001/3165 [11:19<00:39,  4.11it/s]

EPOCH: 11 | G_LOSS: 0.784 | D_LOSS: -2.351


100%|██████████| 3165/3165 [11:56<00:00,  4.42it/s]
  0%|          | 1/3165 [00:00<16:01,  3.29it/s]

EPOCH: 12 | G_LOSS: 1.399 | D_LOSS: -2.800


 16%|█▌        | 501/3165 [01:53<10:50,  4.09it/s]

EPOCH: 12 | G_LOSS: 0.778 | D_LOSS: -2.364


 32%|███▏      | 1001/3165 [03:46<08:47,  4.10it/s]

EPOCH: 12 | G_LOSS: 0.780 | D_LOSS: -2.342


 47%|████▋     | 1501/3165 [05:39<06:45,  4.10it/s]

EPOCH: 12 | G_LOSS: 0.801 | D_LOSS: -2.306


 63%|██████▎   | 2001/3165 [07:32<04:43,  4.10it/s]

EPOCH: 12 | G_LOSS: 0.822 | D_LOSS: -2.280


 70%|███████   | 2225/3165 [08:23<03:32,  4.42it/s]

In [1]:
def show_image_grid(images: torch.Tensor, ncol: int):
    
    image_grid = make_grid(images, ncol, normalize=True)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy
    plt.figure(figsize = (8, 4))
    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])

NameError: name 'torch' is not defined

In [2]:
for i in range(10):
    gen.eval()
    noise = torch.randn(8, 100, 1, 1).to(device)
    gen_images = gen(noise)
    show_image_grid(gen_images, 4)

NameError: name 'gen' is not defined