In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import random
import numpy as np
from math import log2
from torchvision.utils import save_image
from scipy.stats import truncnorm
from tqdm import tqdm
import wandb


In [16]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.scale = (gain/(in_channels*kernel_size**2))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self, x):
        return self.conv(x*self.scale) + self.bias.view(1, self.bias.shape[0],1,1)
    
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8
    def forward(self, x):
        return x/torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)

class ConvBlock(nn.Module):
    def __init__(self,in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pixelnorm = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        
    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pixelnorm else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pixelnorm else x
        return x
    
        
        

In [17]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.inital = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4,1,0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        
        self.inital_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.inital_rgb]))
        
        for i in range(len(factors) -1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )
            
    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha*generated + (1-alpha) * upscaled)
        
    def forward(self, x, alpha, steps):
        out = self.inital(x)
        
        if steps == 0: 
            return self.inital_rgb(out)
        for step in range(steps):
            upscale = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscale)
            
        final_upscaled = self.rgb_layers[steps-1](upscale)
        final_out = self.rgb_layers[steps](out)
        
        return self.fade_in(alpha, final_upscaled, final_out)
            
    
class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)
        
        for i in range(len(factors)-1,0,-1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0)
            )
        
        self.inital_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.inital_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1)
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1-alpha) * downscaled
        
    def minibatch_std(self, x):
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)
        
    def forward(self, x, alpha, steps):
        cur_steps = len(self.prog_blocks) -steps
        out = self.leaky(self.rgb_layers[cur_steps](x))
        
        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
        downscaled = self.leaky(self.rgb_layers[cur_steps+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_steps](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_steps+1, len(self.prog_blocks)):
            
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)
            
            
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)
        
        
        
        
        

### Configs 

In [36]:
factors = [1,1,1,1,1/2,1/4,1/8,1/16,1/32]
start_train_at_img_size = 4
training_dataset = "D:\DataSets\celeb_dataset"
checkpoint_gen = "generator.pth"
checkpoint_critic = "critic.pth"
device = "cuda" 
save_model = False
load_model = False
learning_rate = 1e-3
batch_sizes = [8,8,8,8,8,8,8,8,4]
image_size = 512
channels_img = 3
z_dim=256
in_channels=256
lambda_gp=10
num_steps = int(log2(image_size/4)) + 1

progressive_epochs = [10] * len(batch_sizes)
fixed_noise = torch.randn(8, z_dim,1,1).to(device)
num_workers = 4

torch.backends.cudnn.benchmarks = True

## Helpers

In [37]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cuda"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    #real * beta
    #fake.detach() * (1 - beta)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_examples(gen, steps, truncation=0.7, n=100):
    """
    Tried using truncation trick here but not sure it actually helped anything, you can
    remove it if you like and just sample from torch.randn
    """
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, z_dim, 1, 1)), device=device, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
    gen.train()

In [38]:
def get_loader(image_shape):
    transform = transforms.Compose(
        [
            transforms.Resize((image_shape, image_shape)),
            transforms.ToTensor(),
            #transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(channels_img)],
                [0.5 for _ in range(channels_img)],
            ),
        ]
    )
    batch_size = batch_sizes[int(log2(image_shape / 4))]
    dataset = datasets.ImageFolder(root=training_dataset, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    return loader, dataset
    
def train_fn(critic, gen, loader, dataset, step, alpha, opt_critic, opt_gen, scaler_gen, scaler_critic):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        
        #Train Critic
        noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
        
        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                +lambda_gp*gp
                +(0.001*torch.mean(critic_real**2))
            )
        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()
        
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
        opt_critic.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_critic)
        scaler_gen.update()
        alpha += cur_batch_size/(len(dataset)*progressive_epochs[step]*.05)
        alpha = min(alpha,1)
        
    
    gen.eval()
    alpha = 1.0
    noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
    fake = gen(noise, alpha, step)
    save_image(fake[0], f"saved_examples/img_{random.randint(0,10000)}.png")
    gen.train()
        
        
            
def main():
    gen = Generator(z_dim, in_channels, channels_img).to(device)
    critic = Discriminator(z_dim, in_channels, channels_img).to(device)
    opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.0, 0.99))
    opt_critic = optim.Adam(critic.parameters(), lr=learning_rate, betas=(0.0,0.99))
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()
    
    gen.train()
    critic.train()
    step = int(log2(start_train_at_img_size/4))
    for num_epochs in progressive_epochs[step:]:
        alpha = 1e-5
        loader, dataset = get_loader(4*2**step)
        for epoch in range(num_epochs):
            print(f"epoch: {epoch+1}")
            train_fn(critic, gen, loader, dataset, step, alpha, opt_critic, opt_gen, scaler_gen, scaler_critic)
        step += 1
        
        



In [39]:
main()

epoch: 1


100%|██████████| 3750/3750 [02:17<00:00, 27.24it/s]


epoch: 2


100%|██████████| 3750/3750 [02:22<00:00, 26.36it/s]


epoch: 3


100%|██████████| 3750/3750 [02:16<00:00, 27.45it/s]


epoch: 4


100%|██████████| 3750/3750 [02:15<00:00, 27.68it/s]


epoch: 5


  0%|          | 0/3750 [00:00<?, ?it/s]