Skip to content
Permalink
Browse files

Fix PyCM callback (#649)

  • Loading branch information
ethanwharris committed Dec 3, 2019
1 parent 1ee52ce commit b6105c36969b83d9126b459d1227196cb5915c1d
Showing with 18 additions and 24 deletions.
  1. +2 −0 CHANGELOG.md
  2. +4 −7 tests/callbacks/test_pycm.py
  3. +12 −17 torchbearer/callbacks/pycm.py
@@ -7,9 +7,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added option to use mixup loss with cutmix
### Changed
- Changed PyCM save methods to use `*args` and `**kwargs`
### Deprecated
### Removed
### Fixed
- Fixed a bug where the PyCM callback would fail when saving

## [0.5.1] - 2019-11-06
### Added
@@ -120,31 +120,28 @@ def test_to_file(self):
cm = MagicMock()
callback._handlers[0](cm, {torchbearer.EPOCH: 1})

cm.save_stat.assert_called_once_with('test 1', address=True, overall_param=None, class_param=None,
class_name=None)
cm.save_stat.assert_called_once_with('test 1')

callback = PyCM()
callback.to_html_file('test {epoch}')
cm = MagicMock()
callback._handlers[0](cm, {torchbearer.EPOCH: 2})

cm.save_html.assert_called_once_with('test 2', address=True, overall_param=None, class_param=None,
class_name=None, color=(0, 0, 0), normalize=False)
cm.save_html.assert_called_once_with('test 2')

callback = PyCM()
callback.to_csv_file('test {epoch}')
cm = MagicMock()
callback._handlers[0](cm, {torchbearer.EPOCH: 3})

cm.save_csv.assert_called_once_with('test 3', address=True, overall_param=None, class_param=None,
class_name=None, matrix_save=True, normalize=False)
cm.save_csv.assert_called_once_with('test 3')

callback = PyCM()
callback.to_obj_file('test {epoch}')
cm = MagicMock()
callback._handlers[0](cm, {torchbearer.EPOCH: 4})

cm.save_obj.assert_called_once_with('test 4', address=True, save_stat=False, save_vector=True)
cm.save_obj.assert_called_once_with('test 4')

@patch('torchbearer.callbacks.pycm._to_pyplot')
def test_to_pyplot(self, mock_to_pyplot):
@@ -181,77 +181,72 @@ def to_console(self):
"""
return self.with_handler(lambda cm, _: print(cm))

def to_pycm_file(self, filename, address=True, overall_param=None, class_param=None, class_name=None):
def to_pycm_file(self, filename, *args, **kwargs):
"""Save `ConfusionMatrix` objects from this callback to `.pycm` files
Args:
filename (str): The name of the file, will be formatted with state to create unique filenames if desired
See:
`PyCM Source <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
`PyCM Source (save_stat) <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
Returns:
PyCM: self
"""
def handler(cm, state):
string_state = {str(key): state[key] for key in state.keys()}
cm.save_stat(filename.format(**string_state), address=address, overall_param=overall_param,
class_param=class_param, class_name=class_name)
cm.save_stat(filename.format(**string_state), *args, **kwargs)
return self.with_handler(handler)

def to_html_file(self, filename, address=True, overall_param=None, class_param=None, class_name=None,
color=(0, 0, 0), normalize=False):
def to_html_file(self, filename, *args, **kwargs):
"""Save `ConfusionMatrix` objects from this callback to `.html` files
Args:
filename (str): The name of the file, will be formatted with state to create unique filenames if desired
See:
`PyCM Source <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
`PyCM Source (save_html) <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
Returns:
PyCM: self
"""
def handler(cm, state):
string_state = {str(key): state[key] for key in state.keys()}
cm.save_html(filename.format(**string_state), address=address, overall_param=overall_param,
class_param=class_param, class_name=class_name, color=color, normalize=normalize)
cm.save_html(filename.format(**string_state), *args, **kwargs)
return self.with_handler(handler)

def to_csv_file(self, filename, address=True, overall_param=None, class_param=None, class_name=None,
matrix_save=True, normalize=False):
def to_csv_file(self, filename, *args, **kwargs):
"""Save `ConfusionMatrix` objects from this callback to `.csv` files
Args:
filename (str): The name of the file, will be formatted with state to create unique filenames if desired
See:
`PyCM Source <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
`PyCM Source (save_csv) <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
Returns:
PyCM: self
"""
def handler(cm, state):
string_state = {str(key): state[key] for key in state.keys()}
cm.save_csv(filename.format(**string_state), address=address, overall_param=overall_param,
class_param=class_param, class_name=class_name, matrix_save=matrix_save, normalize=normalize)
cm.save_csv(filename.format(**string_state), *args, **kwargs)
return self.with_handler(handler)

def to_obj_file(self, filename, address=True, save_stat=False, save_vector=True):
def to_obj_file(self, filename, *args, **kwargs):
"""Save `ConfusionMatrix` objects from this callback to `.obj` files
Args:
filename (str): The name of the file, will be formatted with state to create unique filenames if desired
See:
`PyCM Source <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
`PyCM Source (save_obj) <https://github.com/sepandhaghighi/pycm/blob/master/pycm/pycm_obj.py>`_
Returns:
PyCM: self
"""
def handler(cm, state):
string_state = {str(key): state[key] for key in state.keys()}
cm.save_obj(filename.format(**string_state), address=address, save_stat=save_stat, save_vector=save_vector)
cm.save_obj(filename.format(**string_state), *args, **kwargs)
return self.with_handler(handler)

def to_pyplot(self, normalize=False, title='Confusion matrix', cmap=None):

0 comments on commit b6105c3

Please sign in to comment.
You can’t perform that action at this time.