In [1]:
import os 
import tensorflow as tf
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt

random_seed = 9292
torch.manual_seed(random_seed)
LATENT_DIM = 100
BATCH_SIZE = 64
NUM_EPOCHS = 500
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)
lr = 5e-5
latent_dims = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

2023-02-01 22:51:25.564422: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-01 22:51:26.142661: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/yisan/anaconda3/envs/gpu_env/lib/
2023-02-01 22:51:26.142709: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/yisan/anaconda3/envs/gpu_env/lib/


In [2]:
data_dir = '../data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training = CIFAR10(data_dir, train=True, transform=transform, download=True)
testing = CIFAR10(data_dir, train=False, transform=transform, download=True)

train_dataloader = DataLoader(training, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(testing, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
fid = FrechetInceptionDistance(features=2048, normalize=True, reset_real_features=False).to(device)

for i, (real_imgs, _) in enumerate(test_dataloader):
    real_imgs = real_imgs.float().cuda()
    fid.update(real_imgs, real=True)

In [3]:
from utils import *
from models.wgan import WGANDiscriminator, WGANGenerator

In [6]:
def train_gan(disc, gen, disc_optim, gen_optim, basedir, train_dataloader, test_dataloader, fid):
    disc.train()
    gen.train()
    train_num_batches = len(train_dataloader)
    dict_result = {}
    dict_result['avg_wasserstein_d'], dict_result['avg_disc_loss'], dict_result['avg_gen_loss'], dict_result['avg_real_acc'], dict_result['avg_fake_acc'] = [], [], [], [], []
    dict_result['test_wasserstein_d'], dict_result['test_disc_loss'], dict_result['test_gen_loss'], dict_result['test_real_acc'], dict_result['test_fake_acc'], dict_result['fid_score'] = [], [], [], [], [], []

    one = torch.FloatTensor([1]).to(device)
    minus_one = one * -1

    test_gan(gen, disc, dict_result, 0, basedir, test_dataloader, fid)
    
    for epoch in range(NUM_EPOCHS):
        epoch_wasserstein_d, epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc = [], [], [], [], []
        for i, (real_imgs, _) in enumerate(train_dataloader):
            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM, 1, 1))

            disc.zero_grad()

            disc_real = disc(real_imgs)
            disc_fake = disc(fake_imgs)

            real_acc = torch.sum(torch.round(disc_real) == torch.ones_like(disc_real)) / len(disc_real)
            fake_acc = torch.sum(torch.round(disc_fake) == torch.zeros_like(disc_fake)) / len(disc_fake)
        
            real_loss = disc_real.mean(0).view(1)
            real_loss.backward(one)

            fake_loss = disc_fake.mean(0).view(1)
            fake_loss.backward(minus_one)

            disc_optim.step()

            disc_loss = fake_loss - real_loss
            wasserstein_d = real_loss - fake_loss

            epoch_wasserstein_d.append(wasserstein_d)
            epoch_disc_loss.append(disc_loss.item()) 
            epoch_real_acc.append(real_acc.item()) 
            epoch_fake_acc.append(fake_acc.item()) 

            gen.zero_grad()

            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM, 1, 1))

            disc_fake = disc(fake_imgs)
            gen_loss = disc_fake.mean().mean(0).view(1)
            # gen_loss = criterion(disc_fake, torch.ones_like(disc_fake))
            gen_loss.backward(one)
            gen_optim.step()

            epoch_gen_loss.append(gen_loss.item())

            progress = int(((i+1)/train_num_batches)*20)
            print(f"{epoch+1}/{NUM_EPOCHS} [{'='*progress}{' '*(20-progress)}] {i+1}/{train_num_batches}", end='\r', flush=True)

            del disc_loss, gen_loss, real_acc, fake_acc, noise, fake_imgs, disc_real, disc_fake, real_loss, fake_loss, wasserstein_d

        # dcgan_agent.output_train_graphs('.')
        
        avg_wasserstein_d = sum(epoch_wasserstein_d) / len(epoch_wasserstein_d)
        avg_disc_loss = sum(epoch_disc_loss) / len(epoch_disc_loss)
        avg_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        avg_real_acc = sum(epoch_real_acc) / len(epoch_real_acc)
        avg_fake_acc = sum(epoch_fake_acc) / len(epoch_fake_acc)

        dict_result['avg_wasserstein_d'].append(avg_wasserstein_d)
        dict_result['avg_disc_loss'].append(avg_disc_loss)
        dict_result['avg_gen_loss'].append(avg_gen_loss)
        dict_result['avg_real_acc'].append(avg_real_acc)
        dict_result['avg_fake_acc'].append(avg_fake_acc)

        del epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc, avg_wasserstein_d

        plot_training_graphs(basedir, dict_result['avg_disc_loss'], dict_result['avg_gen_loss'], dict_result['avg_real_acc'], dict_result['avg_fake_acc'])

        print(f'\ndisc loss: {avg_disc_loss:5.2f}  gen loss: {avg_gen_loss:5.2f}  real acc: {avg_real_acc:5.2f}  fake_acc: {avg_fake_acc:5.2f}\n')
        del avg_disc_loss, avg_gen_loss, avg_real_acc, avg_fake_acc

        if (epoch+1) % 10 == 0:
            test_gan(gen, disc, dict_result, epoch, basedir, test_dataloader, fid)
            torch.save(gen.state_dict(), basedir+f'/saves/model{str(epoch+1)}.pt')

    return dict_result

def test_gan(gen, disc, dict_result, epoch, basedir, test_dataloader, fid):
    gen.eval()
    disc.eval()
    with torch.no_grad():
        epoch_wasserstein_d, epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc = [], [], [], [], []

        for i, (real_imgs, _) in enumerate(test_dataloader):
            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM, 1, 1))

            disc.zero_grad()

            disc_real = disc(real_imgs)
            disc_fake = disc(fake_imgs)

            real_acc = torch.sum(torch.round(disc_real) == torch.ones_like(disc_real)) / len(disc_real)
            fake_acc = torch.sum(torch.round(disc_fake) == torch.zeros_like(disc_fake)) / len(disc_fake)
        
            real_loss = disc_real.mean(0).view(1)

            fake_loss = disc_fake.mean(0).view(1)

            disc_loss = fake_loss - real_loss
            wasserstein_d = real_loss - fake_loss

            disc_fake = disc(fake_imgs)
            gen_loss = disc_fake.mean().mean(0).view(1)

            fid.update(fake_imgs, real=False)

            epoch_wasserstein_d.append(wasserstein_d)
            epoch_disc_loss.append(disc_loss.item())
            epoch_gen_loss.append(gen_loss.item())
            epoch_real_acc.append(real_acc.item())
            epoch_fake_acc.append(fake_acc.item())

            progress = int(((i+1)/len(test_dataloader))*20)
            print(f"test [{'='*progress}{' '*(20-progress)}] {i+1}/{len(test_dataloader)}", end='\r', flush=True)

            del disc_loss, gen_loss, real_acc, fake_acc, noise, fake_imgs, disc_real, disc_fake, wasserstein_d

        test_wasserstein_d = sum(epoch_wasserstein_d) / len(epoch_wasserstein_d)
        test_disc_loss = sum(epoch_disc_loss) / len(epoch_disc_loss)
        test_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        test_real_acc = sum(epoch_real_acc) / len(epoch_real_acc)
        test_fake_acc = sum(epoch_fake_acc) / len(epoch_fake_acc)
        dict_result['test_wasserstein_d'].append(test_wasserstein_d)
        dict_result['test_disc_loss'].append(test_disc_loss)
        dict_result['test_gen_loss'].append(test_gen_loss)
        dict_result['test_real_acc'].append(test_real_acc)
        dict_result['test_fake_acc'].append(test_fake_acc)
        del epoch_gen_loss, epoch_disc_loss, epoch_real_acc, epoch_fake_acc, test_wasserstein_d

        test_graphs(basedir, dict_result)

        fid_score = fid.compute().item()
        fid.reset()
        dict_result['fid_score'].append(fid_score)

        output_imgs(basedir, gen, epoch, test_disc_loss, test_gen_loss, fid_score)
        
        gen.train()
        disc.train()
        del fid_score

In [8]:
disc = WGANDiscriminator().to(device)
gen = WGANGenerator(LATENT_DIM).to(device)

disc.apply(weights_init)
gen.apply(weights_init)

disc_optim = optim.RMSprop(disc.parameters(), lr=lr)
gen_optim = optim.RMSprop(gen.parameters(), lr=lr)

train_gan(disc, gen, disc_optim, gen_optim, '../model_saves/WGAN', train_dataloader, test_dataloader, fid)

disc loss: 2903.33  gen loss: 1512.60  real acc:  0.00  fake_acc:  0.00

disc loss: 11285.43  gen loss: 5867.70  real acc:  0.00  fake_acc:  0.00

disc loss: 21235.15  gen loss: 11444.08  real acc:  0.00  fake_acc:  0.00

disc loss: 35339.10  gen loss: 18881.04  real acc:  0.00  fake_acc:  0.00

disc loss: 53371.90  gen loss: 27988.57  real acc:  0.00  fake_acc:  0.00

disc loss: 63399.83  gen loss: 35844.41  real acc:  0.00  fake_acc:  0.00

disc loss: 73323.41  gen loss: 42391.12  real acc:  0.00  fake_acc:  0.00

disc loss: 82263.77  gen loss: 46809.58  real acc:  0.00  fake_acc:  0.00

disc loss: 90613.78  gen loss: 52858.47  real acc:  0.00  fake_acc:  0.00

disc loss: 98776.54  gen loss: 60753.81  real acc:  0.00  fake_acc:  0.00

disc loss: 100209.99  gen loss: 65422.39  real acc:  0.00  fake_acc:  0.00

disc loss: 102437.98  gen loss: 69493.32  real acc:  0.00  fake_acc:  0.00

disc loss: 103425.55  gen loss: 69896.92  real acc:  0.00  fake_acc:  0.00

disc loss: 106514.74  gen

{'avg_wasserstein_d': [tensor([-2903.3267], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-11285.4355], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-21235.1602], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-35339.1172], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-53371.8672], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-63399.8086], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-73323.4453], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-82263.7969], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-90613.7188], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-98776.5078], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-100210.0781], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-102437.9609], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-103425.5938], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-106514.7344], device='cuda:0', grad_fn=<DivBackward0>),
  tensor([-109206.7891], device='cuda:0', grad_fn=<D