Skip to content

Commit

Permalink
Fix checkpointers (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jul 30, 2018
1 parent baacf95 commit 84cbc92
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
### Removed
### Fixed
- Fixed a bug where checkpointers would not save the model in some cases

## [0.1.4] - 2018-07-23
### Added
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __getattr__(cls, name):
# built documents.
#
# The short X.Y version.
version = '0.1.4'
version = '0.1.5'
# The full version, including alpha/beta/rc tags.
release = '0.1.4'
release = '0.1.5'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

setup(
name='torchbearer',
version='0.1.4',
version='0.1.5',
packages=['torchbearer', 'torchbearer.metrics', 'torchbearer.callbacks', 'tests', 'tests.metrics', 'tests.callbacks'],
url='https://github.com/ecs-vlc/torchbearer',
download_url='https://github.com/ecs-vlc/torchbearer/archive/0.1.4.tar.gz',
download_url='https://github.com/ecs-vlc/torchbearer/archive/0.1.5.tar.gz',
license='GPL-3.0',
author='Matt Painter',
author_email='mp2u16@ecs.soton.ac.uk',
Expand Down
5 changes: 3 additions & 2 deletions torchbearer/callbacks/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ def save_checkpoint(self, model_state, overwrite_most_recent=False):

filepath = self.fileformat.format(**state)

torch.save(model_state[torchbearer.SELF].state_dict(), filepath, pickle_module=self.pickle_module, pickle_protocol=self.pickle_protocol)

if self.most_recent is not None and overwrite_most_recent:
os.remove(self.most_recent)

torch.save(model_state[torchbearer.SELF].state_dict(), filepath, pickle_module=self.pickle_module,
pickle_protocol=self.pickle_protocol)

self.most_recent = filepath


Expand Down

0 comments on commit 84cbc92

Please sign in to comment.