Skip to content

Commit

Permalink
Feature/bc (#637)
Browse files Browse the repository at this point in the history
* Add BCPlus

* Fix torch 1.0.0 error, update docs
  • Loading branch information
ethanwharris committed Sep 30, 2019
1 parent d43d6a8 commit 4330007
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 55 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Expand Up @@ -3,6 +3,14 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added BCPlus callback for between-class learning
### Changed
### Deprecated
### Removed
### Fixed

## [0.5.0] - 2019-09-17
### Added
- Added PyTorch CyclicLR scheduler
Expand Down
122 changes: 87 additions & 35 deletions docs/_static/notebooks/regularisers.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/code/callbacks.rst
Expand Up @@ -123,6 +123,10 @@ Regularisers
:members:
:undoc-members:

.. autoclass:: torchbearer.callbacks.between_class.BCPlus
:members:
:undoc-members:

.. autoclass:: torchbearer.callbacks.sample_pairing.SamplePairing
:members:
:undoc-members:
Expand Down
70 changes: 70 additions & 0 deletions tests/callbacks/test_between_class.py
@@ -0,0 +1,70 @@
from unittest import TestCase

import torch

import torchbearer
from torchbearer.callbacks import BCPlus


class TestBCPlus(TestCase):
def test_on_val(self):
bcplus = BCPlus(classes=4)
state = {torchbearer.TARGET: torch.tensor([1, 3, 2])}
bcplus.on_sample_validation(state)
self.assertTrue((state[torchbearer.TARGET] -
torch.tensor([[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]]).float()).abs().lt(1e-4).all())

bcplus = BCPlus(classes=4)
state = {torchbearer.TARGET: torch.tensor([[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]])}
bcplus.on_sample_validation(state)
self.assertTrue((state[torchbearer.TARGET] -
torch.tensor([[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]]).float()).abs().lt(1e-4).all())

def test_bc_loss(self):
prediction = torch.tensor([[10.0, 0.01]])
target = torch.tensor([[0., 0.8]])
loss = BCPlus.bc_loss({torchbearer.PREDICTION: prediction, torchbearer.TARGET: target})
self.assertTrue((loss - 7.81).abs().le(1e-2).all())

def test_sample_targets(self):
# Test mixup
bcplus = BCPlus(classes=4, mixup_loss=True)
state = {torchbearer.INPUT: torch.zeros(3, 10, 10), torchbearer.TARGET: torch.tensor([1, 3, 2]), torchbearer.DEVICE: 'cpu'}
bcplus.on_sample(state)

self.assertTrue(torchbearer.MIXUP_LAMBDA in state)
self.assertTrue(torchbearer.MIXUP_PERMUTATION in state)
self.assertTrue(len(state[torchbearer.TARGET]) == 2)

# Test bcplus
bcplus = BCPlus(classes=4)
state = {torchbearer.INPUT: torch.zeros(3, 10, 10), torchbearer.TARGET: torch.tensor([1, 3, 2]),
torchbearer.DEVICE: 'cpu'}
bcplus.on_sample(state)

self.assertTrue(state[torchbearer.TARGET].dim() == 2)
self.assertTrue(not (torchbearer.MIXUP_PERMUTATION in state))

def test_sample_inputs(self):
torch.manual_seed(7)

batch = torch.tensor([[
[0.1, 0.5, 0.6],
[0.8, 0.6, 0.5],
[0.2, 0.4, 0.7]
]])
target = torch.tensor([1])
state = {torchbearer.INPUT: batch, torchbearer.TARGET: target, torchbearer.DEVICE: 'cpu'}

bcplus = BCPlus(classes=4)
bcplus.on_sample(state)

lam = torch.ones(1) * 0.2649

self.assertTrue(((state[torchbearer.INPUT] * (lam.pow(2) + (1 - lam).pow(2)).sqrt()) - (batch - batch.mean())).abs().le(1e-4).all())
15 changes: 10 additions & 5 deletions tests/callbacks/test_cutout.py
Expand Up @@ -9,7 +9,8 @@
class TestCutOut(TestCase):
def test_cutout(self):
random_image = torch.rand(2, 3, 100, 100)
co = Cutout(1, 10, seed=7)
torch.manual_seed(7)
co = Cutout(1, 10)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)
Expand All @@ -28,7 +29,8 @@ def test_cutout(self):

def test_cutout_constant(self):
random_image = torch.rand(2, 3, 100, 100)
co = Cutout(1, 10, constant=0.5, seed=7)
torch.manual_seed(7)
co = Cutout(1, 10, constant=0.5)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)
Expand All @@ -48,7 +50,8 @@ def test_cutout_constant(self):
# TODO: Find a better test for this
def test_random_erase(self):
random_image = torch.rand(2, 3, 100, 100)
co = RandomErase(1, 10, seed=7)
torch.manual_seed(7)
co = RandomErase(1, 10)
state = {torchbearer.X: random_image}
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)
Expand All @@ -70,7 +73,8 @@ def test_random_erase(self):
def test_cutmix(self):
random_image = torch.rand(5, 3, 100, 100)
state = {torchbearer.X: random_image, torchbearer.Y_TRUE: torch.randint(10, (5,)).long(), torchbearer.DEVICE: 'cpu'}
co = CutMix(0.25, classes=10, seed=7)
torch.manual_seed(7)
co = CutMix(0.25, classes=10)
co.on_sample(state)
reg_img = state[torchbearer.X].view(-1)

Expand All @@ -94,7 +98,8 @@ def test_cutmix(self):

def test_cutmix_targets(self):
random_image = torch.rand(2, 3, 100, 100)
co = CutMix(1.0, classes=4, seed=7)
torch.manual_seed(7)
co = CutMix(1.0, classes=4)
target = torch.tensor([
[0., 1., 0., 0.],
[0., 0., 0., 1.]
Expand Down
1 change: 1 addition & 0 deletions torchbearer/callbacks/__init__.py
Expand Up @@ -21,3 +21,4 @@
from .mixup import Mixup, MixupAcc
from .sample_pairing import SamplePairing
from .label_smoothing import LabelSmoothingRegularisation
from .between_class import BCPlus
109 changes: 109 additions & 0 deletions torchbearer/callbacks/between_class.py
@@ -0,0 +1,109 @@
import torchbearer
from torchbearer import Callback
import torch
import torch.nn.functional as F
from torch.distributions import Beta

from torchbearer.bases import cite

bc = """
@inproceedings{tokozume2018between,
title={Between-class learning for image classification},
author={Tokozume, Yuji and Ushiku, Yoshitaka and Harada, Tatsuya},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={5486--5494},
year={2018}
}
"""


@cite(bc)
class BCPlus(Callback):
"""BC+ callback which mixes images by treating them as waveforms. For standard BC, see :class:`.Mixup`.
This callback can optionally convert labels to one hot before combining them according to the lambda parameters,
sampled from a beta distribution, use alpha=1 to replicate the paper. Use with :meth:`BCPlus.bc_loss` or set
`mixup_loss = True` and use :meth:`.Mixup.mixup_loss`.
.. note::
This callback first sets all images to have zero mean. Consider adding an offset (e.g. 0.5) back before
visualising.
Example: ::
>>> from torchbearer import Trial
>>> from torchbearer.callbacks import BCPlus
# Example Trial which does BCPlus regularisation
>>> bcplus = BCPlus(classes=10)
>>> trial = Trial(None, criterion=BCPlus.bc_loss, callbacks=[bcplus], metrics=['acc'])
Args:
mixup_loss (bool): If True, the lambda and targets will be stored for use with the mixup loss function.
alpha (float): The alpha value for the beta distribution.
classes (int): The number of classes for conversion to one hot.
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored and correctly normalised
- :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
"""

def __init__(self, mixup_loss=False, alpha=1, classes=-1):
super(BCPlus, self).__init__()
self.mixup_loss = mixup_loss
self.classes = classes
self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)]))

@staticmethod
def bc_loss(state):
"""The KL divergence between the outputs of the model and the ratio labels. Model ouputs should be un-normalised
logits as this function performs a log_softmax.
Args:
state: The current :class:`Trial` state.
"""
prediction, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]

entropy = - (target[target.nonzero().split(1, dim=1)] * target[target.nonzero().split(1, dim=1)].log()).sum()
cross = - (target * F.log_softmax(prediction, dim=1)).sum()

return (cross - entropy) / prediction.size(0)

def _to_one_hot(self, target):
if target.dim() == 1:
target = target.unsqueeze(1)
one_hot = torch.zeros_like(target).repeat(1, self.classes)
one_hot.scatter_(1, target, 1)
return one_hot
return target.float()

def on_sample(self, state):
super(BCPlus, self).on_sample(state)

lam = self.dist.sample().to(state[torchbearer.DEVICE])

permutation = torch.randperm(state[torchbearer.X].size(0))

batch1 = state[torchbearer.X]
batch1 = batch1 - batch1.view(batch1.size(0), -1).mean(1, keepdim=True).view(*tuple([batch1.size(0)] + [1] * (batch1.dim() - 1)))
g1 = batch1.view(batch1.size(0), -1).std(1, keepdim=True).view(*tuple([batch1.size(0)] + [1] * (batch1.dim() - 1)))

batch2 = batch1[permutation]
g2 = g1[permutation]

p = 1. / (1 + ((g1 / g2) * ((1 - lam) / lam)))

state[torchbearer.X] = (batch1 * p + batch2 * (1 - p)) / (p.pow(2) + (1 - p).pow(2)).sqrt()

if not self.mixup_loss:
target = self._to_one_hot(state[torchbearer.TARGET]).float()
state[torchbearer.Y_TRUE] = lam * target + (1 - lam) * target[permutation]
else:
state[torchbearer.MIXUP_LAMBDA] = lam
state[torchbearer.MIXUP_PERMUTATION] = permutation
state[torchbearer.Y_TRUE] = (state[torchbearer.Y_TRUE], state[torchbearer.Y_TRUE][state[torchbearer.MIXUP_PERMUTATION]])

def on_sample_validation(self, state):
super(BCPlus, self).on_sample_validation(state)
if not self.mixup_loss:
state[torchbearer.TARGET] = self._to_one_hot(state[torchbearer.TARGET]).float()
16 changes: 3 additions & 13 deletions torchbearer/callbacks/cutout.py
Expand Up @@ -52,16 +52,13 @@ class Cutout(Callback):
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
constant (float): Constant value for each square patch
seed: Random seed
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
"""
def __init__(self, n_holes, length, constant=0., seed=None):
def __init__(self, n_holes, length, constant=0.):
super(Cutout, self).__init__()
self.constant = constant
if seed is not None:
torch.manual_seed(seed)
self.cutter = BatchCutout(n_holes, length, length)

def on_sample(self, state):
Expand Down Expand Up @@ -90,15 +87,12 @@ class RandomErase(Callback):
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
seed: Random seed
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
"""
def __init__(self, n_holes, length, seed=None):
def __init__(self, n_holes, length):
super(RandomErase, self).__init__()
if seed is not None:
torch.manual_seed(seed)
self.cutter = BatchCutout(n_holes, length, length)

def on_sample(self, state):
Expand Down Expand Up @@ -127,17 +121,14 @@ class CutMix(Callback):
Args:
alpha (float): The alpha value for the beta distribution.
classes (int): The number of classes for conversion to one hot.
seed: Random seed
State Requirements:
- :attr:`torchbearer.state.X`: State should have the current data stored
- :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
"""
def __init__(self, alpha, classes=-1, seed=None):
def __init__(self, alpha, classes=-1):
super(CutMix, self).__init__()
self.classes = classes
if seed is not None:
torch.manual_seed(seed)
self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)]))

def _to_one_hot(self, target):
Expand Down Expand Up @@ -176,7 +167,6 @@ class BatchCutout(object):
n_holes (int): Number of patches to cut out of each image.
width (int): The width (in pixels) of each square patch.
height (int): The height (in pixels) of each square patch.
seed: Random seed
"""
def __init__(self, n_holes, width, height):
self.n_holes = n_holes
Expand Down
2 changes: 1 addition & 1 deletion torchbearer/callbacks/mixup.py
Expand Up @@ -66,7 +66,7 @@ class Mixup(Callback):
# Example Trial which does Mixup regularisation
>>> mixup = Mixup(0.9)
>>> trial = Trial(None, criterion=Mixup.loss, callbacks=[mixup], metrics=['acc'])
>>> trial = Trial(None, criterion=Mixup.mixup_loss, callbacks=[mixup], metrics=['acc'])
Args:
alpha (float): The alpha value to use in the beta distribution.
Expand Down
2 changes: 1 addition & 1 deletion torchbearer/version.py
@@ -1 +1 @@
__version__ = '0.5.0'
__version__ = '0.5.1.dev'

0 comments on commit 4330007

Please sign in to comment.