Skip to content

Commit

Permalink
Some code style changes (#267)
Browse files Browse the repository at this point in the history
* Some code style changes

* Fix bug in callback list
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 30, 2018
1 parent 3e46052 commit adcb24b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torchbearer/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self, callback_list):
for i in range(len(callback_list)):
callback = callback_list[i]

if type(callback) is CallbackList:
if isinstance(callback, CallbackList):
self.callback_list = self.callback_list + callback.callback_list
else:
self.callback_list.append(callback)
Expand Down
6 changes: 3 additions & 3 deletions torchbearer/callbacks/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def __init__(self, filepath='model.{epoch:02d}-{val_loss:.2f}.pt', pickle_module
super().__init__(filepath, pickle_module=pickle_module, pickle_protocol=pickle_protocol)
self.filepath = filepath

def on_end_epoch(self, model_state):
super().on_end_training(model_state)
self.save_checkpoint(model_state, overwrite_most_recent=True)
def on_end_epoch(self, state):
super().on_end_training(state)
self.save_checkpoint(state, overwrite_most_recent=True)


class Best(_Checkpointer):
Expand Down
4 changes: 2 additions & 2 deletions torchbearer/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def train_valid_splitter(x, y, split, shuffle=True):
:type y: torch.Tensor
:param split: Fraction of dataset to be used for validation
:type split: float
:param shuffle: If True randomize tensor order before splitting else do not randomize
:param shuffle: If True randomize tensor order before splitting else do not randomize
:type shuffle: bool
:return: Training and validation tensors (training data, training labels, validation data, validation labels)
:rtype: tuple
Expand All @@ -39,7 +39,7 @@ def get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True):
:type x: torch.Tensor
:param y: Label tensor for dataset
:type y: torch.Tensor
:param validation_data: Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
:param validation_data: Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
:type validation_data: (torch.Tensor, torch.Tensor)
:param validation_split: Fraction of dataset to be used for validation
:type validation_split: float
Expand Down
3 changes: 1 addition & 2 deletions torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validati
:return: The final state context dictionary
:rtype: dict[str,any]
"""

trainset, valset = torchbearer.cv_utils.get_train_valid_sets(x, y, validation_data, validation_split, shuffle=shuffle)
trainloader = DataLoader(trainset, batch_size, shuffle=shuffle, num_workers=workers)

Expand Down Expand Up @@ -531,7 +530,7 @@ def _update_device_and_dtype_from_args(main_state, *args, **kwargs):
:return: Updated main state dictionary
:rtype: dict[str,any]
"""
for key, val in kwargs.items():
for key, _ in kwargs.items():
if key == torchbearer.DATA_TYPE:
main_state[torchbearer.DATA_TYPE] = kwargs[torchbearer.DATA_TYPE]
elif torchbearer.DEVICE in kwargs:
Expand Down

0 comments on commit adcb24b

Please sign in to comment.