Skip to content

Commit

Permalink
Feature/closures (#524)
Browse files Browse the repository at this point in the history
* Base with closures

* wip

* Closure now lambda

* Use base_closure to create default closure

* Update closures

* Test closures

* Update docs

* Update docs

* Update docs

* Update docs

* Update changelog

* Formatting

* Formatting

* Add docstring

* Add docstring

* Formatting

* Formatting
  • Loading branch information
MattPainter01 committed Mar 14, 2019
1 parent ed91a94 commit 0de148f
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 123 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added cyclic learning rate finder
- Added on_init callback hook to run at the end of trial init
- Added callbacks for weight initialisation in ``torchbearer.callbacks.init``
- Added ``with_closure`` trial method that allows running of custom closures
- Added ``base_closure`` function to bases that allows creation of standard training loop closures
### Changed
### Deprecated
### Removed
### Fixed
- Fixed bug where replay errored when train or val steps were None
- Fixed a bug where mock optimser wouldn't call it's closure

## [0.3.0] - 2019-02-28
### Added
Expand Down
100 changes: 46 additions & 54 deletions docs/_static/examples/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torchbearer as tb
import torchbearer.callbacks as callbacks
from torchbearer import state_key
from torchbearer.bases import base_closure


os.makedirs('images', exist_ok=True)

Expand All @@ -27,6 +29,8 @@
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)
batch = torch.randn(25, latent_dim).to(device)


# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
Expand All @@ -36,6 +40,12 @@
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')

DISC_OPT = state_key('disc_opt')
GEN_OPT = state_key('gen_opt')
DISC_MODEL = state_key('disc_model')
DISC_IMGS = state_key('disc_imgs')
DISC_CRIT = state_key('disc_crit')


class Generator(nn.Module):
def __init__(self):
Expand All @@ -57,7 +67,8 @@ def block(in_feat, out_feat, normalize=True):
nn.Tanh()
)

def forward(self, z):
def forward(self, real_imgs, state):
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
Expand All @@ -76,48 +87,32 @@ def __init__(self):
nn.Sigmoid()
)

def forward(self, img):
def forward(self, img, state):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity


class GAN(nn.Module):
def __init__(self):
super().__init__()
self.discriminator = Discriminator()
self.generator = Generator()

def forward(self, real_imgs, state):
# Generator Forward
z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])
state[GEN_IMGS] = self.generator(z)
state[DISC_GEN] = self.discriminator(state[GEN_IMGS])
# This clears the function graph built up for the discriminator
self.discriminator.zero_grad()

# Discriminator Forward
state[DISC_GEN_DET] = self.discriminator(state[GEN_IMGS].detach())
state[DISC_REAL] = self.discriminator(real_imgs)
def gen_crit(state):
loss = adversarial_loss(state[DISC_MODEL](state[tb.Y_PRED], state), valid)
state[G_LOSS] = loss
return loss


@callbacks.add_to_loss
def loss_callback(state):
fake_loss = adversarial_loss(state[DISC_GEN_DET], fake)
real_loss = adversarial_loss(state[DISC_REAL], valid)
state[G_LOSS] = adversarial_loss(state[DISC_GEN], valid)
state[D_LOSS] = (real_loss + fake_loss) / 2
return state[G_LOSS] + state[D_LOSS]
def disc_crit(state):
real_loss = adversarial_loss(state[DISC_MODEL](state[tb.X], state), valid)
fake_loss = adversarial_loss(state[DISC_MODEL](state[tb.Y_PRED].detach(), state), fake)
loss = (real_loss + fake_loss) / 2
state[D_LOSS] = loss
return loss


batch = torch.randn(25, latent_dim).to(device)
@callbacks.on_step_training
@callbacks.only_if(lambda state: state[tb.BATCH] % sample_interval == 0)
def saver_callback(state):
batches_done = state[tb.EPOCH] * len(state[tb.GENERATOR]) + state[tb.BATCH]
if batches_done % sample_interval == 0:
samples = state[tb.MODEL].generator(batch)
save_image(samples, 'images/%d.png' % batches_done, nrow=5, normalize=True)
samples = state[tb.MODEL](batch, state)
save_image(samples, 'images/%d.png' % state[tb.BATCH], nrow=5, normalize=True)


# Configure data loader
Expand All @@ -126,39 +121,36 @@ def saver_callback(state):
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


# Model and optimizer
model = GAN()
optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.999))
generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))


@tb.metrics.running_mean
@tb.metrics.mean
class g_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('g_loss')
closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION, tb.LOSS, GEN_OPT)
closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, tb.LOSS, DISC_OPT)

def process(self, state):
return state[G_LOSS]

def closure(state):
closure_gen(state)
state[GEN_OPT].step()
closure_disc(state)
state[DISC_OPT].step()

@tb.metrics.running_mean
@tb.metrics.mean
class d_loss(tb.metrics.Metric):
def __init__(self):
super().__init__('d_loss')
from torchbearer.metrics import mean, running_mean
metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))]

def process(self, state):
return state[D_LOSS]
trial = tb.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback])
trial.with_train_generator(dataloader, steps=200000)
trial.to(device)

new_keys = {DISC_MODEL: discriminator.to(device), DISC_OPT: optimizer_D, GEN_OPT: optimizer_G, DISC_CRIT: disc_crit}
trial.state.update(new_keys)
trial.with_closure(closure)
trial.run(epochs=1)

torchbearertrial = tb.Trial(model, optim, criterion=None, metrics=['loss', g_loss(), d_loss()],
callbacks=[loss_callback, saver_callback])
torchbearertrial.with_train_generator(dataloader)
torchbearertrial.to(device)
torchbearertrial.run(epochs=200)
85 changes: 67 additions & 18 deletions docs/examples/gan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,122 @@ We first define all constants for the example.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 19-29
:lines: 20-32

We then define a number of state keys for convenience using :func:`.state_key`. This is optional, however, it automatically avoids key conflicts.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 32-37
:lines: 35-47

We then define the dataset and dataloader - for this example, MNIST.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 124-132
:lines: 120-125

Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We use the generator and discriminator from PyTorch_GAN_ and combine them into a model that performs a single forward pass.
We use the generator and discriminator from PyTorch_GAN_.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 86-102
:lines: 50-94

Note that we have to be careful to remove the gradient information from the discriminator after doing the generator forward pass.
We then create the models and optimisers.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 128-132

Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Since our loss computation in this example is complicated, we shall forgo the basic loss criterion used in normal torchbearer trials.
Instead we use a callback to provide the loss, in this case we use the :func:`.add_to_loss` callback decorator.
This decorates a function that returns a loss and automatically adds this to the main loss in training and validation.
GANs usually require two different losses, one for the generator and one for the discriminator.
We define these as functions of state so that we can access the discriminator model for the additional forward passes required.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 105-111
:lines: 97-108

Note that we have summed the separate discriminator and generator losses, since their graphs are separated, this is allowable.
We will see later how we get a torchbearer trial to use these losses.

Metrics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We would like to follow the discriminator and generator losses during training - note that we added these to state during the model definition.
We can then create metrics from these by decorating simple state fetcher metrics.
In torchbearer, state keys are also metrics, so we can take means and running means of them and tell torchbearer to output them as metrics.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 145-146

We will add this metric list to the trial when we create it.


Closures
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The training loop of a GAN is a bit different to a standard model training loop.
GANs require separate forward and backward passes for the generator and discriminator.
To achieve this in torchbearer we can write a new closure.
Since the individual training loops for the generator and discriminator are the same as a
standard training loop we can use a :func:`~torchbearer.bases.base_closure`.
The base closure takes state keys for required objects (data, model, optimiser, etc.) and returns a standard closure consisting of:

1. Zero gradients
2. Forward pass
3. Loss calculation
4. Backward pass

We create a separate closure for the generator and discriminator. We use separate state keys for some objects so we can use them separately, although the loss is easier to deal with in a single key.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 15, 135-136

We then create a main closure (a simple function of state) that runs both of these and steps the optimisers.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 140-147
:lines: 139-143

We will add this closure to the trial next.


Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We can then train the torchbearer trial on the GPU in the standard way.
Note that when torchbearer is passed a ``None`` criterion it automatically sets the base loss to 0.
We now create the torchbearer trial on the GPU in the standard way.
Note that when torchbearer is passed a ``None`` optimiser it creates a mock optimser that will just run the closure.
Since we are using the standard torchbearer state keys for the generator model and criterion, we can pass them in here.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 148-150

We now update state with the keys required for the discriminators closure and add the new closure to the trial.
Note that torchbearer doesn't know the discriminator model is a model here, so we have to sent it to the GPU ourselves.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 152-154

Finally we run the trial.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 160-164
:lines: 155

Visualising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We borrow the image saving method from PyTorch_GAN_ and put it in a call back to save :func:`~torchbearer.callbacks.decorators.on_step_training`.
We generate from the same inputs each time to get a better visulisation.
We generate from the same inputs each time to get a better visualisation.

.. literalinclude:: /_static/examples/gan.py
:language: python
:lines: 114-120
:lines: 111-115

Here is a Gif created from the saved images.

Expand Down

0 comments on commit 0de148f

Please sign in to comment.