In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from model import Generator, Discriminator
from torch.utils.tensorboard import SummaryWriter
from math import log2
from tqdm import tqdm

In [2]:
start_train_at_image_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
batch_sizes = [16, 16, 16, 16, 16, 16, 16, 8, 4]
image_size = 512
channels_img = 3
z_dim = 256
in_channels = 256
lambda_gp = 10
steps = int(log2(image_size/4)) + 1

progressive_epochs = [20] * len(batch_sizes)
fixed_noise = torch.randn(8, z_dim, 1, 1).to(device)
num_workers = 4

torch.backends.cudnn.benchmarks = True

In [3]:
def gradient_penalty(discriminator, real, fake, alpha, train_step, device='cpu'):
    batch_size, c, h, w = real.shape
    beta = torch.rand((batch_size, 1, 1, 1)).repeat(1, c, h, w).to(device)

    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    mixed_scores = discriminator(interpolated_images, alpha, train_step)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [4]:
def plot_to_tensorboard(writer, loss_disc, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar("Loss Discriminator", loss_disc, global_step=tensorboard_step)

    with torch.no_grad():
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

In [5]:
def get_loader(image_size):
    transform = T.Compose(
        [
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.RandomHorizontalFlip(p=0.5),
            T.Normalize(
                [0.5 for _ in range(channels_img)],
                [0.5 for _ in range(channels_img)],
            )
        ]
    )

    batch_size = batch_sizes[int(log2(image_size / 4))]
    dataset = ImageFolder(root='./dataset', transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader, dataset


In [6]:
def train_func(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen, tensorboard_step, writer, scaler_gen, scaler_disc):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(device)
        current_batch_size = real.shape[0]

        noise = torch.randn(current_batch_size, z_dim, 1, 1).to(device)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            disc_real = disc(real, alpha, step)
            disc_fake = disc(fake.detach(), alpha, step)

            gp = gradient_penalty(disc, real, fake, alpha, step, device=device)
            loss_disc = (
                -(torch.mean(disc_real) - torch.mean(disc_fake))
                + lambda_gp * gp
                + (0.001 * torch.mean(disc_real ** 2))
            )

        opt_disc.zero_grad()
        scaler_disc.scale(loss_disc).backward()
        scaler_disc.step(opt_disc)
        scaler_disc.update()

        with torch.cuda.amp.autocast():
            gen_fake = disc(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
        
        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        alpha += current_batch_size / (len(dataset) * (progressive_epochs[step]*0.5))
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(fixed_noise, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_disc.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step
            )
            tensorboard_step += 1
        
    return tensorboard_step, alpha

In [7]:
def main():
    gen = Generator(z_dim, in_channels, channels_img).to(device)
    disc = Discriminator(in_channels, channels_img).to(device)

    opt_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.99))
    opt_disc = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.0, 0.99))

    scaler_disc = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()

    writer = SummaryWriter(f'logs/gan')

    tensorboard_step = 0
    step = int(log2(start_train_at_image_size / 4))
    for num_epochs in progressive_epochs[step:]:
        alpha = 1e-5
        loader, dataset = get_loader(4*2**step)
        print(f"Размер изображения {4*2**step}")

        for epoch in range(num_epochs):
            print(f'Эпоха: {epoch}/{num_epochs}')
            tensorboard_step, alpha = train_func(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen, tensorboard_step, writer, scaler_gen, scaler_disc)

            step += 1

In [None]:
main()