Skip to content

Commit

Permalink
Fix/checkpoint no file error (#541)
Browse files Browse the repository at this point in the history
* Fix bug where checkpointers would error if they couldn't find the previous checkpoint file

* Update changelog

* Fix python2 test
  • Loading branch information
MattPainter01 authored and ethanwharris committed Apr 23, 2019
1 parent 86147c0 commit 632fbf6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug with the once and once_per_epoch decorators
- Fixed a bug where the test criterion wouldn't accept a function of state
- Fixed a bug where type inference would not work correctly when chaining ``Trial`` methods
- Fixed a bug where checkpointers would error when they couldn't find the old checkpoint to overwrite

## [0.3.0] - 2019-02-28
### Added
Expand Down
13 changes: 12 additions & 1 deletion tests/callbacks/test_checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
import torchbearer
from torchbearer import Trial
from torchbearer.callbacks.checkpointers import _Checkpointer, ModelCheckpoint, MostRecent, Interval, Best

import warnings

class TestCheckpointer(TestCase):
@patch('os.makedirs')
def test_make_dirs(self, mock_dirs):
_Checkpointer('thisdirectoryshouldntexist/norshouldthis/model.pt')
mock_dirs.assert_called_once_with('thisdirectoryshouldntexist/norshouldthis')

@patch('torch.save')
@patch('os.makedirs')
def test_no_existing_file(self, mock_dirs, mock_save):
check = _Checkpointer('thisdirectoryshouldntexist/norshouldthis/model.pt')
check.most_recent = 'thisfiledoesnotexist.pt'
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
check.save_checkpoint({torchbearer.METRICS: {}, torchbearer.SELF: Mock()}, True)
self.assertTrue(len(w) == 1)
self.assertTrue('Failed to delete old file' in str(w[-1].message))

@patch("torch.save")
def test_save_checkpoint_save_filename(self, mock_save):
torchmodel = Mock()
Expand Down
6 changes: 5 additions & 1 deletion torchbearer/callbacks/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torchbearer.callbacks.callbacks import Callback
import os
import warnings


class _Checkpointer(Callback):
Expand All @@ -30,7 +31,10 @@ def save_checkpoint(self, model_state, overwrite_most_recent=False):
filepath = self.fileformat.format(**string_state)

if self.most_recent is not None and overwrite_most_recent:
os.remove(self.most_recent)
try:
os.remove(self.most_recent)
except OSError:
warnings.warn('Failed to delete old file. Are you running two checkpointers with the same filename?')

if self.save_model_params_only:
torch.save(model_state[torchbearer.MODEL].state_dict(), filepath, pickle_module=self.pickle_module,
Expand Down

0 comments on commit 632fbf6

Please sign in to comment.