In [20]:
import os 
import tensorflow as tf
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt

random_seed = 9292
torch.manual_seed(random_seed)
LATENT_DIM = 100
BATCH_SIZE = 64
NUM_EPOCHS = 500
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)
lr = 3e-4
latent_dims = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
data_dir = './data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training = CIFAR10(data_dir, train=True, transform=transform, download=True)
testing = CIFAR10(data_dir, train=False, transform=transform, download=True)

train_dataloader = DataLoader(training, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(testing, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
fid = FrechetInceptionDistance(features=2048, normalize=True, reset_real_features=False).to(device)

for i, (real_imgs, _) in enumerate(test_dataloader):
    real_imgs = real_imgs.float().cuda()
    fid.update(real_imgs, real=True)

In [None]:
from utils import * 
from models.acgan import ACGANDiscriminator, ACGANGenerator

In [22]:
import numpy as np
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def train_acgan(disc, gen, disc_optim, gen_optim, disc_criterion, aux_criterion, basedir, train_dataloader, test_dataloader):
    disc.train()
    gen.train()
    train_num_batches = len(train_dataloader)
    dict_result = {}
    dict_result['avg_errD_real'], dict_result['avg_gen_loss'], dict_result['avg_errD_fake'], dict_result['avg_accuracy'] = [], [], [], []
    dict_result['test_errD_real'], dict_result['test_gen_loss'], dict_result['test_errD_fake'], dict_result['test_accuracy'], dict_result['fid_score'] = [], [], [], [], []
    
    test_acgan(gen, disc, disc_criterion, aux_criterion, dict_result, 0, basedir, test_dataloader, fid)
    
    for epoch in range(NUM_EPOCHS):
        epoch_errD_real, epoch_errD_fake, epoch_gen_loss, epoch_accuracy = [], [], [], []
        for i, (real_imgs, labels) in enumerate(train_dataloader):
            disc.zero_grad()
            random_labels = torch.randint(0, 10, (real_imgs.shape[0], )).to(device)
            # real disc training
            disc_real, classes_real = disc(real_imgs.to(device))

            _batch_size = labels.shape[0]
            
            labels = labels.resize_(_batch_size).to(device)

            disc_real = disc_criterion(disc_real, torch.ones_like(disc_real))
            aux_real = aux_criterion(classes_real, labels)

            errD_real = disc_real + aux_real 
            errD_real.backward()

            accuracy = compute_acc(classes_real, labels)

            # fake disc training
            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_labels = torch.randint(0, 10, (_batch_size, )).resize_(_batch_size).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM))

            disc_fake, classes_fake = disc(fake_imgs)

            disc_fake = disc_criterion(disc_fake, torch.zeros_like(disc_fake))
            aux_fake = aux_criterion(classes_fake, fake_labels)
            errD_fake = disc_fake + aux_fake
        
            errD_fake.backward(retain_graph=True)
            disc_optim.step()

            epoch_errD_real.append(errD_real.item()) 
            epoch_errD_fake.append(errD_fake.item()) 
            epoch_accuracy.append(accuracy)

            disc_fake, classes_fake = disc(fake_imgs)
            gen_fake = disc_criterion(disc_fake, torch.ones_like(disc_fake))
            gen_aux = aux_criterion(classes_fake, fake_labels)
            gen_loss = gen_fake + gen_aux
            gen.zero_grad()
            gen_loss.backward()
            gen_optim.step()

            epoch_gen_loss.append(gen_loss.item())

            progress = int(((i+1)/train_num_batches)*20)
            print(f"{epoch+1}/{NUM_EPOCHS} [{'='*progress}{' '*(20-progress)}] {i+1}/{train_num_batches}", end='\r', flush=True)

            del gen_loss, noise, fake_imgs, disc_real, disc_fake, random_labels

        # dcgan_agent.output_train_graphs('.')
        
        avg_errD_real = sum(epoch_errD_real) / len(epoch_errD_real)
        avg_errD_fake = sum(epoch_errD_fake) / len(epoch_errD_fake)
        avg_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        avg_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)

        dict_result['avg_errD_real'].append(avg_errD_real)
        dict_result['avg_errD_fake'].append(avg_errD_fake)
        dict_result['avg_gen_loss'].append(avg_gen_loss)
        dict_result['avg_accuracy'].append(avg_accuracy)

        # plot_training_graphs(basedir, dict_result['avg_disc_loss'], dict_result['avg_gen_loss'], dict_result['avg_real_acc'], dict_result['avg_fake_acc'])

        fig = plt.figure(figsize=(15,9))
        plt.plot(dict_result['avg_errD_real'])
        plt.plot(dict_result['avg_errD_fake'])
        plt.plot(dict_result['avg_gen_loss'])
        plt.legend(['discriminator loss real', 'discriminator loss fake', 'generator loss'])
        plt.title('Discriminator loss vs Generator loss')
        plt.savefig(basedir+'/graphs/train_loss_graph.jpg')
        plt.close()

        fig = plt.figure(figsize=(15,9))
        plt.plot(dict_result['avg_accuracy'])
        plt.title('Accuracy')
        plt.savefig(basedir+'/graphs/train_acc_graph.jpg')
        plt.close()

        print(f'\nerrD_real loss: {avg_errD_real:5.2f}  errD_fake loss: {avg_errD_fake:5.2f}  gen loss: {avg_gen_loss:5.2f}  acc: {avg_accuracy:5.2f}\n')
        del avg_errD_fake, avg_gen_loss, avg_errD_real, avg_accuracy

        if (epoch+1) % 10 == 0:
            test_acgan(gen, disc, disc_criterion, aux_criterion, dict_result, epoch, basedir, test_dataloader, fid)
            torch.save(gen.state_dict(), basedir+f'/saves/model{str(epoch+1)}.pt')

    return dict_result

def test_acgan(gen, disc, disc_criterion, aux_criterion, dict_result, epoch, basedir, test_dataloader, fid):
    gen.eval()
    disc.eval()
    with torch.no_grad():
        epoch_errD_real, epoch_errD_fake, epoch_gen_loss, epoch_accuracy = [], [], [], []
        train_num_batches = len(test_dataloader)
        for i, (real_imgs, labels) in enumerate(test_dataloader):
            # disc_loss, gen_loss, real_acc, fake_acc = dcgan_agent.test_gan(real_imgs)
            random_labels = torch.randint(0, 10, (real_imgs.shape[0], )).to(device)
            # real disc training
            disc_real, classes_real = disc(real_imgs.to(device))

            _batch_size = labels.shape[0]
            
            labels = labels.resize_(_batch_size).to(device)

            disc_real = disc_criterion(disc_real, torch.ones_like(disc_real))
            aux_real = aux_criterion(classes_real, labels)

            errD_real = disc_real + aux_real 

            accuracy = compute_acc(classes_real, labels)

            # fake disc training
            real_imgs = real_imgs.float().cuda()
            noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
            fake_labels = torch.randint(0, 10, (_batch_size, )).resize_(_batch_size).to(device)
            fake_imgs = gen(noise.reshape(real_imgs.shape[0], LATENT_DIM))

            disc_fake, classes_fake = disc(fake_imgs)

            disc_fake = disc_criterion(disc_fake, torch.zeros_like(disc_fake))
            aux_fake = aux_criterion(classes_fake, fake_labels)
            errD_fake = disc_fake + aux_fake
        
            epoch_errD_real.append(errD_real.item()) 
            epoch_errD_fake.append(errD_fake.item()) 
            epoch_accuracy.append(accuracy)
            
            disc_fake, classes_fake = disc(fake_imgs)
            # gen_loss = criterion(disc_fake, torch.ones_like(disc_fake))
            gen_fake = disc_criterion(disc_fake, torch.ones_like(disc_fake))
            gen_aux = aux_criterion(classes_fake, fake_labels)
            gen_loss = gen_fake + gen_aux

            epoch_gen_loss.append(gen_loss.item())

            progress = int(((i+1)/train_num_batches)*20)
            print(f"{epoch+1}/{NUM_EPOCHS} [{'='*progress}{' '*(20-progress)}] {i+1}/{train_num_batches}", end='\r', flush=True)

            fid.update(fake_imgs, real=False)

            del gen_loss, noise, fake_imgs, disc_real, disc_fake, random_labels

        test_errD_real = sum(epoch_errD_real) / len(epoch_errD_real)
        test_errD_fake = sum(epoch_errD_fake) / len(epoch_errD_fake)
        test_gen_loss =  sum(epoch_gen_loss) / len(epoch_gen_loss)
        test_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)
        dict_result['test_errD_real'].append(test_errD_real)
        dict_result['test_gen_loss'].append(test_gen_loss)
        dict_result['test_errD_fake'].append(test_errD_fake)
        dict_result['test_accuracy'].append(test_accuracy)
        del epoch_gen_loss, epoch_errD_real, epoch_errD_fake, epoch_accuracy 

        # test_graphs(basedir, dict_result)
        fig = plt.figure(figsize=(15,9))
        plt.plot(dict_result['test_errD_real'])
        plt.plot(dict_result['test_errD_fake'])
        plt.plot(dict_result['test_gen_loss'])
        plt.plot(dict_result['fid_score'])
        plt.legend(['discriminator loss real', 'discriminator loss fake', 'generator loss', 'fid'])
        plt.title('Discriminator loss vs Generator loss vs FID')
        plt.savefig(basedir+'/graphs/test_loss_graph.jpg')
        plt.close()

        fig = plt.figure(figsize=(15,9))
        plt.plot(dict_result['test_accuracy'])
        plt.title('Accuracy')
        plt.savefig(basedir+'/graphs/test_acc_graph.jpg')
        plt.close()


        fid_score = fid.compute().item()
        fid.reset()
        dict_result['fid_score'].append(fid_score)

        # output_imgs(basedir, gen, epoch, test_disc_loss, test_gen_loss, fid_score)
        rows, columns = 8, 6

        noise = torch.randn(LATENT_DIM*rows*columns).to(device)
        fake_imgs = gen(noise.reshape(rows*columns, LATENT_DIM))
        _, classes_predict = disc(fake_imgs)
        classes_predict = classes_predict.data.max(1)[1]
        fig, axs = plt.subplots(rows, columns, figsize=(20, 30))
        for i in range(rows*columns):
            ax = axs[i//columns, i%columns]
            img = (fake_imgs[i].cpu().permute(1,2,0).detach().numpy() + 1) / 2
            ax.imshow(img, cmap='Greys')
            ax.set_title(classes[classes_predict[i]])
            ax.axis('off')

        plt.text(-192, -310, f'epoch: {str(epoch+1)}   discriminator loss real: {test_errD_real:5.2f}\n   discriminator loss fake: {test_errD_fake:5.2f}   generator loss: {test_gen_loss:5.2f}   fid: {fid_score:5.2f}', fontsize=30, ha='left', va='top')
        plt.savefig(basedir+f'/imgs/img{epoch+1}.jpg')
        plt.close()

        del rows, columns, noise, fake_imgs

        gen.train()
        disc.train()
        del fid_score

# compute the current classification accuracy
def compute_acc(preds, labels):
    correct = 0
    preds_ = preds.data.max(1)[1]
    correct = preds_.eq(labels.data).cpu().sum()
    acc = float(correct) / float(len(labels.data)) * 100.0
    return acc

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [23]:
disc = ACGANDiscriminator().to(device)
gen = ACGANGenerator(LATENT_DIM).to(device)

disc.apply(weights_init)
gen.apply(weights_init)

disc_optim = optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
gen_optim = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))

disc_criterion = nn.BCELoss()
aux_criterion = nn.NLLLoss()

ACGAN_results  = train_acgan(disc, gen, disc_optim, gen_optim, disc_criterion, aux_criterion, 'models/ACGAN', train_dataloader, test_dataloader)

errD_real loss:  0.48  errD_fake loss:  0.56  gen loss:  0.93  acc: 18.40

errD_real loss:  0.41  errD_fake loss:  0.53  gen loss:  1.04  acc: 23.19

errD_real loss:  0.37  errD_fake loss:  0.53  gen loss:  1.09  acc: 26.64

errD_real loss:  0.34  errD_fake loss:  0.52  gen loss:  1.06  acc: 29.24

errD_real loss:  0.35  errD_fake loss:  0.54  gen loss:  1.02  acc: 30.39

errD_real loss:  0.31  errD_fake loss:  0.53  gen loss:  0.96  acc: 32.90

errD_real loss:  0.28  errD_fake loss:  0.53  gen loss:  0.98  acc: 35.71

errD_real loss:  0.29  errD_fake loss:  0.55  gen loss:  0.95  acc: 36.94

errD_real loss:  0.28  errD_fake loss:  0.57  gen loss:  0.82  acc: 39.06

errD_real loss:  0.25  errD_fake loss:  0.57  gen loss:  0.78  acc: 42.33

errD_real loss:  0.24  errD_fake loss:  0.57  gen loss:  0.78  acc: 43.19

errD_real loss:  0.24  errD_fake loss:  0.57  gen loss:  0.77  acc: 43.79

errD_real loss:  0.23  errD_fake loss:  0.57  gen loss:  0.79  acc: 44.34

errD_real loss:  0.23  er

: 