Skip to content

Commit

Permalink
Test/torchbearer (#213)
Browse files Browse the repository at this point in the history
* Add train() and eval() tests

* Add tests for predict and evaluate generators

* Format imports
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 18, 2018
1 parent c7562c5 commit 39c782d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 5 deletions.
155 changes: 155 additions & 0 deletions tests/test_torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,160 @@ def test_test_loop_stop_training(self):

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == 1)

def test_evaluate_generator_args(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = None

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.METRICS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.evaluate_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][1].callback_list == [])
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_evaluate_generator_verbose(self):
from torchbearer.callbacks import Tqdm

torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = None

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.METRICS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.evaluate_generator(generator, 1, steps, pass_state)
self.assertIsInstance(torchbearermodel._test_loop.call_args[0][1].callback_list[0], Tqdm)

def test_evaluate_generator_pass_state(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = True
steps = None

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.METRICS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.evaluate_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)

def test_evaluate_generator_steps(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = 100

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.METRICS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.evaluate_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_predict_generator_args(self):
from torchbearer.callbacks import AggregatePredictions

torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = None

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.FINAL_PREDICTIONS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.predict_generator(generator, 0, steps, pass_state)
self.assertIsInstance(torchbearermodel._test_loop.call_args[0][1].callback_list[0], AggregatePredictions)
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_predict_generator_verbose(self):
from torchbearer.callbacks import Tqdm

torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = None

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.FINAL_PREDICTIONS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.predict_generator(generator, 1, steps, pass_state)
self.assertIsInstance(torchbearermodel._test_loop.call_args[0][1].callback_list[1], Tqdm)
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_predict_generator_steps(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = 100

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.FINAL_PREDICTIONS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.predict_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_predict_generator_pass_state(self):
torchmodel = MagicMock()
optimizer = MagicMock()
generator = MagicMock()

pass_state = False
steps = 100

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state[torchbearer.FINAL_PREDICTIONS] = 1
torchbearermodel._test_loop = Mock()

torchbearermodel.predict_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][2] == pass_state)

def test_train(self):
torchmodel = torch.nn.Sequential(torch.nn.Linear(1,1))
optimizer = MagicMock()
metric_list = MagicMock()

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state = {torchbearer.MODEL: torchmodel, torchbearer.METRIC_LIST: metric_list}
torchbearermodel.train()
self.assertTrue(torchbearermodel.main_state[torchbearer.MODEL].training == True)
torchbearermodel.main_state[torchbearer.METRIC_LIST].train.assert_called_once()

def test_eval(self):
torchmodel = torch.nn.Sequential(torch.nn.Linear(1,1))
optimizer = MagicMock()
metric_list = MagicMock()

torchbearermodel = Model(torchmodel, optimizer, torch.nn.L1Loss(), [])
torchbearermodel.main_state = {torchbearer.MODEL: torchmodel, torchbearer.METRIC_LIST: metric_list}
torchbearermodel.eval()
self.assertTrue(torchbearermodel.main_state[torchbearer.MODEL].training == False)
torchbearermodel.main_state[torchbearer.METRIC_LIST].eval.assert_called_once()

def test_to_both_args(self):
dev = 'cuda:1'
dtype = torch.float16
Expand Down Expand Up @@ -923,3 +1077,4 @@ def test_load_batch_predict_list(self):
Model._load_batch_predict('training', state)
self.assertTrue(state['x'].item() == items[0][0].item())
self.assertTrue(state['y_true'].item() == items[0][1].item())

8 changes: 3 additions & 5 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import math
import warnings
from _warnings import warn

import torch
from torch.utils.data import DataLoader, TensorDataset

import torchbearer
from torchbearer.cv_utils import get_train_valid_sets
from torchbearer import metrics as torchbearer_metrics
from torchbearer.callbacks.aggregate_predictions import AggregatePredictions
from torchbearer.callbacks.callbacks import CallbackList
from torchbearer.callbacks.printer import Tqdm
from torchbearer.callbacks.aggregate_predictions import AggregatePredictions
from torchbearer import metrics as torchbearer_metrics
from torchbearer.cv_utils import get_train_valid_sets


class Model:
Expand Down

0 comments on commit 39c782d

Please sign in to comment.