Skip to content

Commit

Permalink
Feature/custom loaders (#526)
Browse files Browse the repository at this point in the history
* Add with_loader

* Formatting

* Formatting

* Update changelog

* Remove sampler

* Update test pass

* Hide loader key

* Unhide loader key

* Revert to a5174b

* Add end to end

* Fix broken tests

* Update CHANGELOG.md
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jun 5, 2019
1 parent 128af12 commit 50caf34
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 125 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added ``with_loader`` trial method that allows running of custom batch loaders
### Changed
### Deprecated
### Removed
Expand Down
21 changes: 21 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def test_basic_checkpoint(self):
import os
os.remove('test.pt')

def test_with_loader(self):
p = torch.tensor([2.0, 1.0, 10.0])
training_steps = 2

model = Net(p)
optim = torch.optim.SGD(model.parameters(), lr=0.01)
test_var = {'loaded': False}

def custom_loader(state):
state[torchbearer.X], state[torchbearer.Y_TRUE] = None, None
test_var['loaded'] = True

tbmodel = torchbearer.Trial(model, optim, loss, callbacks=[torchbearer.callbacks.MostRecent(filepath='test.pt')]).for_train_steps(training_steps).for_val_steps(1)
tbmodel.with_loader(custom_loader)
self.assertTrue(not test_var['loaded'])
tbmodel.run(1)
self.assertTrue(test_var['loaded'])

import os
os.remove('test.pt')

def test_only_model(self):
p = torch.tensor([2.0, 1.0, 10.0])

Expand Down

0 comments on commit 50caf34

Please sign in to comment.