Skip to content

Commit

Permalink
Fix code style issues (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jul 31, 2018
1 parent 04081b5 commit cf24e5e
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 115 deletions.
10 changes: 5 additions & 5 deletions tests/callbacks/test_tensor_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_simple_case(self, mock_board, mock_grid):

state = {'x': torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardImages(name='test', key='x', write_each_epoch=False, num_images=18, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
tboard = TensorBoardImages(name='test', key='x', write_each_epoch=False, num_images=18, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1)

tboard.on_start(state)
tboard.on_step_validation(state)
Expand All @@ -150,7 +150,7 @@ def test_multi_batch(self, mock_board, mock_grid):

state = {'x': torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardImages(name='test', key='x', write_each_epoch=False, num_images=36, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
tboard = TensorBoardImages(name='test', key='x', write_each_epoch=False, num_images=36, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1)

tboard.on_start(state)
tboard.on_step_validation(state)
Expand All @@ -170,7 +170,7 @@ def test_multi_epoch(self, mock_board, mock_grid):

state = {'x': torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=36, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=36, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1)

tboard.on_start(state)
tboard.on_step_validation(state)
Expand All @@ -191,7 +191,7 @@ def test_single_channel(self, mock_board, mock_grid):

state = {'x': torch.ones(18, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=18, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=18, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1)

tboard.on_start(state)
tboard.on_step_validation(state)
Expand All @@ -210,7 +210,7 @@ def test_odd_batches(self, mock_board, mock_grid):

state = {'x': torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 1, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=40, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
tboard = TensorBoardImages(name='test', key='x', write_each_epoch=True, num_images=40, nrow=9, padding=3, normalize=True, norm_range='tmp', scale_each=True, pad_value=1)

tboard.on_start(state)
tboard.on_step_validation(state)
Expand Down
14 changes: 8 additions & 6 deletions torchbearer/callbacks/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,30 @@ def __init__(self, validation_label_letter='v'):
super().__init__()
self.validation_label = validation_label_letter

def _step(self, state, letter, steps):
@staticmethod
def _step(state, letter, steps):
epoch_str = '{:d}/{:d}({:s}): '.format(state[torchbearer.EPOCH], state[torchbearer.MAX_EPOCHS], letter)
batch_str = '{:d}/{:d} '.format(state[torchbearer.BATCH], steps)
stats_str = ', '.join(['{0}:{1:.03g}'.format(key, value) for (key, value) in state[torchbearer.METRICS].items()])
print('\r' + epoch_str + batch_str + stats_str, end='')

def _end(self, state, letter):
@staticmethod
def _end(state, letter):
epoch_str = '{:d}/{:d}({:s}): '.format(state[torchbearer.EPOCH], state[torchbearer.MAX_EPOCHS], letter)
stats_str = ', '.join(['{0}:{1:.03g}'.format(key, value) for (key, value) in state[torchbearer.METRICS].items()])
print('\r' + epoch_str + stats_str)

def on_step_training(self, state):
self._step(state, 't', state[torchbearer.TRAIN_STEPS])
ConsolePrinter._step(state, 't', state[torchbearer.TRAIN_STEPS])

def on_end_training(self, state):
self._end(state, 't')
ConsolePrinter._end(state, 't')

def on_step_validation(self, state):
self._step(state, self.validation_label, state[torchbearer.VALIDATION_STEPS])
ConsolePrinter._step(state, self.validation_label, state[torchbearer.VALIDATION_STEPS])

def on_end_validation(self, state):
self._end(state, self.validation_label)
ConsolePrinter._end(state, self.validation_label)


class Tqdm(Callback):
Expand Down
8 changes: 4 additions & 4 deletions torchbearer/callbacks/tensor_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self, log_dir='./logs',
nrow=8,
padding=2,
normalize=False,
range=None,
norm_range=None,
scale_each=False,
pad_value=0):
"""Create TensorBoardImages callback which writes images from the given key to the given path. Full name of
Expand All @@ -132,7 +132,7 @@ def __init__(self, log_dir='./logs',
:param nrow: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param padding: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param normalize: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param range: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param norm_range: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param scale_each: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
:param pad_value: See `torchvision.utils.make_grid https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid`
"""
Expand All @@ -145,7 +145,7 @@ def __init__(self, log_dir='./logs',
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.range = range
self.norm_range = norm_range
self.scale_each = scale_each
self.pad_value = pad_value

Expand Down Expand Up @@ -182,7 +182,7 @@ def on_step_validation(self, state):
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.range,
range=self.norm_range,
scale_each=self.scale_each,
pad_value=self.pad_value
)
Expand Down
45 changes: 10 additions & 35 deletions torchbearer/callbacks/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,19 @@ def __init__(self, monitor='running_loss'):
super(TerminateOnNaN, self).__init__()
self._monitor = monitor

def _step_training(self, state):
value = state[torchbearer.METRICS][self._monitor]
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Batch %d: Invalid ' % (state[torchbearer.BATCH]) + self._monitor + ', terminating training')
state[torchbearer.STOP_TRAINING] = True
def _check(self, state):
if self._monitor in state[torchbearer.METRICS]:
value = state[torchbearer.METRICS][self._monitor]
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Invalid ' + self._monitor + ', terminating')
state[torchbearer.STOP_TRAINING] = True

def on_step_training(self, state):
if self._monitor in state[torchbearer.METRICS]:
self.on_step_training = lambda inner_state: self._step_training(inner_state)
return self._step_training(state)
else:
self.on_step_training = lambda inner_state: ...

def _end_epoch(self, state):
value = state[torchbearer.METRICS][self._monitor]
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Epoch %d: Invalid ' % (state[torchbearer.EPOCH]) + self._monitor + ', terminating')
state[torchbearer.STOP_TRAINING] = True
self._check(state)

def on_end_epoch(self, state):
if self._monitor in state[torchbearer.METRICS]:
self.on_end_epoch = lambda inner_state: self._end_epoch(inner_state)
return self._end_epoch(state)
else:
self.on_end_epoch = lambda inner_state: ...

def _step_validation(self, state):
value = state[torchbearer.METRICS][self._monitor]
if value is not None:
if math.isnan(value) or math.isinf(value):
print('Batch %d: Invalid ' % (state[torchbearer.BATCH]) + self._monitor + ', terminating validation')
state[torchbearer.STOP_TRAINING] = True
self._check(state)

def on_step_validation(self, state):
if self._monitor in state[torchbearer.METRICS]:
self.on_step_validation = lambda inner_state: self._step_validation(inner_state)
return self._step_validation(state)
else:
self.on_step_validation = lambda inner_state: ...
self._check(state)
29 changes: 14 additions & 15 deletions torchbearer/metrics/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _process_train(self, *args):
:return: The metric value.
"""
...
pass

@abstractmethod
def _step(self, cache):
Expand All @@ -53,7 +53,7 @@ def _step(self, cache):
:return: The new metric value.
"""
...
pass

def process_train(self, *args):
"""Add the current metric value to the cache and call '_step' is needed.
Expand Down Expand Up @@ -94,7 +94,8 @@ class RunningMean(RunningMetric):
def __init__(self, name, batch_size=50, step_size=10):
super().__init__(name, batch_size=batch_size, step_size=step_size)

def _process_train(self, data):
def _process_train(self, *args):
data = args[0]
return data.mean().item()

def _step(self, cache):
Expand All @@ -111,13 +112,14 @@ class Std(metrics.Metric):
def __init__(self, name):
super(Std, self).__init__(name)

def process(self, data):
def process(self, *args):
"""Compute values required for the std from the input.
:param data: The output of some previous call to :meth:`.Metric.process`.
:type data: torch.Tensor
:param args: The output of some previous call to :meth:`.Metric.process`.
:type args: torch.Tensor
"""
data = args[0]
self._sum += data.sum().item()
self._sum_sq += data.pow(2).sum().item()

Expand All @@ -126,11 +128,9 @@ def process(self, data):
else:
self._count += data.size(0)

def process_final(self, data):
def process_final(self, *args):
"""Compute and return the final standard deviation.
:param data: The output of some previous call to :meth:`.Metric.process_final`.
:type data: torch.Tensor
:return: The standard deviation of each observation since the last reset call.
"""
Expand Down Expand Up @@ -161,25 +161,24 @@ class Mean(metrics.Metric):
def __init__(self, name):
super(Mean, self).__init__(name)

def process(self, data):
def process(self, *args):
"""Add the input to the rolling sum.
:param data: The output of some previous call to :meth:`.Metric.process`.
:type data: torch.Tensor
:param args: The output of some previous call to :meth:`.Metric.process`.
:type args: torch.Tensor
"""
data = args[0]
self._sum += data.sum().item()

if data.size() == torch.Size([]):
self._count += 1
else:
self._count += data.size(0)

def process_final(self, data):
def process_final(self, *args):
"""Compute and return the mean of all metric values since the last call to reset.
:param data: The output of some previous call to :meth:`.Metric.process_final`.
:type data: torch.Tensor
:return: The mean of the metric values since the last call to reset.
"""
Expand Down

0 comments on commit cf24e5e

Please sign in to comment.