In [1]:
import os
import math
import torch
import random
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms



from tqdm import tqdm
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter

2024-03-07 17:32:18.598519: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-07 17:32:18.672471: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-07 17:32:18.995664: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2024-03-07 17:32:18.995699: W tensorflow/compiler/xl

In [2]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

if not os.path.exists("./runs/PGGAN"):
    os.makedirs("./runs/PGGAN")

In [3]:
class WSConv2d(nn.Module): ## Weight Standardization Convolution
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
    

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8

    # x-shape: Batch Size x Channels x H X W
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.initial = nn.Sequential(PixelNorm(),
                                    nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
                                    nn.LeakyReLU(0.2),
                                    WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
                                    nn.LeakyReLU(0.2),
                                    PixelNorm())

        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]))

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0))

        self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
                                         nn.LeakyReLU(0.2),
                                         WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
                                         nn.LeakyReLU(0.2),
                                         WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1))

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))

        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        cur_step = len(self.prog_blocks) - steps

        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
            
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [7]:
def test():
    Z_DIM = 100
    IN_CHANNELS = 128
    IMG_CHANNELS = 3
    gen = Generator(Z_DIM, IN_CHANNELS, IMG_CHANNELS)
    critic = Discriminator(IN_CHANNELS, IMG_CHANNELS)
    
    for img_size in [4,8,16,32,64,128]:
        num_steps = int(math.log2(img_size/4))
        z = torch.randn((1,Z_DIM,1,1))
        generated = gen(z,0.5,steps=num_steps)
        
        assert generated.shape == (1, IMG_CHANNELS, img_size, img_size)
        
        critic_generated = critic(generated,0.5,steps=num_steps)
        
        assert critic_generated.shape == (1,1)
        print(f"Succes at image size {img_size}x{img_size}")

test()

Succes at image size 4x4
Succes at image size 8x8
Succes at image size 16x16
Succes at image size 32x32
Succes at image size 64x64
Succes at image size 128x128


In [8]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

START_TRAIN_IMG_SIZE = 16
DATASET = "/home/pervinco/Datasets/CelebA"
SAVE_DIR = "./runs/PGGAN"

CHECKPOINT_GEN = f"{SAVE_DIR}/generator.pth"
CHECKPOINT_CRITIC = f"{SAVE_DIR}/critic.pth"
SAVE_MODEL = False
LOAD_MODEL = False

LR = 1e-3
BATCH_SIZES = [32,32,32,32,16,16,16,4,4,4] ## modifiable/ Batch_sizes for each step
IMAGE_SIZE = 128 ## 1024 for paper
IMG_CHANNELS = 3
Z_DIM = 256 ## 512 for paper
IN_CHANNELS = 256 ## 512 for paper
LAMBDA_GP = 10
NUM_STEPS = int(math.log2(IMAGE_SIZE/4)) + 1

PROGRESSIVE_EPOCHS = [4] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8,Z_DIM,1,1).to(DEVICE)
# NUM_WORKERS = 4
NUM_WORKERS = 2

print(DEVICE)

cuda


In [9]:
def save_on_tensorboard(writer,loss_critic,loss_gen,real,fake,tensorboard_step):
    writer.add_scalar("Loss Critic",loss_critic,global_step=tensorboard_step)
    writer.add_scalar("Loss Generator", loss_gen, global_step=tensorboard_step)
    
    with torch.no_grad():
        img_grid_real = torchvision.utils.make_grid(real[:8],normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8],normalize=True)
        
        writer.add_image("Real",img_grid_real,global_step = tensorboard_step)
        writer.add_image("Fake",img_grid_fake,global_step = tensorboard_step)
        
def gradient_penalty(critic,real,fake,alpha,train_step,device="cpu"):
    BATCH_SIZE,C,H,W = real.shape
    beta = torch.rand((BATCH_SIZE,1,1,1)).repeat(1,C,H,W).to(device)
    
    interpolated_images = real * beta + fake.detach() * (1-beta)
    interpolated_images.requires_grad_(True)
    
    ## Calculate critic scores
    mixed_scores = critic(interpolated_images,alpha,train_step)
    
    ## Take the gradient of the scores with respect to the image
    gradient = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph = True,
        retain_graph = True
    )[0]
    
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    return penalty

def save_checkpoint(model,optimizer,filename="my_checkpoint.pth"):
    print("Saving Checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer" : optimizer.state_dict()
    }
    torch.save(checkpoint,filename)
    
def load_checkpoint(checkpoint_file,model,optimizer,lr):
    print("Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file,map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        
def generate_examples(gen,current_epoch,steps,n=16):
    gen.eval()
    alpha = 1.0
    
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1,Z_DIM,1,1).to(DEVICE)
            generated_img = gen(noise,alpha=alpha,steps=steps)
            save_image(generated_img*0.5+0.5,f"{SAVE_DIR}/step{steps}_epoch{current_epoch}_{i}.png")
                
    gen.train()

In [10]:
def get_loader(img_size):
    transform = transforms.Compose(
    [
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)],[0.5 for _ in range(IMG_CHANNELS)])
    ])
    
    batch_size = BATCH_SIZES[int(math.log2(img_size/4))]
    dataset = datasets.ImageFolder(root=DATASET,transform=transform)
    loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)
    
    return loader,dataset

def train_fn(gen,critic,loader,dataset,step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic):
    loop = tqdm(loader,leave=True)
    
    i = 0
    for batch_idx,(real,_) in enumerate(loop):
        i += 1
        if i%2 == 0:
            continue
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        noise = torch.randn(cur_batch_size,Z_DIM,1,1).to(DEVICE)
        
        ## Train Critic
        ## Wasserstein Loss : Maximize "E[Critic(real)] - E[Critic(fake)]"   ==   Minimize "-(E[Critic(real)] - E[Critic(fake)])"
        with torch.cuda.amp.autocast():
            fake = gen(noise,alpha,step).to(DEVICE)
            critic_real = critic(real,alpha,step)
            critic_fake = critic(fake.detach(),alpha,step)
            gp = gradient_penalty(critic,real,fake,alpha,step,device=DEVICE)
            loss_critic = -1 * (torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp + 0.001 * torch.mean(critic_real**2)
        
        critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()
        
        ## Train Generator
        ## Maximize "E[Critic(fake)]"   ==   Minimize "- E[Critic(fake)]"
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake,alpha,step)
            loss_gen = -1 * torch.mean(gen_fake)
            
        gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()
    
        alpha += (cur_batch_size/len(dataset)) * (1/PROGRESSIVE_EPOCHS[step]) * 2
        alpha = min(alpha,1)
        
        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE,alpha,step) * 0.5 + 0.5
                save_on_tensorboard(writer,loss_critic.item(),loss_gen.item(),real.detach(),fixed_fakes.detach(),tensorboard_step)
                tensorboard_step += 1
    
    return tensorboard_step,alpha
        
## build model
gen = Generator(Z_DIM,IN_CHANNELS,IMG_CHANNELS).to(DEVICE)
critic = Discriminator(IN_CHANNELS,IMG_CHANNELS).to(DEVICE)

## initialize optimizer,scalers (for FP16 training)
opt_gen = optim.Adam(gen.parameters(),lr=LR,betas=(0.0,0.99))
opt_critic = optim.Adam(critic.parameters(),lr=LR,betas=(0.0,0.99))
scaler_gen = torch.cuda.amp.GradScaler()
scaler_critic = torch.cuda.amp.GradScaler()

## tensorboard writer
writer = SummaryWriter(f"{SAVE_DIR}")
tensorboard_step = 0

## if checkpoint files exist, load model
if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN,gen,opt_gen,LR)
    load_checkpoint(CHECKPOINT_CRITIC,critic,opt_critic,LR)
    
gen.train()
critic.train()

step = int(math.log2(START_TRAIN_IMG_SIZE/4)) ## starts from 0

global_epoch = 0
generate_examples_at = [4,8,12,16,20,24,28,32]

for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-4
    loader,dataset = get_loader(4*2**step)
    print(f"Image size:{4*2**step} | Current step:{step}")
    
    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}] Global Epoch:{global_epoch}")
        tensorboard_step,alpha = train_fn(gen,critic,loader,dataset,step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic)
        global_epoch += 1
        if global_epoch in generate_examples_at:
            generate_examples(gen,global_epoch,step,n=6)
        
        if SAVE_MODEL and (epoch+1)%8==0:
            save_checkpoint(gen,opt_gen,filename="CHECKPOINT_GEN")
            save_checkpoint(critic,opt_critic,filename="CHECKPOINT_CRITIC")
            
    step += 1 ## Progressive Growing
    
print("Training finished")

Image size:16 | Current step:2
Epoch [1/4] Global Epoch:0


100%|██████████| 6332/6332 [00:47<00:00, 132.34it/s]


Epoch [2/4] Global Epoch:1


100%|██████████| 6332/6332 [00:45<00:00, 138.75it/s]


Epoch [3/4] Global Epoch:2


100%|██████████| 6332/6332 [00:45<00:00, 138.62it/s]


Epoch [4/4] Global Epoch:3


100%|██████████| 6332/6332 [00:46<00:00, 137.45it/s]


Image size:32 | Current step:3
Epoch [1/4] Global Epoch:4


100%|██████████| 6332/6332 [01:52<00:00, 56.47it/s]


Epoch [2/4] Global Epoch:5


100%|██████████| 6332/6332 [01:52<00:00, 56.35it/s]


Epoch [3/4] Global Epoch:6


100%|██████████| 6332/6332 [01:52<00:00, 56.42it/s]


Epoch [4/4] Global Epoch:7


100%|██████████| 6332/6332 [01:52<00:00, 56.48it/s]


Image size:64 | Current step:4
Epoch [1/4] Global Epoch:8


100%|██████████| 12663/12663 [07:57<00:00, 26.53it/s]


Epoch [2/4] Global Epoch:9


100%|██████████| 12663/12663 [07:56<00:00, 26.55it/s]


Epoch [3/4] Global Epoch:10


100%|██████████| 12663/12663 [07:56<00:00, 26.56it/s]


Epoch [4/4] Global Epoch:11


100%|██████████| 12663/12663 [07:56<00:00, 26.57it/s]


Image size:128 | Current step:5
Epoch [1/4] Global Epoch:12


100%|██████████| 12663/12663 [18:20<00:00, 11.50it/s]


Epoch [2/4] Global Epoch:13


100%|██████████| 12663/12663 [18:21<00:00, 11.50it/s]


Epoch [3/4] Global Epoch:14


100%|██████████| 12663/12663 [18:20<00:00, 11.50it/s]


Epoch [4/4] Global Epoch:15


100%|██████████| 12663/12663 [18:20<00:00, 11.50it/s]


Image size:256 | Current step:6
Epoch [1/4] Global Epoch:16


100%|██████████| 12663/12663 [42:08<00:00,  5.01it/s]


Epoch [2/4] Global Epoch:17


100%|██████████| 12663/12663 [42:07<00:00,  5.01it/s]


Epoch [3/4] Global Epoch:18


 60%|█████▉    | 7579/12663 [25:13<16:51,  5.03it/s]

: 