Skip to content
Permalink
Browse files

Fix/accessing metrics (#659)

* Add get metric method and fix usages

* Update changelog

* Update checks on metrics dict

* Add callback end to end test

* Add extra eval call

* Fix tests

* Test bad checkpoint monitor

* Update changelog
  • Loading branch information
MattPainter01 committed Feb 17, 2020
1 parent e25593e commit d73f5cd3d32ac96bc76fab6580da10164cb058a2
@@ -3,6 +3,17 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [Unreleased]
### Added
### Changed
### Deprecated
### Removed
### Fixed
- Fixed a bug in access metrics function and callbacks that use it
- Fixed bug where schedulers were called before optimisers
- Fixed a bug where the csv logger closed the file too early

## [0.5.3] - 2020-01-31
### Added
- Method in bases to access metrics
@@ -6,6 +6,7 @@
from torchbearer.callbacks.checkpointers import _Checkpointer, ModelCheckpoint, MostRecent, Interval, Best
import warnings


class TestCheckpointer(TestCase):
@patch('os.makedirs')
def test_make_dirs(self, mock_dirs):
@@ -338,6 +339,18 @@ def test_auto_shoud_be_max(self, _):
check.on_checkpoint(state)
self.assertTrue(check.mode == 'max')

@patch('torchbearer.callbacks.checkpointers._Checkpointer.save_checkpoint')
def test_bad_monitor(self, _):
state = {torchbearer.METRICS: {'acc_loss': 0.1}}

file_path = 'test_file_{acc_loss:.2f}'
check = Best(file_path, monitor='test_fail')
check.on_start(state)

with warnings.catch_warnings(record=True) as w:
check.on_checkpoint(state)
self.assertTrue(len(w) == 1)

def test_state_dict(self):
check = Best('test')
check.best = 'temp2'
@@ -16,7 +16,7 @@ def test_write_header(self, mock_open):
}

logger = CSVLogger('test_file.log')

logger.on_start(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
logger.on_end(state)
@@ -35,7 +35,7 @@ def test_write_no_header(self, mock_open):
}

logger = CSVLogger('test_file.log', write_header=False)

logger.on_start(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
logger.on_end(state)
@@ -54,7 +54,7 @@ def test_csv_closed(self, mock_open):
}

logger = CSVLogger('test_file.log', write_header=False)

logger.on_start(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
logger.on_end(state)
@@ -70,7 +70,7 @@ def test_append(self, mock_open):
}

logger = CSVLogger('test_file.log', append=True)

logger.on_start(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
logger.on_end(state)
@@ -111,7 +111,7 @@ def test_write_on_epoch(self, mock_open, mock_write):
}

logger = CSVLogger('test_file.log')

logger.on_start(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
logger.on_end(state)
@@ -128,7 +128,7 @@ def test_batch_granularity(self, mock_open, mock_write):
}

logger = CSVLogger('test_file.log', batch_granularity=True)

logger.on_start(state)
logger.on_step_training(state)
logger.on_step_training(state)
logger.on_end_epoch(state)
@@ -44,7 +44,8 @@ def test_torch_scheduler_on_batch_with_monitor(self):
mock_scheduler.reset_mock()

def test_torch_scheduler_on_epoch_with_monitor(self):
state = {torchbearer.EPOCH: 1, torchbearer.METRICS: {'test': 101}, torchbearer.OPTIMIZER: 'optimizer'}
state = {torchbearer.EPOCH: 1, torchbearer.METRICS: {'test': 101}, torchbearer.OPTIMIZER: 'optimizer',
torchbearer.DATA: None}
mock_scheduler = Mock()
mock_scheduler.return_value = mock_scheduler

@@ -108,10 +109,6 @@ def test_torch_scheduler_on_epoch_no_monitor(self):
mock_scheduler.assert_called_once_with('optimizer')
mock_scheduler.reset_mock()

torch_scheduler.on_start_training(state)
mock_scheduler.step.assert_called_once_with(epoch=1)
mock_scheduler.reset_mock()

torch_scheduler.on_sample(state)
mock_scheduler.assert_not_called()
mock_scheduler.reset_mock()
@@ -5,6 +5,7 @@
import torch.nn.init as init

import torchbearer
from torchbearer import callbacks as c


class Net(Module):
@@ -57,6 +58,31 @@ def test_basic_opt(self):
self.assertAlmostEqual(model.pars[1].item(), 0.0, places=4)
self.assertAlmostEqual(model.pars[2].item(), 1.0, places=4)

def test_callbacks(self):
from torch.utils.data import TensorDataset
traingen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1))
valgen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1))
testgen = TensorDataset(torch.rand(10, 1, 3), torch.rand(10, 1))

model = torch.nn.Linear(3, 1)
optim = torch.optim.SGD(model.parameters(), lr=0.01)
cbs = []
cbs.extend([c.EarlyStopping(), c.GradientClipping(10, model.parameters()), c.Best('test.pt'),
c.MostRecent('test.pt'), c.ReduceLROnPlateau(), c.CosineAnnealingLR(0.1, 0.01),
c.ExponentialLR(1), c.Interval('test.pt'), c.CSVLogger('test_csv.pt'),
c.L1WeightDecay(), c.L2WeightDecay(), c.TerminateOnNaN(monitor='fail_metric')])

trial = torchbearer.Trial(model, optim, torch.nn.MSELoss(), metrics=['loss'], callbacks=cbs)
trial = trial.with_generators(traingen, valgen, testgen)
trial.run(2)
trial.predict()
trial.evaluate(data_key=torchbearer.TEST_DATA)
trial.evaluate()

import os
os.remove('test.pt')
os.remove('test_csv.pt')

def test_zero_model(self):
model = Linear(3, 1)
init.constant_(model.weight, 0)
@@ -452,6 +452,9 @@ def decorator(inner):


def get_metric(self_tag, state, metric_key):
if torchbearer.DATA in state and state[torchbearer.DATA] == 'test_data' and 'val_' in metric_key:
return None

if metric_key in state[torchbearer.METRICS]:
return state[torchbearer.METRICS][metric_key]
else:
@@ -227,6 +227,8 @@ def on_checkpoint(self, state):
self.epochs_since_last_save = 0

current = get_metric('Best Checkpoint', state, self.monitor)
if current is None:
return

if self.monitor_op(current, self.best):
self.best = current
@@ -42,13 +42,16 @@ def __init__(self, filename, separator=',', batch_granularity=False, write_heade
filemode = 'a'
else:
filemode = 'w'
self.filemode = filemode

self.write_header = write_header

def on_start(self, state):
if sys.version_info[0] < 3:
filemode += 'b'
self.csvfile = open(self.filename, filemode)
self.filemode += 'b'
self.csvfile = open(self.filename, self.filemode)
else:
self.csvfile = open(self.filename, filemode, newline='')
self.write_header = write_header
self.csvfile = open(self.filename, self.filemode, newline='')

def on_step_training(self, state):
super(CSVLogger, self).on_step_training(state)
@@ -71,6 +71,9 @@ def load_state_dict(self, state_dict):

def step(self, state):
current = get_metric('Early Stopping', state, self.monitor)
if current is None:
return

if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
@@ -5,6 +5,7 @@
from torchbearer.callbacks import Callback
from torchbearer.bases import get_metric


class no_print:
def __init__(self):
pass
@@ -66,7 +67,9 @@ def on_start(self, state):
self.batch_plt = PlotLosses(**self._kwargs)

def _on_step_training(self, state):
# These checks shouldn't fail
self.batch_plt.update({k: get_metric('LiveLossPlot', state, k) for k in state[torchbearer.METRICS]})

if state[torchbearer.BATCH] % self.batch_step_size == 0 and not self.draw_once:
with no_print():
self.batch_plt.draw()
@@ -37,12 +37,11 @@ def __init__(self, monitor='running_loss'):
self._monitor = monitor

def _check(self, state):
if self._monitor in state[torchbearer.METRICS]:
value = get_metric('TerminateOnNaN', state, self._monitor)
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Invalid ' + self._monitor + ', terminating')
state[torchbearer.STOP_TRAINING] = True
value = get_metric('TerminateOnNaN', state, self._monitor)
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Invalid ' + self._monitor + ', terminating')
state[torchbearer.STOP_TRAINING] = True

def on_step_training(self, state):
self._check(state)
@@ -1,5 +1,4 @@
import torchbearer
import warnings
from torchbearer.callbacks import Callback
from torchbearer.bases import get_metric

@@ -21,16 +20,24 @@ def on_sample(self, state):
self._scheduler.step()

def on_step_training(self, state):
if self._step_on_batch and self._monitor is not None:
self._scheduler.step(get_metric('Scheduler', state, self._monitor))

def on_start_training(self, state):
if not self._step_on_batch and self._monitor is None:
self._scheduler.step(epoch=state[torchbearer.EPOCH])
if self._step_on_batch:
if self._monitor is not None:
current = get_metric('Scheduler', state, self._monitor)
if current is None:
return
self._scheduler.step(current)
else:
self._scheduler.step()

def on_end_epoch(self, state):
if not self._step_on_batch and self._monitor is not None:
self._scheduler.step(get_metric('Scheduler', state, self._monitor), epoch=state[torchbearer.EPOCH])
if not self._step_on_batch:
if self._monitor is not None:
current = get_metric('Scheduler', state, self._monitor)
if current is None:
return
self._scheduler.step(current, epoch=state[torchbearer.EPOCH])
else:
self._scheduler.step(epoch=state[torchbearer.EPOCH])


class LambdaLR(TorchScheduler):
@@ -778,6 +778,7 @@ def with_data(self, x_train=None, y_train=None, x_val=None, y_val=None, x_test=N
self.with_train_data(x_train, y_train, batch_size, shuffle, num_workers, train_steps)
self.with_val_data(x_val, y_val, batch_size, shuffle, num_workers, val_steps)
self.with_test_data(x_test, batch_size, num_workers, test_steps)
return self

# Infinite steps and loading

@@ -1 +1 @@
__version__ = '0.5.3'
__version__ = '0.5.3.dev'

0 comments on commit d73f5cd

Please sign in to comment.
You can’t perform that action at this time.