In [None]:
import os 
import tensorflow as tf
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt

random_seed = 9292
torch.manual_seed(random_seed)
LATENT_DIM = 100
BATCH_SIZE = 64
NUM_EPOCHS = 500
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)
lr = 3e-4
latent_dims = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data_dir = './data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training = CIFAR10(data_dir, train=True, transform=transform, download=True)
testing = CIFAR10(data_dir, train=False, transform=transform, download=True)

train_dataloader = DataLoader(training, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(testing, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
fid = FrechetInceptionDistance(features=2048, normalize=True, reset_real_features=False).to(device)

for i, (real_imgs, _) in enumerate(test_dataloader):
    real_imgs = real_imgs.float().cuda()
    fid.update(real_imgs, real=True)

In [None]:
from utils import *
from models.cdcgan import ConditionalDCGANDiscriminator, ConditionalDCGANGenerator

In [None]:
def train_conditional_gan(disc, gen, disc_optim, gen_optim, criterion, basedir, train_dataloader, test_dataloader, fid):
    disc.train()
    gen.train()
    train_num_batches = len(train_dataloader)
    dict_result = {}
    dict_result['avg_disc_loss'], dict_result['avg_gen_loss'], dict_result['avg_real_acc'], dict_result['avg_fake_acc'] = [], [], [], []
    dict_result['test_disc_loss'], dict_result['test_gen_loss'], dict_result['test_real_acc'], dict_result['test_fake_acc'], dict_result['fid_score'] = [], [], [], [], []

    test_conditional_gan(gen, disc, criterion, dict_result, 0, basedir, test_dataloader, fid)
    
    for epoch in range(NUM_EPOCHS):
        epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc = [], [], [], []
        for i, (real_imgs, labels) in enumerate(train_dataloader):
            random_labels = torch.randint(0, 10, (real_imgs.shape[0], )).to(device)

            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM), random_labels)

            disc_real = disc(real_imgs, labels.to(device))
            disc_fake = disc(fake_imgs, random_labels)

            real_acc = torch.sum(torch.round(disc_real) == torch.ones_like(disc_real)) / len(disc_real)
            fake_acc = torch.sum(torch.round(disc_fake) == torch.zeros_like(disc_fake)) / len(disc_fake)
        
            real_loss = criterion(disc_real, torch.ones_like(disc_real))
            fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
            disc_loss = (real_loss + fake_loss) / 2
            disc.zero_grad()
            disc_loss.backward(retain_graph=True)
            disc_optim.step()

            epoch_disc_loss.append(disc_loss.item()) 
            epoch_real_acc.append(real_acc.item()) 
            epoch_fake_acc.append(fake_acc.item()) 

            disc_fake = disc(fake_imgs, random_labels)
            gen_loss = criterion(disc_fake, torch.ones_like(disc_fake))
            gen.zero_grad()
            gen_loss.backward()
            gen_optim.step()

            epoch_gen_loss.append(gen_loss.item())

            progress = int(((i+1)/train_num_batches)*20)
            print(f"{epoch+1}/{NUM_EPOCHS} [{'='*progress}{' '*(20-progress)}] {i+1}/{train_num_batches}", end='\r', flush=True)

            del disc_loss, gen_loss, real_acc, fake_acc, noise, fake_imgs, disc_real, disc_fake, real_loss, fake_loss, random_labels

        # dcgan_agent.output_train_graphs('.')
        
        avg_disc_loss = sum(epoch_disc_loss) / len(epoch_disc_loss)
        avg_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        avg_real_acc = sum(epoch_real_acc) / len(epoch_real_acc)
        avg_fake_acc = sum(epoch_fake_acc) / len(epoch_fake_acc)

        dict_result['avg_disc_loss'].append(avg_disc_loss)
        dict_result['avg_gen_loss'].append(avg_gen_loss)
        dict_result['avg_real_acc'].append(avg_real_acc)
        dict_result['avg_fake_acc'].append(avg_fake_acc)

        del epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc

        plot_training_graphs(basedir, dict_result['avg_disc_loss'], dict_result['avg_gen_loss'], dict_result['avg_real_acc'], dict_result['avg_fake_acc'])

        print(f'disc loss: {avg_disc_loss:5.2f}  gen loss: {avg_gen_loss:5.2f}  real acc: {avg_real_acc:5.2f}  fake_acc: {avg_fake_acc:5.2f}')
        del avg_disc_loss, avg_gen_loss, avg_real_acc, avg_fake_acc

        if (epoch+1) % 10 == 0:
            test_conditional_gan(gen, disc, criterion, dict_result, epoch, basedir, test_dataloader, fid)
            torch.save(gen.state_dict(), basedir+f'/saves/model{str(epoch+1)}.pt')

    return dict_result


def test_conditional_gan(gen, disc, criterion, dict_result, epoch, basedir, test_dataloader, fid):
    gen.eval()
    disc.eval()
    with torch.no_grad():
        epoch_disc_loss, epoch_gen_loss, epoch_real_acc, epoch_fake_acc = [], [], [], []

        for i, (real_imgs, labels) in enumerate(test_dataloader):
            random_labels = torch.randint(0, 10, (real_imgs.shape[0], )).to(device)

            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM), random_labels)

            disc_real = disc(real_imgs, labels.to(device))
            disc_fake = disc(fake_imgs, random_labels)

            real_acc = torch.sum(torch.round(disc_real) == torch.ones_like(disc_real)) / len(disc_real)
            fake_acc = torch.sum(torch.round(disc_fake) == torch.zeros_like(disc_fake)) / len(disc_fake)
            
            real_loss = criterion(disc_real, torch.ones_like(disc_real))
            fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
            disc_loss = (real_loss + fake_loss) / 2

            disc_fake = disc(fake_imgs, random_labels)
            gen_loss = criterion(disc_fake, torch.ones_like(disc_fake))

            fid.update(fake_imgs, real=False)

            epoch_disc_loss.append(disc_loss.item())
            epoch_gen_loss.append(gen_loss.item())
            epoch_real_acc.append(real_acc.item())
            epoch_fake_acc.append(fake_acc.item())

            progress = int(((i+1)/len(test_dataloader))*20)
            print(f"test [{'='*progress}{' '*(20-progress)}] {i+1}/{len(test_dataloader)}", end='\r', flush=True)

            del disc_loss, gen_loss, real_acc, fake_acc, noise, fake_imgs, disc_real, disc_fake, random_labels

        test_disc_loss = sum(epoch_disc_loss) / len(epoch_disc_loss)
        test_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        test_real_acc = sum(epoch_real_acc) / len(epoch_real_acc)
        test_fake_acc = sum(epoch_fake_acc) / len(epoch_fake_acc)
        dict_result['test_disc_loss'].append(test_disc_loss)
        dict_result['test_gen_loss'].append(test_gen_loss)
        dict_result['test_real_acc'].append(test_real_acc)
        dict_result['test_fake_acc'].append(test_fake_acc)
        del epoch_gen_loss, epoch_disc_loss, epoch_real_acc, epoch_fake_acc

        test_graphs(basedir, dict_result)

        fid_score = fid.compute().item()
        fid.reset()
        dict_result['fid_score'].append(fid_score)

        output_imgs(basedir, gen, epoch, test_disc_loss, test_gen_loss, fid_score, labels=True)
        
        gen.train()
        disc.train()
        del fid_score


In [None]:
disc = ConditionalDCGANDiscriminator().to(device)
gen = ConditionalDCGANGenerator(LATENT_DIM).to(device)

disc_optim = optim.Adam(disc.parameters(), lr=1e-4)
gen_optim = optim.Adam(gen.parameters(), lr=4e-4)

criterion = nn.BCELoss()

ConditionalDCGAN_results = train_conditional_gan(disc, gen, disc_optim, gen_optim, criterion, 'models/C-DCGAN_run_8', train_dataloader, test_dataloader, fid)

disc loss:  0.34  gen loss:  4.30  real acc:  0.82  fake_acc:  0.82
disc loss:  0.63  gen loss:  1.34  real acc:  0.68  fake_acc:  0.57
disc loss:  0.62  gen loss:  1.76  real acc:  0.69  fake_acc:  0.70
disc loss:  0.62  gen loss:  1.19  real acc:  0.64  fake_acc:  0.69
disc loss:  0.62  gen loss:  1.39  real acc:  0.67  fake_acc:  0.70
disc loss:  0.61  gen loss:  1.30  real acc:  0.65  fake_acc:  0.69
disc loss:  0.62  gen loss:  1.43  real acc:  0.66  fake_acc:  0.69
disc loss:  0.61  gen loss:  2.02  real acc:  0.71  fake_acc:  0.75
disc loss:  0.56  gen loss:  1.65  real acc:  0.69  fake_acc:  0.79
disc loss:  0.54  gen loss:  1.64  real acc:  0.69  fake_acc:  0.77
disc loss:  0.52  gen loss:  1.72  real acc:  0.72  fake_acc:  0.81
disc loss:  0.50  gen loss:  1.82  real acc:  0.72  fake_acc:  0.84
disc loss:  0.45  gen loss:  2.34  real acc:  0.77  fake_acc:  0.88
disc loss:  0.44  gen loss:  2.30  real acc:  0.78  fake_acc:  0.88
disc loss:  0.40  gen loss:  2.31  real acc:  0.