Skip to content

Commit

Permalink
Add final torchbearer.py tests (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 20, 2018
1 parent b89f285 commit 8735eee
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 29 deletions.
199 changes: 175 additions & 24 deletions tests/test_torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,59 @@


class TestTorchbearer(TestCase):
@patch('torchbearer.cv_utils.get_train_valid_sets')
def test_fit_valid_sets_args(self, gtvs):
x = torch.rand(1,5)
y = torch.rand(1,5)
val_data = (1,2)
val_split = 0.2
shuffle = False

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

gtvs.return_value = (1, 2)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearermodel.fit_generator = Mock()
torchbearermodel.fit(x, y, 1, validation_data=val_data, validation_split=val_split, shuffle=shuffle)

gtvs.assert_called_once()
self.assertTrue(list(gtvs.call_args[0][0].numpy()[0]) == list(x.numpy()[0]))
self.assertTrue(list(gtvs.call_args[0][1].numpy()[0]) == list(y.numpy()[0]))
self.assertTrue(gtvs.call_args[0][2] == val_data)
self.assertTrue(gtvs.call_args[0][3] == val_split)
self.assertTrue(gtvs.call_args[1]['shuffle'] == shuffle)

def test_fit_no_valid(self):
x = torch.rand(1, 5)
y = torch.rand(1, 5)

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearermodel.fit_generator = Mock()
fit = torchbearermodel.fit_generator
torchbearermodel.fit(x, y, 1, validation_split=None)

self.assertTrue(fit.call_args[1]['validation_generator'] is None)

def test_main_loop_metrics(self):
metric = Metric('test')
metric.process = Mock(return_value={'test': 0})
metric.process_final = Mock(return_value={'test': 0})
metric.reset = Mock(return_value=None)
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -25,7 +71,6 @@ def test_main_loop_metrics(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -42,9 +87,39 @@ def test_main_loop_metrics(self):
torchbearerstate[torchbearer.METRIC_LIST].metric_list[0].process_final.assert_called_once()
self.assertTrue(torchbearerstate[torchbearer.METRICS]['test'] == 0)

def test_main_loop_verbose(self):
metric = Metric('test')

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
train_steps = len(data)

epochs = 1

callback = MagicMock()

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

import sys
from io import StringIO
saved_std_err = sys.stderr
out = StringIO()
sys.stderr = out

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearerstate = torchbearermodel.fit_generator(generator, train_steps, epochs, 1, [callback], initial_epoch=0, pass_state=False)

output = out.getvalue().strip()
self.assertTrue(output != '')
sys.stderr = saved_std_err

def test_main_loop_train_steps_positive(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -53,7 +128,6 @@ def test_main_loop_train_steps_positive(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -70,7 +144,6 @@ def test_main_loop_train_steps_positive(self):
@patch("warnings.warn")
def test_main_loop_train_steps_fractional(self, _):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -79,7 +152,6 @@ def test_main_loop_train_steps_fractional(self, _):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -93,9 +165,36 @@ def test_main_loop_train_steps_fractional(self, _):

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == int(train_steps))

def test_main_loop_validation_setup(self):
metric = Metric('test')

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
valgenerator = DataLoader(data)
train_steps = 2

epochs = 1

callback = MagicMock()

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearermodel._test_loop = Mock()
torchbearerstate = torchbearermodel.fit_generator(generator, train_steps, epochs, 0, [callback],
validation_generator=valgenerator, initial_epoch=0,
pass_state=False)

self.assertTrue(torchbearerstate[torchbearer.VALIDATION_STEPS] == len(valgenerator))
self.assertTrue(torchbearerstate[torchbearer.VALIDATION_GENERATOR] == valgenerator)

def test_main_loop_epochs_positive(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -104,7 +203,6 @@ def test_main_loop_epochs_positive(self):
epochs = 2

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -120,7 +218,6 @@ def test_main_loop_epochs_positive(self):

def test_main_loop_epochs_zero(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -129,7 +226,6 @@ def test_main_loop_epochs_zero(self):
epochs = 0

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -145,7 +241,6 @@ def test_main_loop_epochs_zero(self):

def test_main_loop_epochs_negative(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -154,7 +249,6 @@ def test_main_loop_epochs_negative(self):
epochs = -2

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -171,7 +265,6 @@ def test_main_loop_epochs_negative(self):
@patch("warnings.warn")
def test_main_loop_epochs_fractional(self, _):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -180,7 +273,6 @@ def test_main_loop_epochs_fractional(self, _):
epochs = 2.5

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -194,9 +286,32 @@ def test_main_loop_epochs_fractional(self, _):

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == int(epochs)*len(data))

@patch("warnings.warn")
def test_main_loop_epochs_none(self, warning):
metric = Metric('test')

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
train_steps = None

epochs = None

callback = MagicMock()

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearerstate = torchbearermodel.fit_generator(generator, train_steps, epochs, 0, [callback], initial_epoch=0, pass_state=False)

self.assertTrue(warning.call_count == 1)

def test_main_loop_train_steps_too_big(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -205,7 +320,6 @@ def test_main_loop_train_steps_too_big(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -221,7 +335,6 @@ def test_main_loop_train_steps_too_big(self):

def test_main_loop_train_steps_negative(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -230,7 +343,6 @@ def test_main_loop_train_steps_negative(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -246,7 +358,6 @@ def test_main_loop_train_steps_negative(self):

def test_main_loop_pass_state(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -255,7 +366,6 @@ def test_main_loop_pass_state(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -271,7 +381,6 @@ def test_main_loop_pass_state(self):

def test_main_loop_optimizer(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -280,7 +389,6 @@ def test_main_loop_optimizer(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand All @@ -296,7 +404,6 @@ def test_main_loop_optimizer(self):

def test_main_loop_criterion(self):
metric = Metric('test')
metric_list = MetricList([metric])

data = [(torch.Tensor([1]), torch.Tensor([1])), (torch.Tensor([2]), torch.Tensor([2])), (torch.Tensor([3]), torch.Tensor([3]))]
generator = DataLoader(data)
Expand All @@ -305,7 +412,6 @@ def test_main_loop_criterion(self):
epochs = 1

callback = MagicMock()
callback_List = torchbearer.CallbackList([callback])

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
Expand Down Expand Up @@ -690,6 +796,29 @@ def test_test_loop_stop_training(self):

self.assertTrue(torchbearerstate[torchbearer.MODEL].call_count == 1)

def test_evaluate(self):
x = torch.rand(1,5)
y = torch.rand(1,5)
pass_state = False
verbose=0

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearermodel.evaluate_generator = Mock()
ev = torchbearermodel.evaluate_generator
torchbearermodel.evaluate(x, y, verbose=verbose, pass_state=pass_state)

ev.assert_called_once()
self.assertTrue(ev.call_args[0][1] == verbose)
self.assertTrue(ev.call_args[1]['pass_state'] == pass_state)

def test_evaluate_generator_args(self):
torchmodel = MagicMock()
optimizer = MagicMock()
Expand Down Expand Up @@ -754,6 +883,28 @@ def test_evaluate_generator_steps(self):
torchbearermodel.evaluate_generator(generator, 0, steps, pass_state)
self.assertTrue(torchbearermodel._test_loop.call_args[0][4] == steps)

def test_predict(self):
x = torch.rand(1,5)
pass_state = False
verbose=0

torchmodel = MagicMock()
torchmodel.forward = Mock(return_value=1)
optimizer = MagicMock()
metric = Metric('test')

loss = torch.tensor([2], requires_grad=True)
criterion = Mock(return_value=loss)

torchbearermodel = Model(torchmodel, optimizer, criterion, [metric])
torchbearermodel.predict_generator = Mock()
pred = torchbearermodel.predict_generator
torchbearermodel.predict(x, verbose=verbose, pass_state=pass_state)

pred.assert_called_once()
self.assertTrue(pred.call_args[0][1] == verbose)
self.assertTrue(pred.call_args[1]['pass_state'] == pass_state)

def test_predict_generator_args(self):
from torchbearer.callbacks import AggregatePredictions

Expand Down

0 comments on commit 8735eee

Please sign in to comment.