In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import torchvision
from math import log2
from tqdm import tqdm_notebook
import cv2
import os
from torchvision.utils import save_image
from scipy.stats import truncnorm
import random

In [2]:
START_TRAIN_AT_IMG_SIZE = 32
DATASET = 'celeb_dataset'
CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
LOAD_MODEL = False
LEARNING_RATE = 1e-3
BATCH_SIZES = [200,150,100,100,100]
CHANNELS_IMG = 1
Z_DIM = 512  # should be 512 in original paper
IN_CHANNELS = 512  # should be 512 in original paper
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 4

In [3]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
#     fake = torch.nn.functional.interpolate(fake, size=[H,W]) #not required --> check
#     print(fake.shape, real.shape)

    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)
    
    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [4]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 5, stride = 1, padding = 2, gain = 2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain/(in_channels*(kernel_size**2)))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self,x):
        return self.conv(x*self.scale)+self.bias.view(1,self.bias.shape[0],1,1)

In [5]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm,self).__init__()
        self.epsilon = 1e-8
        
    def forward(self,x):
        return x/torch.sqrt(torch.mean(x**2, dim = 1, keepdim = True)+self.epsilon)

In [6]:
class Upsampler(nn.Module):
    def __init__(self, factor):
        super(Upsampler, self).__init__()
        self.factor = factor
        
    def forward(self,x):
        input_img = x[0]
        size_ip = torch.squeeze(x[1][0])
        size_tup = [int(size_ip[0]*self.factor),int(size_ip[1]*self.factor)]
        return torch.nn.functional.interpolate(input_img, size = size_tup, mode = 'nearest')

In [7]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm = True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace = True)
        self.pn = PixelNorm()
        
    def forward(self,x):

        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

In [8]:
factors = [1,1,1/2,1/4,1/32]

    
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=CHANNELS_IMG):
        super(Generator,self).__init__()
        
        self.in_channels = in_channels
        self.z_dim = z_dim
        self.lin = nn.Linear(in_features = z_dim, out_features = z_dim*4*4)
        self.conv1 = WSConv2d(in_channels = z_dim, out_channels = in_channels,kernel_size = 5,
                               stride=1, padding=2)
        self.conv2 = WSConv2d(in_channels = in_channels, out_channels = in_channels,kernel_size = 5,
                               stride=1, padding=2)
        
        self.pn = PixelNorm()
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.initial2 = nn.Sequential(
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
        self.img_channels = img_channels
        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size = 1, stride=1, padding=0)
        
        self.prog_blocks, self.rgb_layers = (
        nn.ModuleList([]), nn.ModuleList([self.initial_rgb]),)
        self.up_blocks = nn.ModuleList([])
        
        for i in range(1,5):
            self.up_blocks.append(Upsampler(i/4))
            
            
        for i in range(len(factors)-1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i+1])
            
            self.prog_blocks.append(ConvBlock(conv_in_c,conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c,img_channels, kernel_size = 1, stride = 1, padding=0))
            
    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha*generated+(1-alpha)*upscaled)


    def forward(self,x,alpha, steps):
        x1 = x[0]
        x2 = x[1]
        out = self.pn(x1)
        out = self.lin(x1)
        out = out.view(-1,self.z_dim,4,4)
#         print(out.shape)
        out = self.leaky(self.conv1(out))
        out = self.leaky(self.conv2(out))
        out = self.initial2(out)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upsampled = self.up_blocks[step]([out,x2])
            out = self.prog_blocks[step](upsampled)       

        final_upscaled = self.rgb_layers[steps-1](upsampled)
        final_out = self.rgb_layers[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)

            
        

# Discriminator failing with 1 sample or steps >0

In [9]:
# # Looking at the generator
# z_dim = 512
# in_channels = 512
# gen = Generator(z_dim,in_channels=in_channels)
# x = torch.randn((3, Z_DIM))
# s = torch.ones(3,1)*torch.tensor([32,32])
# z = gen([x,s],0.6,4)
# for i in range(5):
#     z = gen([x,s],alpha=0.6, steps=i)
#     plt.imshow(torchvision.utils.make_grid(z, normalize=True).permute(2,1,0))
#     plt.title(z.shape[2:])
#     plt.show()
# # d = Discriminator(512,1)
# # p = d(z.detach(),0.5,4)
# # p

In [10]:
class Discriminator(nn.Module):
    def __init__(self, in_channels,img_channels=CHANNELS_IMG):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2, inplace = True)
        self.final_avg = nn.AdaptiveAvgPool2d((1,1))
        
        for i in range(len(factors)-1,0,-1):
            conv_in = int(in_channels*factors[i])
            conv_out = int(in_channels*factors[i-1])
            
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels,conv_in, kernel_size=1, stride=1, padding=0))
                                    
        self.initial_rgb = WSConv2d(img_channels,in_channels,kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.final_block = nn.Sequential(WSConv2d(in_channels+1, in_channels, kernel_size=3,stride=1, padding=1),
                                        nn.LeakyReLU(0.2),
                                        WSConv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1),
                                        nn.LeakyReLU(0.2),
                                        WSConv2d(in_channels,1,kernel_size=1 ,padding=0, stride=1),
                                        self.final_avg)

    def fade_in(self,alpha, downscaled, out):
        return alpha*out+(1-alpha)*downscaled
                    
    
    def minibatch_std(self,x):
        batch_statistics = (torch.std(x,dim=0).mean().repeat(x.shape[0],1,x.shape[2],x.shape[3]))
        return torch.cat([x,batch_statistics],dim=1)
    
    def forward(self,x,alpha,steps):
        cur_step = len(self.prog_blocks)-steps
        out = self.leaky(self.rgb_layers[cur_step](x))
                                    
        if steps ==0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0],-1)
                                    
        downscaled = self.leaky(self.rgb_layers[cur_step+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
                                    
        out = self.fade_in(alpha, downscaled, out)
        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)                            

In [11]:
def plot_to_tensorboard(
    writer, loss_critic, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

In [12]:
def train_fn(critic, gen, loader, dataset, step, originial_shape, alpha, opt_critic, opt_gen,
            tensorboard_step, writer, scaler_gen, scaler_critic,epoch):
    
    loop = tqdm_notebook(loader, leave=False)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
#         print(cur_batch_size)
        shape_ip = torch.tensor([originial_shape[0], originial_shape[1]])
        shape_ip = (torch.ones(cur_batch_size,1)*shape_ip).to(DEVICE)
        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
#         print(noise.shape)

        with torch.cuda.amp.autocast():
            fake = gen([noise,shape_ip], alpha, step)
#             print(fake.shape)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 10 == 0:
            with torch.no_grad():
                fixed_fakes = gen([noise,shape_ip], alpha, step) * 0.5 + 0.5
                torchvision.utils.save_image(fixed_fakes,f"fake_i_{epoch}_{batch_idx}.png")
#                 print(fixed_fakes.shape)
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha

In [13]:
def get_loader(img_size):
    image_size1, image_size2 = img_size
    transform = transforms.Compose(
        [
            transforms.Resize((int(image_size1), int(image_size2))),
            transforms.ToTensor(),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
#     batch_size = BATCH_SIZES[step]
    batch_size = 100
    dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
    dataset, _ = torch.utils.data.random_split(dataset,[30000,30000])
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,)
    
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
    return loader, dataset

In [14]:
def main():
    gen = Generator(Z_DIM, IN_CHANNELS).to(DEVICE)
    critic = Discriminator(IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()
    
    writer = SummaryWriter(f"logs/gan1")
    
    gen.train()
    critic.train()
    tensorboard_step = 0
    

#     for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    num_epochs = 2
    alpha = 1e-5
    for i in range(1,5):
        loader, dataset = get_loader([(START_TRAIN_AT_IMG_SIZE*(i/4)), (START_TRAIN_AT_IMG_SIZE)*(i/4)])
        print('Current img size',[(START_TRAIN_AT_IMG_SIZE*(i/4)), (START_TRAIN_AT_IMG_SIZE)*(i/4)])
        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train_fn(critic, gen, loader, dataset, i,[32,32],alpha, opt_critic,
                                               opt_gen, tensorboard_step, writer, scaler_gen, scaler_critic,epoch)


In [15]:
main()

Current img size [8.0, 8.0]
Epoch [1/2]


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  loop = tqdm_notebook(loader, leave=False)


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch [2/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Current img size [16.0, 16.0]
Epoch [1/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch [2/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Current img size [24.0, 24.0]
Epoch [1/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch [2/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Current img size [32.0, 32.0]
Epoch [1/2]


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch [2/2]


  0%|          | 0/300 [00:00<?, ?it/s]

In [16]:
torch.save(gen,'gen.pth')
torch.save(dis,'dis.pth')

NameError: name 'gen' is not defined