In [None]:
import os 
import tensorflow as tf
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt

random_seed = 9292
torch.manual_seed(random_seed)
LATENT_DIM = 100
BATCH_SIZE = 64
NUM_EPOCHS = 500
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)
lr = 3e-4
latent_dims = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data_dir = './data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training = CIFAR10(data_dir, train=True, transform=transform, download=True)
testing = CIFAR10(data_dir, train=False, transform=transform, download=True)

train_dataloader = DataLoader(training, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(testing, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
from simple_baseline import SimpleAgent
simple_agent = SimpleAgent(LATENT_DIM, BATCH_SIZE)

In [None]:
for epoch in range(NUM_EPOCHS):
    for i, (real_imgs, _) in enumerate(train_dataloader):
        real_imgs = real_imgs.float().cuda()
        noise = torch.randn(LATENT_DIM*real_imgs.shape[0]).to(device)
        fake_imgs = simple_agent.gen(noise.reshape(real_imgs.shape[0], LATENT_DIM))

        disc_loss, real_acc, fake_acc = simple_agent.train_disc(fake_imgs, real_imgs)

        gen_loss = simple_agent.train_gen(fake_imgs)

        simple_agent.print_progress(len(train_dataloader), NUM_EPOCHS, epoch, i)

        del disc_loss, gen_loss, real_acc, fake_acc, noise, fake_imgs

    simple_agent.output_train_graphs('.')

    print(f'disc loss: {simple_agent.disc_loss[-1]:5.2f}  gen loss: {simple_agent.gen_loss[-1]:5.2f}  real acc: {simple_agent.real_acc[-1]:5.2f}  fake_acc: {simple_agent.fake_acc[-1]:5.2f}')

    if (epoch+1) % 50 == 0:
        simple_agent.gen.eval()
        simple_agent.disc.eval()

        with torch.no_grad():
            for i, (real_imgs, _) in enumerate(test_dataloader):
                disc_loss, gen_loss, real_acc, fake_acc = simple_agent.test_gan(real_imgs)

                progress = int(((i+1)/len(test_dataloader))*20)
                print(f"test [{'='*progress}{' '*(20-progress)}] {i+1}/{len(test_dataloader)}", end='\r', flush=True)
                del disc_loss, gen_loss, real_acc, fake_acc

            test_disc_loss, test_gen_loss, test_real_acc, test_fake_acc, fid_score = simple_agent.output_test_graphs('.')
            simple_agent.output_imgs('.', test_disc_loss, test_gen_loss, fid_score, epoch+1)
            
            simple_agent.disc.train()
            simple_agent.disc.train()

            del test_disc_loss, test_gen_loss, test_real_acc, test_fake_acc, fid_score