Skip to content

Commit

Permalink
Callback inline examples (#564)
Browse files Browse the repository at this point in the history
* Add mock model and examples for decorators, checkpointers and csv logger.

* Fix no model test

* Fix MockModel in Py2 and dataset variational tests

* More examples

* More examples

* Add some trial examples

* Remove bad characters

* More examples

* Trial inline examples

* Add tensorboard examples

* Fix test
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jun 12, 2019
1 parent 1d88b8d commit b244c5d
Show file tree
Hide file tree
Showing 16 changed files with 882 additions and 35 deletions.
41 changes: 17 additions & 24 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_basic_opt(self):
model = NetWithState(p)
optim = torch.optim.SGD(model.parameters(), lr=0.01)

tbmodel = torchbearer.Trial(model, optim, loss).for_train_steps(training_steps).for_val_steps(1)
tbmodel.run()
trial = torchbearer.Trial(model, optim, loss).for_train_steps(training_steps).for_val_steps(1)
trial.run()

self.assertAlmostEqual(model.pars[0].item(), 5.0, places=4)
self.assertAlmostEqual(model.pars[1].item(), 0.0, places=4)
Expand All @@ -61,17 +61,17 @@ def test_basic_checkpoint(self):
model = Net(p)
optim = torch.optim.SGD(model.parameters(), lr=0.01)

tbmodel = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1)
tbmodel.run(2) # Simulate 2 'epochs'
trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1)
trial.run(2) # Simulate 2 'epochs'

# Reload
p = torch.tensor([2.0, 1.0, 10.0])
model = Net(p)
optim = torch.optim.SGD(model.parameters(), lr=0.01)

tbmodel = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps)
tbmodel.load_state_dict(torch.load('test.pt'))
self.assertEqual(len(tbmodel.state[torchbearer.HISTORY]), 2)
trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps)
trial.load_state_dict(torch.load('test.pt'))
self.assertEqual(len(trial.state[torchbearer.HISTORY]), 2)
self.assertAlmostEqual(model.pars[0].item(), 5.0, places=4)
self.assertAlmostEqual(model.pars[1].item(), 0.0, places=4)
self.assertAlmostEqual(model.pars[2].item(), 1.0, places=4)
Expand All @@ -91,34 +91,27 @@ def custom_loader(state):
state[torchbearer.X], state[torchbearer.Y_TRUE] = None, None
test_var['loaded'] = True

tbmodel = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1)
tbmodel.with_loader(custom_loader)
trial = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1)
trial.with_loader(custom_loader)
self.assertTrue(not test_var['loaded'])
tbmodel.run(1)
trial.run(1)
self.assertTrue(test_var['loaded'])

import os
os.remove('test.pt')

def test_only_model(self):
p = torch.tensor([2.0, 1.0, 10.0])

model = Net(p)

tbmodel = torchbearer.Trial(model)
self.assertListEqual(tbmodel.run(), [])
trial = torchbearer.Trial(model)
self.assertListEqual(trial.run(), [])

def test_no_model(self):
tbmodel = torchbearer.Trial(None)

import warnings
with warnings.catch_warnings(record=True) as w:
tbmodel.run()
self.assertTrue(len(w) == 1)

trial = torchbearer.Trial(None)
trial.run()
self.assertTrue(torchbearer.trial.MockModel()(torch.rand(1)) is None)

def test_no_train_steps(self):
tbmodel = torchbearer.Trial(None)
tbmodel.for_val_steps(10)
tbmodel.run()
trial = torchbearer.Trial(None)
trial.for_val_steps(10)
trial.run()
2 changes: 2 additions & 0 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@

from torchbearer import Callback
from .callbacks import *
from .lr_finder import CyclicLR
from .lsuv import LSUV
from .checkpointers import *
from .csv_logger import *
from .early_stopping import *
Expand Down
60 changes: 57 additions & 3 deletions torchbearer/callbacks/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_p
will be saved with the epoch number and the validation loss in the filename. The torch :class:`.Trial` will be
saved to filename.
Example: ::
>>> from torchbearer.callbacks import ModelCheckpoint
>>> from torchbearer import Trial
>>> import torch
# Example Trial (without optimiser or loss criterion) which uses this checkpointer
>>> model = torch.nn.Linear(1,1)
>>> checkpoint = ModelCheckpoint('my_path.pt', monitor='val_acc', mode='max')
>>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Args:
filepath (str): Path to save the model file
save_model_params_only (bool): If `save_model_params_only=True`, only model parameters will be saved so that
Expand All @@ -77,7 +89,21 @@ def ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.pt', save_model_p


class MostRecent(_Checkpointer):
"""Model checkpointer which saves the most recent model to a given filepath.
"""Model checkpointer which saves the most recent model to a given filepath. `filepath` can contain named
formatting options, which will be filled any values from state. For example: if `filepath` is
`weights.{epoch:02d}-{val_loss:.2f}`, then the model checkpoints will be saved with the epoch number and the
validation loss in the filename.
Example: ::
>>> from torchbearer.callbacks import MostRecent
>>> from torchbearer import Trial
>>> import torch
# Example Trial (without optimiser or loss criterion) which uses this checkpointer
>>> model = torch.nn.Linear(1,1)
>>> checkpoint = MostRecent('my_path.pt')
>>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Args:
filepath (str): Path to save the model file
Expand All @@ -101,7 +127,21 @@ def on_checkpoint(self, state):


class Best(_Checkpointer):
"""Model checkpointer which saves the best model according to the given configurations.
"""Model checkpointer which saves the best model according to the given configurations. `filepath` can contain
named formatting options, which will be filled any values from state. For example: if `filepath` is
`weights.{epoch:02d}-{val_loss:.2f}`, then the model checkpoints will be saved with the epoch number and the
validation loss in the filename.
Example: ::
>>> from torchbearer.callbacks import Best
>>> from torchbearer import Trial
>>> import torch
# Example Trial (without optimiser or loss criterion) which uses this checkpointer
>>> model = torch.nn.Linear(1,1)
>>> checkpoint = Best('my_path.pt', monitor='val_acc', mode='max')
>>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Args:
filepath (str): Path to save the model file
Expand Down Expand Up @@ -178,7 +218,21 @@ def on_checkpoint(self, state):


class Interval(_Checkpointer):
"""Model checkpointer which which saves the model every 'period' epochs to the given filepath.
"""Model checkpointer which which saves the model every 'period' epochs to the given filepath. `filepath` can
contain named formatting options, which will be filled any values from state. For example: if `filepath` is
`weights.{epoch:02d}-{val_loss:.2f}`, then the model checkpoints will be saved with the epoch number and the
validation loss in the filename.
Example: ::
>>> from torchbearer.callbacks import Interval
>>> from torchbearer import Trial
>>> import torch
# Example Trial (without optimiser or loss criterion) which uses this checkpointer
>>> model = torch.nn.Linear(1,1)
>>> checkpoint = Interval('my_path.pt', period=100, on_batch=True)
>>> trial = Trial(model, callbacks=[checkpoint], metrics=['acc'])
Args:
filepath (str): Path to save the model file
Expand Down
10 changes: 10 additions & 0 deletions torchbearer/callbacks/csv_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
class CSVLogger(Callback):
"""Callback to log metrics to a given csv file.
Example: ::
>>> from torchbearer.callbacks import CSVLogger
>>> from torchbearer import Trial
>>> import torch
# Example Trial (without optimiser or loss criterion) which writes metrics to a csv file appending to previous content
>>> logger = CSVLogger('my_path.pt', separator=',', append=True)
>>> trial = Trial(None, callbacks=[logger], metrics=['acc'])
Args:
filename (str): The name of the file to output to
separator (str): The delimiter to use (e.g. comma, tab etc.)
Expand Down

0 comments on commit b244c5d

Please sign in to comment.