-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update docs * Update changelog.md * Add to docs * Add citation * Formatting * Formatting * Remove utility class from docs * Rename CutOut -> Cutout * Update tests
- Loading branch information
1 parent
3255f09
commit 97982a7
Showing
6 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from unittest import TestCase | ||
|
||
import torch | ||
import numpy as np | ||
|
||
import torchbearer | ||
from torchbearer.callbacks.cutout import Cutout, RandomErase | ||
|
||
|
||
class TestCutOut(TestCase): | ||
def test_cutout(self): | ||
random_image = torch.rand(2, 3, 100, 100) | ||
co = Cutout(1, 10, seed=7) | ||
state = {torchbearer.X: random_image} | ||
co.on_sample(state) | ||
reg_img = state[torchbearer.X].view(-1) | ||
|
||
x = [25, 67] | ||
y = [47, 68] | ||
|
||
known_cut = random_image.clone().numpy() | ||
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0 | ||
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0 | ||
known_cut = torch.from_numpy(known_cut) | ||
known_cut = known_cut.view(-1) | ||
|
||
diff = (torch.abs(known_cut-reg_img) > 1e-4).any() | ||
self.assertTrue(diff.item() == 0) | ||
|
||
def test_cutout_constant(self): | ||
random_image = torch.rand(2, 3, 100, 100) | ||
co = Cutout(1, 10, constant=0.5, seed=7) | ||
state = {torchbearer.X: random_image} | ||
co.on_sample(state) | ||
reg_img = state[torchbearer.X].view(-1) | ||
|
||
x = [25, 67] | ||
y = [47, 68] | ||
|
||
known_cut = random_image.clone().numpy() | ||
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0.5 | ||
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0.5 | ||
known_cut = torch.from_numpy(known_cut) | ||
known_cut = known_cut.view(-1) | ||
|
||
diff = (torch.abs(known_cut-reg_img) > 1e-4).any() | ||
self.assertTrue(diff.item() == 0) | ||
|
||
# TODO: Find a better test for this | ||
def test_random_erase(self): | ||
random_image = torch.rand(2, 3, 100, 100) | ||
co = RandomErase(1, 10, seed=7) | ||
state = {torchbearer.X: random_image} | ||
co.on_sample(state) | ||
reg_img = state[torchbearer.X].view(-1) | ||
|
||
x = [25, 67] | ||
y = [47, 68] | ||
|
||
known_cut = random_image.clone().numpy() | ||
known_cut[0, :, y[0]-10//2:y[0]+10//2, x[0]-10//2:x[0]+10//2] = 0 | ||
known_cut[1, :, y[1]-10//2:y[1]+10//2, x[1]-10//2:x[1]+10//2] = 0 | ||
known_cut = torch.from_numpy(known_cut) | ||
|
||
known_cut = known_cut.view(-1) | ||
masked_pix = known_cut == 0 | ||
|
||
diff = (torch.abs(known_cut[masked_pix]-reg_img[masked_pix]) > 1e-4).any() | ||
self.assertTrue(diff.item() > 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torchbearer | ||
from torchbearer import Callback | ||
import torch | ||
import numpy as np | ||
from torchbearer.bases import cite | ||
|
||
cutout = """ | ||
@article{devries2017improved, | ||
title={Improved regularization of convolutional neural networks with Cutout}, | ||
author={DeVries, Terrance and Taylor, Graham W}, | ||
journal={arXiv preprint arXiv:1708.04552}, | ||
year={2017} | ||
} | ||
""" | ||
|
||
random_erase = """ | ||
@article{zhong2017random, | ||
title={Random erasing data augmentation}, | ||
author={Zhong, Zhun and Zheng, Liang and Kang, Guoliang and Li, Shaozi and Yang, Yi}, | ||
journal={arXiv preprint arXiv:1708.04896}, | ||
year={2017} | ||
} | ||
""" | ||
|
||
|
||
@cite(cutout) | ||
class Cutout(Callback): | ||
""" Cutout callback which randomly masks out patches of image data. Implementation a modified version of the code | ||
found `here <https://github.com/uoguelph-mlrg/Cutout/blob/master/util/Cutout.py>`_. | ||
Example:: | ||
>>> from torchbearer import Trial | ||
>>> from torchbearer.callbacks import Cutout | ||
# Example Trial which does Cutout regularisation | ||
>>> cutout = Cutout(1, 10) | ||
>>> trial = Trial(None, callbacks=[cutout], metrics=['acc']) | ||
Args: | ||
n_holes (int): Number of patches to cut out of each image. | ||
length (int): The length (in pixels) of each square patch. | ||
constant (float): Constant value for each square patch | ||
seed: Random seed | ||
""" | ||
def __init__(self, n_holes, length, constant=0., seed=None): | ||
super(Cutout, self).__init__() | ||
self.cutter = BatchCutout(n_holes, length, constant=constant, seed=seed) | ||
|
||
def on_sample(self, state): | ||
super(Cutout, self).on_sample(state) | ||
state[torchbearer.X] = self.cutter(state[torchbearer.X]) | ||
|
||
|
||
@cite(random_erase) | ||
class RandomErase(Callback): | ||
""" Random erase callback which replaces random patches of image data with random noise. | ||
Implementation a modified version of the cutout code found | ||
`here <https://github.com/uoguelph-mlrg/Cutout/blob/master/util/Cutout.py>`_. | ||
Example:: | ||
>>> from torchbearer import Trial | ||
>>> from torchbearer.callbacks import RandomErase | ||
# Example Trial which does Cutout regularisation | ||
>>> erase = RandomErase(1, 10) | ||
>>> trial = Trial(None, callbacks=[erase], metrics=['acc']) | ||
Args: | ||
n_holes (int): Number of patches to cut out of each image. | ||
length (int): The length (in pixels) of each square patch. | ||
seed: Random seed | ||
""" | ||
def __init__(self, n_holes, length, seed=None): | ||
super(RandomErase, self).__init__() | ||
self.cutter = BatchCutout(n_holes, length, seed=seed, random_erase=True) | ||
|
||
def on_sample(self, state): | ||
super(RandomErase, self).on_sample(state) | ||
state[torchbearer.X] = self.cutter(state[torchbearer.X]) | ||
|
||
|
||
class BatchCutout(object): | ||
"""Randomly mask out one or more patches from a batch of images. | ||
Args: | ||
n_holes (int): Number of patches to cut out of each image. | ||
length (int): The length (in pixels) of each square patch. | ||
seed: Random seed | ||
""" | ||
def __init__(self, n_holes, length, constant=0., random_erase=False, seed=None): | ||
self.n_holes = n_holes | ||
self.length = length | ||
self.random_erasing = random_erase | ||
self.constant = constant | ||
np.random.seed(seed) | ||
|
||
def __call__(self, img): | ||
""" | ||
Args: | ||
img (Tensor): Tensor image of size (B, C, H, W). | ||
Returns: | ||
Tensor: Image with n_holes of dimension length x length cut out of it. | ||
""" | ||
b = img.size(0) | ||
c = img.size(1) | ||
h = img.size(-2) | ||
w = img.size(-1) | ||
|
||
mask = np.ones((b, h, w), np.float32) | ||
|
||
for n in range(self.n_holes): | ||
y = np.random.randint(h, size=b) | ||
x = np.random.randint(w, size=b) | ||
|
||
y1 = np.clip(y - self.length // 2, 0, h) | ||
y2 = np.clip(y + self.length // 2, 0, h) | ||
x1 = np.clip(x - self.length // 2, 0, w) | ||
x2 = np.clip(x + self.length // 2, 0, w) | ||
|
||
for batch in range(b): | ||
mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0 | ||
|
||
mask = torch.from_numpy(mask).unsqueeze(1).repeat(1, c, 1, 1) | ||
|
||
erase_locations = mask == 0 | ||
|
||
if self.random_erasing: | ||
random = torch.from_numpy(np.random.rand(*img.shape)).to(torch.float) | ||
else: | ||
random = torch.from_numpy(np.ones_like(img)).to(torch.float) * self.constant | ||
|
||
img[erase_locations] = random[erase_locations] | ||
|
||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters