In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from models import Generator, Discriminator
from PIL import Image
from torchvision import datasets, transforms
from torch.autograd import Variable
import torchvision.utils as vutils
import re

In [3]:
data = "../data/training_data"
batch_size = 300
epochs = 10
lr = 1.5e-4
betas = (0.5, 0.99)
weight_decay = 1e-4
epoch_test_interval = 1
batch_test_interval = 10000

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
print(f"Will run on {device}.")

torch.manual_seed(0)
noise = Variable(torch.randn(1, 100, 1, 1)).to(device)
test_noise = Variable(torch.randn(9, 100, 1, 1)).to(device)


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGBA')


def get_flag_loader(dir=data, batch_size=batch_size, shuffle=True):
    transform = transforms.ToTensor()
    flag_dataset = datasets.ImageFolder(
        root=dir, transform=transform, loader=pil_loader)
    flag_loader = torch.utils.data.DataLoader(
        dataset=flag_dataset, batch_size=batch_size, shuffle=shuffle)

    return flag_loader


def train(data_loader, epoch):
    D.train()
    G.train()

    for batch_idx, (data, _) in enumerate(data_loader):
        n_samples = data.size(dim=0)

        #######################
        # Train Discriminator #
        #######################

        # Zero out gradients on discriminator
        D.zero_grad()

        # Load real flag data, run through discriminator and compute BCE loss
        # against target vector of all ones, because the flags are legit
        real_data = Variable(data).to(device)
        output = D(real_data)
        real_target = Variable(torch.ones(n_samples)).to(device)
        real_error = loss(output.squeeze(), real_target)

        # Get normally distributed noise and feed to generator to create fake
        # flag data. Run fake flag data through discriminator and compute BCE
        # loss against target vector of all zeros, because data is fake. Detach
        # to avoid training generator on these labels
        noise = Variable(torch.randn(n_samples, 100, 1, 1)).to(device)
        fake_data = G(noise)
        output = D(fake_data.detach()).to(device)
        fake_target = Variable(torch.zeros(n_samples)).to(device)
        fake_error = loss(output.squeeze(), fake_target)

        # Compute accumulated gradient based on real and fake data to update
        # discriminator weights
        d_error = real_error + fake_error
        d_error.backward()
        d_optim.step()

        ###################
        # Train Generator #
        ###################

        # Zero out gradients on generator
        G.zero_grad()

        # Run fake flag data through discriminator and compute BCE loss against
        # target vector of all ones. We want to fool the discriminator, so
        # pretend the mapped data is genuine
        output = D(fake_data)
        g_error = loss(output.squeeze(), real_target)

        # Compute new gradients from discriminator and update weights of the
        # generator
        g_error.backward()
        g_optim.step()

        if epoch % epoch_test_interval == 0 and batch_idx % batch_test_interval == 0:
            # Logging
            print('({:02d}, {:02d}) \tLoss_D: {:.6f} \tLoss_G: {:.6f}'.format(
                epoch, batch_idx, d_error.data, g_error.data))
            
            # Test Generator
            with torch.no_grad():
                sample = G(test_noise).detach().cpu()
            grid = vutils.make_grid(sample, padding=2, normalize=True, nrow=3)
            img = np.transpose(grid, (1,2,0)).numpy()
            imgplot = plt.imshow(img)
            plt.title('Epoch {}'.format(epoch))
            plt.savefig(f"../data/gan_progress/try4/{epoch}.png")

Will run on cuda.


In [8]:
G, D = Generator().to(device), Discriminator().to(device)
loss = nn.BCELoss()
g_optim = optim.AdamW(G.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
d_optim = optim.AdamW(D.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

flag_loader = get_flag_loader()
noise = Variable(torch.randn(1, 100, 1, 1))

epoch = 0

**Optional:** Load previous network progress.

In [None]:
G_file_name = "./data/documentation/..."
D_file_name = "./data/documentation/..."

G.load_state_dict(torch.load(G_file_name))
D.load_state_dict(torch.load(D_file_name))

epoch = int(re.findall(r"net\_G\_e(\d+)", G_file_name)[0])

In [None]:
while True:
    # Train Model
    train(data_loader=flag_loader, epoch=epoch)
    epoch += 1

(00, 00) 	Loss_D: 1.140581 	Loss_G: 1.151804


**Optional:** Save networks.

In [None]:
torch.save(G.state_dict(), f"./data/documentation/try3.5/net_G_e{epoch}.pth")
torch.save(D.state_dict(), f"./data/documentation/try3.5/net_D_e{epoch}.pth")