-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.transforms as transforms | ||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader | ||
from torchvision import datasets | ||
from torchvision.utils import save_image | ||
|
||
import torchbearer as tb | ||
from torchbearer.callbacks import Callback | ||
|
||
os.makedirs('images', exist_ok=True) | ||
|
||
# Define constants | ||
epochs = 200 | ||
batch_size = 64 | ||
lr = 0.0002 | ||
nworkers = 8 | ||
latent_dim = 100 | ||
sample_interval = 400 | ||
img_shape = (1, 28, 28) | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self): | ||
super(Generator, self).__init__() | ||
|
||
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
self.model = nn.Sequential( | ||
*block(latent_dim, 128, normalize=False), | ||
*block(128, 256), | ||
*block(256, 512), | ||
*block(512, 1024), | ||
nn.Linear(1024, int(np.prod(img_shape))), | ||
nn.Tanh() | ||
) | ||
|
||
def forward(self, z): | ||
img = self.model(z) | ||
img = img.view(img.size(0), *img_shape) | ||
return img | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self): | ||
super(Discriminator, self).__init__() | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(np.prod(img_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
def forward(self, img): | ||
img_flat = img.view(img.size(0), -1) | ||
validity = self.model(img_flat) | ||
|
||
return validity | ||
|
||
|
||
class GAN(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.discriminator = Discriminator() | ||
self.generator = Generator() | ||
|
||
def forward(self, real_imgs, state): | ||
# Generator Forward | ||
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE]) | ||
state['gen_imgs'] = self.generator(z) | ||
state['disc_gen'] = self.discriminator(state['gen_imgs']) | ||
# We don't want to discriminator gradients on the generator forward pass | ||
self.discriminator.zero_grad() | ||
|
||
# Discriminator Forward | ||
state['disc_gen_det'] = self.discriminator(state['gen_imgs'].detach()) | ||
state['disc_real'] = self.discriminator(real_imgs) | ||
|
||
|
||
class LossCallback(Callback): | ||
def on_start(self, state): | ||
super().on_start(state) | ||
self.adversarial_loss = torch.nn.BCELoss() | ||
state['valid'] = torch.ones(batch_size, 1, device=state[tb.DEVICE]) | ||
state['fake'] = torch.zeros(batch_size, 1, device=state[tb.DEVICE]) | ||
|
||
def on_criterion(self, state): | ||
super().on_criterion(state) | ||
fake_loss = self.adversarial_loss(state['disc_gen_det'], state['fake']) | ||
real_loss = self.adversarial_loss(state['disc_real'], state['valid']) | ||
state['g_loss'] = self.adversarial_loss(state['disc_gen'], state['valid']) | ||
state['d_loss'] = (real_loss + fake_loss) / 2 | ||
# This is the loss that backward is called on. | ||
state[tb.LOSS] = state['g_loss'] + state['d_loss'] | ||
|
||
|
||
class SaverCallback(Callback): | ||
def on_step_training(self, state): | ||
super().on_step_training(state) | ||
batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH] | ||
if batches_done % sample_interval == 0: | ||
save_image(state['gen_imgs'].data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True) | ||
|
||
|
||
# Configure data loader | ||
os.makedirs('./data/mnist', exist_ok=True) | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform) | ||
|
||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) | ||
|
||
|
||
# Model and optimizer | ||
model = GAN() | ||
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999)) | ||
|
||
|
||
@tb.metrics.running_mean | ||
@tb.metrics.mean | ||
class g_loss(tb.metrics.Metric): | ||
def __init__(self): | ||
super().__init__('g_loss') | ||
|
||
def process(self, state): | ||
return state['g_loss'] | ||
|
||
|
||
@tb.metrics.running_mean | ||
@tb.metrics.mean | ||
class d_loss(tb.metrics.Metric): | ||
def __init__(self): | ||
super().__init__('d_loss') | ||
|
||
def process(self, state): | ||
return state['d_loss'] | ||
|
||
|
||
def zero_loss(y_pred, y_true): | ||
return torch.zeros(y_true.shape[0], 1) | ||
|
||
|
||
torchbearermodel = tb.Model(model, optim, zero_loss, ['loss', g_loss(), d_loss()]) | ||
torchbearermodel.to('cuda') | ||
torchbearermodel.fit_generator(dataloader, epochs=200, pass_state=True, callbacks=[LossCallback(), SaverCallback()]) |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
Training a GAN | ||
==================================== | ||
|
||
We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). | ||
This tutorial is a modified version of the GAN_ from the brilliant collection of GAN implementations PyTorch_GAN_ by eriklindernoren on github. | ||
|
||
.. _PyTorch_GAN: https://github.com/eriklindernoren/PyTorch-GAN | ||
.. _GAN: https://github.com/eriklindernoren/PyTorch-GAN#gan | ||
|
||
Data and Constants | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We first define all constants for the example. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 18-24 | ||
|
||
We then define the dataset and dataloader - for this example, MNIST. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 119-126 | ||
|
||
Model | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We use the generator and discriminator from PyTorch_GAN_ and combine them into a model that performs a single forward pass. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 73-89 | ||
|
||
Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass. | ||
|
||
Loss | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
Since our loss is complicated in this example, we shall forgo the basic loss criterion used in normal torchbearer models. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 154-155 | ||
|
||
Instead use a callback to provide the loss. | ||
We also utilise this callback to add constants to state. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 92-106 | ||
|
||
Note that we have summed the separate discriminator and generator losses since their graphs are separated, this is allowable. | ||
|
||
Metrics | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition. | ||
We can then create metrics from these by decorating simple state fetcher metrics. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 134-141 | ||
|
||
Training | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We can then train the torchbearer model on the GPU in the standard way. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 158-160 | ||
|
||
Visualising | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save on training step. | ||
|
||
.. literalinclude:: /_static/examples/gan.py | ||
:language: python | ||
:lines: 109-114 | ||
|
||
After 172400 iterations we see the following. | ||
|
||
.. figure:: /_static/img/172400.png | ||
:scale: 200 % | ||
:alt: GAN generated samples after 172400 iterations | ||
|
||
|
||
Source Code | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
The source code for the example is given below: | ||
|
||
:download:`Download Python source code: gan.py </_static/examples/gan.py>` | ||
|
||
|
||
|