In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import random
import numpy as np
from math import log2
from torchvision.utils import save_image
from scipy.stats import truncnorm
from tqdm import tqdm
import wandb



In [2]:
torch.cuda.is_available()

True

In [3]:
wandb.init(project="progan", entity="team-sisyphus")
config = wandb.config

factors = [1,1,1,1,1/2,1/4,1/8,1/16,1/32]
start_train_at_img_size = 4
#training_dataset = "D:\DataSets\celeb_dataset"
training_dataset = "dupe_512"
checkpoint_gen = "generator.pth"
checkpoint_critic = "critic.pth"
#device = "cuda" 
device = torch.device("cuda")
#device = torch.cuda.set_device()
save_model = False
load_model = False
learning_rate = 1e-3
batch_sizes = [32, 32, 32, 16, 16, 8, 8, 4, 4]
image_size = 512
channels_img = 1
z_dim=256
in_channels=256
lambda_gp=10
num_steps = int(log2(start_train_at_img_size/4)) + 1

progressive_epochs = [6] * len(batch_sizes)
fixed_noise = torch.randn(8, z_dim,1,1).to(device)
num_workers = 4

torch.backends.cudnn.benchmarks = True

config.settings = {"factors": factors, "start_train_at_img_size":start_train_at_img_size, "training_dataset": training_dataset,
                   "device": device, "learning_rate": learning_rate, "batch_sizes":batch_sizes, "image_size":image_size, "channels_img":channels_img,
                   "z_dim":z_dim, "in_channels":in_channels, "lambda_gp":lambda_gp, "num_steps":num_steps, "progressive_epochs":progressive_epochs,
                   "fixed_noise": fixed_noise, "num_workers":num_workers}

[34m[1mwandb[0m: Currently logged in as: [33mtekisuto[0m ([33mteam-sisyphus[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
config.settings["z_dim"]

256

In [5]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.scale = (gain/(in_channels*(kernel_size**2)))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self, x):
        return self.conv(x*self.scale) + self.bias.view(1, self.bias.shape[0],1,1)
    
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8
    def forward(self, x):
        return x/torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)

class ConvBlock(nn.Module):
    def __init__(self,in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pixelnorm = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        
    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pixelnorm else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pixelnorm else x
        return x
    
        
        

In [6]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4,1,0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
        
        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]))
        
        for i in range(len(config.settings["factors"]) -1):
            conv_in_c = int(in_channels*config.settings["factors"][i])
            conv_out_c = int(in_channels*config.settings["factors"][i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )
            
    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha*generated + (1-alpha) * upscaled)
        
    def forward(self, x, alpha, steps):
        out = self.initial(x)
        
        if steps == 0: 
            return self.initial_rgb(out)
        for step in range(steps):
            upscale = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscale)
            
        final_upscaled = self.rgb_layers[steps-1](upscale)
        final_out = self.rgb_layers[steps](out)
        
        return self.fade_in(alpha, final_upscaled, final_out)
            
    
class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)
        
        for i in range(len(config.settings["factors"])-1,0,-1):
            conv_in_c = int(in_channels*config.settings["factors"][i])
            conv_out_c = int(in_channels*config.settings["factors"][i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0)
            )
        
        self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1)
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1-alpha) * downscaled
        
    def minibatch_std(self, x):
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)
        
    def forward(self, x, alpha, steps):
        cur_steps = len(self.prog_blocks) -steps
        out = self.leaky(self.rgb_layers[cur_steps](x))
        
        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
        downscaled = self.leaky(self.rgb_layers[cur_steps+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_steps](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_steps+1, len(self.prog_blocks)):
            
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)
            
            
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)
        
        
        
        
        

### Configs 

## Helpers

In [7]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cuda"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    #real * beta
    #fake.detach() * (1 - beta)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_examples(gen, steps, truncation=0.7, n=100):
    """
    Tried using truncation trick here but not sure it actually helped anything, you can
    remove it if you like and just sample from torch.randn
    """
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.settings["z_dim"], 1, 1)), device=device, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
    gen.train()

In [8]:
def get_loader(image_shape):
    transform = transforms.Compose(
        [
            transforms.Resize((image_shape, image_shape)),
            transforms.ToTensor(),
            #transforms.RandomHorizontalFlip(p=0.5),
            transforms.Grayscale(),
            transforms.Normalize(
                [0.5 for _ in range(config.settings["channels_img"])],
                [0.5 for _ in range(config.settings["channels_img"])],
            ),
        ]
    )
    wandb.log({"image_dim":image_shape})
    batch_size = batch_sizes[int(log2(image_shape / 4))]
    dataset = datasets.ImageFolder(root=training_dataset, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    return loader, dataset
    
def train_fn(critic,
             gen,
             loader,
             dataset,
             step,
             alpha,
             opt_critic,
             opt_gen,
             scaler_gen,
             scaler_critic,
             wandb
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        
        #Train Critic
        noise = torch.randn(cur_batch_size, config.settings["z_dim"], 1, 1).to(device)
        
        with torch.cuda.amp.autocast():
            wandb.log({"alpha": alpha})
            wandb.log({"step": step})
            fake = gen(noise, alpha, step)
            critic_real = critic(real, alpha, step)
            critic_fake = critic(fake.detach(), alpha, step)
            gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                +lambda_gp*gp
                +(0.001*torch.mean(critic_real**2))
            )
        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()
        
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
            
        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()
        alpha += cur_batch_size / (
            (config.settings["progressive_epochs"][step] * 0.5) * len(dataset)
        )
        alpha = min(alpha,1)
        
        if batch_idx % 50 ==0:
            try:
                with torch.no_grad():
                    fixed_fakes = gen(fixed_noise, alpha, step) * 0.5 + 0.5
                wandb.log({"critic loss":loss_critic.item()})
                wandb.log({"gen loss":loss_gen.item()})

                real_images = wandb.Image(real.detach()[0, :, :, :])
                fake_images = wandb.Image(fixed_fakes.detach()[0, :, :, :])
                wandb.log({"reals": real_images})
                wandb.log({"fakes": fake_images})
            except:
                print("Logging Error")
                with torch.no_grad():
                    fixed_fakes = gen(fixed_noise, alpha, step) * 0.5 + 0.5
                fake_images = wandb.Image(fixed_fakes.detach()[0, :, :, :])
                wandb.log({"fakes": fake_images})
        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )
    return alpha

        
    
    #generate_examples(gen, step, truncation=0.7, n=1)
    #gen.eval()
    #alpha = 1.0
    #noise = torch.randn(cur_batch_size, z_dim, 1, 1).to(device)
    #fake = gen(noise, alpha, step)
    #save_image(fake[0], f"saved_examples/img_{random.randint(0,10000)}.png")
    #gen.train()
        
        
            
def main():
    gen = Generator(config.settings["z_dim"], in_channels, config.settings["channels_img"]).to(device)
    critic = Discriminator(config.settings["z_dim"], in_channels, config.settings["channels_img"]).to(device)
    opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.0, 0.99))
    opt_critic = optim.Adam(critic.parameters(), lr=learning_rate, betas=(0.0,0.99))
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()
    
    gen.train()
    critic.train()
    wandb.watch(gen, log_freq=100)
    wandb.watch(critic, log_freq=100)
    
    step = int(log2(config.settings["start_train_at_img_size"]/4))
    for num_epochs in progressive_epochs[step:]:
        alpha = 1e-5
        loader, dataset = get_loader(4*2**step)
        wandb.log({"line":num_epochs})
        for epoch in range(num_epochs):
            print(f"epoch: {epoch+1}")
            wandb.log({"epoch":epoch})
            alpha = train_fn(critic, gen, loader, dataset, step, alpha, opt_critic, opt_gen, scaler_gen, scaler_critic, wandb)
        step += 1
        save_checkpoint(gen, opt_gen, filename="my_checkpoint.pth.tar")
    return gen
        
        



In [9]:
gen = main()

epoch: 1


100%|████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81it/s, gp=0.0816, loss_critic=-1.41]


epoch: 2


100%|████████████████████████████████████████████| 63/63 [00:06<00:00, 10.50it/s, gp=0.0231, loss_critic=-.867]


epoch: 3


100%|████████████████████████████████████████████| 63/63 [00:06<00:00, 10.40it/s, gp=0.0203, loss_critic=-.636]


epoch: 4


100%|███████████████████████████████████████████| 63/63 [00:05<00:00, 10.52it/s, gp=0.0445, loss_critic=0.0676]


epoch: 5


100%|████████████████████████████████████████████| 63/63 [00:05<00:00, 10.55it/s, gp=0.0217, loss_critic=0.134]


epoch: 6


100%|████████████████████████████████████████████| 63/63 [00:06<00:00, 10.48it/s, gp=0.0179, loss_critic=0.161]


=> Saving checkpoint
epoch: 1


100%|███████████████████████████████████████████| 63/63 [00:11<00:00,  5.69it/s, gp=0.00472, loss_critic=0.057]


epoch: 2


100%|██████████████████████████████████████████| 63/63 [00:10<00:00,  5.75it/s, gp=0.00735, loss_critic=0.0599]


epoch: 3


100%|██████████████████████████████████████████| 63/63 [00:10<00:00,  5.76it/s, gp=1.44e-5, loss_critic=-.0517]


epoch: 4


100%|███████████████████████████████████████████| 63/63 [00:10<00:00,  5.76it/s, gp=0.00325, loss_critic=-.021]


epoch: 5


100%|██████████████████████████████████████████| 63/63 [00:10<00:00,  5.76it/s, gp=0.00142, loss_critic=-.0985]


epoch: 6


100%|███████████████████████████████████████████| 63/63 [00:10<00:00,  5.77it/s, gp=0.0112, loss_critic=0.0632]


=> Saving checkpoint
epoch: 1


100%|█████████████████████████████████████████| 63/63 [00:22<00:00,  2.86it/s, gp=0.00395, loss_critic=-.00606]


epoch: 2


100%|██████████████████████████████████████████| 63/63 [00:22<00:00,  2.86it/s, gp=0.000274, loss_critic=-.178]


epoch: 3


100%|██████████████████████████████████████████| 63/63 [00:21<00:00,  2.87it/s, gp=0.000304, loss_critic=-.109]


epoch: 4


100%|██████████████████████████████████████████| 63/63 [00:21<00:00,  2.87it/s, gp=0.00173, loss_critic=0.0267]


epoch: 5


100%|██████████████████████████████████████████| 63/63 [00:21<00:00,  2.87it/s, gp=0.00285, loss_critic=0.0326]


epoch: 6


100%|██████████████████████████████████████████████| 63/63 [00:21<00:00,  2.87it/s, gp=0.0365, loss_critic=0.4]


=> Saving checkpoint
epoch: 1


100%|█████████████████████████████████████████| 125/125 [01:26<00:00,  1.45it/s, gp=0.0107, loss_critic=0.0335]


epoch: 2


100%|████████████████████████████████████████| 125/125 [01:26<00:00,  1.45it/s, gp=2.98e-5, loss_critic=-.0708]


epoch: 3


100%|█████████████████████████████████████████| 125/125 [01:25<00:00,  1.45it/s, gp=0.0238, loss_critic=0.0189]


epoch: 4


100%|████████████████████████████████████████| 125/125 [01:26<00:00,  1.45it/s, gp=0.000284, loss_critic=0.113]


epoch: 5


100%|████████████████████████████████████████| 125/125 [01:26<00:00,  1.45it/s, gp=0.00341, loss_critic=0.0678]


epoch: 6


100%|███████████████████████████████████████| 125/125 [01:26<00:00,  1.45it/s, gp=0.000768, loss_critic=0.0609]


=> Saving checkpoint
epoch: 1


  0%|                                                                                  | 0/125 [00:03<?, ?it/s]

Logging Error





RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
generate_examples(gen, steps, truncation=0.7, n=100)

In [None]:
#def make_image(gen):
#    x = torch.rand(noise_vec).to("cuda")
#    y = gen(x)
#
#    image = torch.reshape(y, dimensions)
#
#    y = y.detach().to("cpu").numpy()
#    y.ndim
#
#    image = image[0,0].detach().to("cpu").numpy()
#
#    plt.figure(figsize=(40,10))
#    plt.imshow(image)
#
#    im = Image.fromarray(image).convert('RGB')
#    im.save("Generated.jpeg")
#
#    return image