Skip to content

Commit

Permalink
Fix /update to_pyplot method (#640)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 4, 2019
1 parent 81731eb commit bd935a3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added BCPlus callback for between-class learning
- Added support for PyTorch 1.3
- Added a show flag to the `ImagingCallback.to_pyplot` method, set to false to stop it from calling `plt.show`
### Changed
- Changed the default behaviour of `ImagingCallback.to_pyplot` to turn off the axis
### Deprecated
### Removed
### Fixed
- Fixed a bug in imaging where passing a title to `to_pyplot` was not possible

## [0.5.0] - 2019-09-17
### Added
Expand Down
1 change: 1 addition & 0 deletions tests/callbacks/imaging/test_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_to_pyplot(self, plt):
plt.title.assert_called_once_with('test')

plt.imshow.assert_called_once_with(mock.mul().clamp().byte().permute().cpu().numpy())
plt.axis.assert_called_once_with('off')
self.assertTrue(plt.show.call_count == 1)

@patch('torchbearer.callbacks.tensor_board')
Expand Down
23 changes: 14 additions & 9 deletions torchbearer/callbacks/imaging/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ def handler(image, index, _):
return handler


def _to_pyplot(title=None):
def _to_pyplot(title=None, show=True):
import matplotlib.pyplot as plt

def handler(image, index, _):
ndarr = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
plt.imshow(ndarr)
if title is not None:
plt.title(title.format(index=str(index)))
plt.show()
plt.axis('off')

if show:
plt.show()

return handler

Expand Down Expand Up @@ -168,7 +171,7 @@ def with_handler(self, handler, index=None):
Args:
handler: A function of image and state which stores the given image in some way
index (int or list or None): if not None, only apply the handler on this index / list of indices
index (int or list or None): If not None, only apply the handler on this index / list of indices
Returns:
ImagingCallback: self
Expand All @@ -180,31 +183,33 @@ def to_file(self, filename, index=None):
"""Send images from this callback to the given file
Args:
filename (str): the filename to store the image to
index (int or list or None): if not None, only apply the handler on this index / list of indices
filename (str): The filename to store the image to
index (int or list or None): If not None, only apply the handler on this index / list of indices
Returns:
ImagingCallback: self
"""
return self.with_handler(_to_file(filename), index=index)

def to_pyplot(self, index=None):
def to_pyplot(self, title=None, show=True, index=None):
"""Show images from this callback with pyplot
Args:
index (int or list or None): if not None, only apply the handler on this index / list of indices
title (str or None): If not None, plt.title will be called with the given string
show (bool): If True (default), show will be called after each image is plotted
index (int or list or None): If not None, only apply the handler on this index / list of indices
Returns:
ImagingCallback: self
"""
return self.with_handler(_to_pyplot(), index=index)
return self.with_handler(_to_pyplot(title=title, show=show), index=index)

def to_state(self, keys, index=None):
"""Put images from this callback in state with the given key
Args:
keys (StateKey or list[StateKey]): The state key or keys to use for the images
index (int or list or None): if not None, only apply the handler on this index / list of indices
index (int or list or None): If not None, only apply the handler on this index / list of indices
Returns:
ImagingCallback: self
Expand Down

0 comments on commit bd935a3

Please sign in to comment.