Skip to content

Commit

Permalink
Feature/cutout callback (#572)
Browse files Browse the repository at this point in the history
* Update docs

* Update changelog.md

* Add to docs

* Add citation

* Formatting

* Formatting

* Remove utility class from docs

* Rename CutOut -> Cutout

* Update tests
  • Loading branch information
MattPainter01 committed Jun 14, 2019
1 parent 3255f09 commit 97982a7
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added ``with_loader`` trial method that allows running of custom batch loaders
- Added a Mock Model which is set when None is passed as the model to a Trial. Mock Model always returns None.
- Added `__call__(state)` to `StateKey` so that they can now be used as losses
- Added a callback to do cutout regularisation
### Changed
### Deprecated
### Removed
Expand Down
12 changes: 12 additions & 0 deletions docs/code/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ Weight / Bias Initialisation
:members:
:undoc-members:

Regularisers
------------------------------------

.. autoclass:: torchbearer.callbacks.cutout.Cutout
:members:
:undoc-members:

.. autoclass:: torchbearer.callbacks.cutout.RandomErase
:members:
:undoc-members:


Decorators
------------------------------------

Expand Down
69 changes: 69 additions & 0 deletions tests/callbacks/test_cutout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest import TestCase

import torch
import numpy as np

import torchbearer
from torchbearer.callbacks.cutout import Cutout, RandomErase


class TestCutOut(TestCase):
def test_cutout(self):
random_image = torch.rand(2, 3, 100, 100)
co = Cutout(1, 10, seed=7)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)

x = [25, 67]
y = [47, 68]

known_cut = random_image.clone().numpy()
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0
known_cut = torch.from_numpy(known_cut)
known_cut = known_cut.view(-1)

diff = (torch.abs(known_cut-reg_img) > 1e-4).any()
self.assertTrue(diff.item() == 0)

def test_cutout_constant(self):
random_image = torch.rand(2, 3, 100, 100)
co = Cutout(1, 10, constant=0.5, seed=7)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)

x = [25, 67]
y = [47, 68]

known_cut = random_image.clone().numpy()
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0.5
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0.5
known_cut = torch.from_numpy(known_cut)
known_cut = known_cut.view(-1)

diff = (torch.abs(known_cut-reg_img) > 1e-4).any()
self.assertTrue(diff.item() == 0)

# TODO: Find a better test for this
def test_random_erase(self):
random_image = torch.rand(2, 3, 100, 100)
co = RandomErase(1, 10, seed=7)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)

x = [25, 67]
y = [47, 68]

known_cut = random_image.clone().numpy()
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0
known_cut = torch.from_numpy(known_cut)

known_cut = known_cut.view(-1)
masked_pix = known_cut == 0

diff = (torch.abs(known_cut[masked_pix]-reg_img[masked_pix]) > 1e-4).any()
self.assertTrue(diff.item() > 0)
1 change: 1 addition & 0 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torchbearer import Callback
from .callbacks import *
from .cutout import Cutout, RandomErase
from .lr_finder import CyclicLR
from .lsuv import LSUV
from .checkpointers import *
Expand Down
137 changes: 137 additions & 0 deletions torchbearer/callbacks/cutout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torchbearer
from torchbearer import Callback
import torch
import numpy as np
from torchbearer.bases import cite

cutout = """
@article{devries2017improved,
title={Improved regularization of convolutional neural networks with Cutout},
author={DeVries, Terrance and Taylor, Graham W},
journal={arXiv preprint arXiv:1708.04552},
year={2017}
}
"""

random_erase = """
@article{zhong2017random,
title={Random erasing data augmentation},
author={Zhong, Zhun and Zheng, Liang and Kang, Guoliang and Li, Shaozi and Yang, Yi},
journal={arXiv preprint arXiv:1708.04896},
year={2017}
}
"""


@cite(cutout)
class Cutout(Callback):
""" Cutout callback which randomly masks out patches of image data. Implementation a modified version of the code
found `here <https://github.com/uoguelph-mlrg/Cutout/blob/master/util/Cutout.py>`_.
Example::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import Cutout
# Example Trial which does Cutout regularisation
>>> cutout = Cutout(1, 10)
>>> trial = Trial(None, callbacks=[cutout], metrics=['acc'])
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
constant (float): Constant value for each square patch
seed: Random seed
"""
def __init__(self, n_holes, length, constant=0., seed=None):
super(Cutout, self).__init__()
self.cutter = BatchCutout(n_holes, length, constant=constant, seed=seed)

def on_sample(self, state):
super(Cutout, self).on_sample(state)
state[torchbearer.X] = self.cutter(state[torchbearer.X])


@cite(random_erase)
class RandomErase(Callback):
""" Random erase callback which replaces random patches of image data with random noise.
Implementation a modified version of the cutout code found
`here <https://github.com/uoguelph-mlrg/Cutout/blob/master/util/Cutout.py>`_.
Example::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import RandomErase
# Example Trial which does Cutout regularisation
>>> erase = RandomErase(1, 10)
>>> trial = Trial(None, callbacks=[erase], metrics=['acc'])
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
seed: Random seed
"""
def __init__(self, n_holes, length, seed=None):
super(RandomErase, self).__init__()
self.cutter = BatchCutout(n_holes, length, seed=seed, random_erase=True)

def on_sample(self, state):
super(RandomErase, self).on_sample(state)
state[torchbearer.X] = self.cutter(state[torchbearer.X])


class BatchCutout(object):
"""Randomly mask out one or more patches from a batch of images.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
seed: Random seed
"""
def __init__(self, n_holes, length, constant=0., random_erase=False, seed=None):
self.n_holes = n_holes
self.length = length
self.random_erasing = random_erase
self.constant = constant
np.random.seed(seed)

def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (B, C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
b = img.size(0)
c = img.size(1)
h = img.size(-2)
w = img.size(-1)

mask = np.ones((b, h, w), np.float32)

for n in range(self.n_holes):
y = np.random.randint(h, size=b)
x = np.random.randint(w, size=b)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

for batch in range(b):
mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0

mask = torch.from_numpy(mask).unsqueeze(1).repeat(1, c, 1, 1)

erase_locations = mask == 0

if self.random_erasing:
random = torch.from_numpy(np.random.rand(*img.shape)).to(torch.float)
else:
random = torch.from_numpy(np.ones_like(img)).to(torch.float) * self.constant

img[erase_locations] = random[erase_locations]

return img
2 changes: 1 addition & 1 deletion torchbearer/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class EarlyStopping(Callback):
"""Callback to stop training when a monitored quantity has stopped improving.
Example: ::
Example::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import EarlyStopping
Expand Down

0 comments on commit 97982a7

Please sign in to comment.