Skip to content

Commit

Permalink
refactored AdditiveBlock and AffineBlock to AdditiveCoupling and Affi…
Browse files Browse the repository at this point in the history
…neCoupling #29
  • Loading branch information
silvandeleemput committed Dec 14, 2019
1 parent 1bdc5a7 commit c4af809
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
20 changes: 13 additions & 7 deletions memcnn/models/additive.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import warnings
import torch
import torch.nn as nn
import copy
from torch import set_grad_enabled
import numpy as np


class NonMemorySavingWarning(UserWarning):
pass


class AdditiveBlock(nn.Module):
def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):
class AdditiveCoupling(nn.Module):
def __init__(self, Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1):
"""The AdditiveBlock
Parameters
Expand All @@ -29,7 +26,7 @@ def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):
Switch between different Additive Operation implementations for inverse pass. Default = 1
"""
super(AdditiveBlock, self).__init__()
super(AdditiveCoupling, self).__init__()
# mirror the passed module, without parameter sharing...
if Gm is None:
Gm = copy.deepcopy(Fm)
Expand Down Expand Up @@ -81,6 +78,15 @@ def inverse(self, y):
return x


class AdditiveBlock(AdditiveCoupling):
def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):
warnings.warn("This class has been deprecated. Use the AdditiveCoupling class instead.",
DeprecationWarning)
super(AdditiveBlock, self).__init__(Fm=Fm, Gm=Gm,
implementation_fwd=implementation_fwd,
implementation_bwd=implementation_bwd)


class AdditiveBlockFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, Fm, Gm, *weights):
Expand Down
19 changes: 14 additions & 5 deletions memcnn/models/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def forward(self, x):
return scale, shift


class AffineBlock(nn.Module):
def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=1, implementation_bwd=1):
class AffineCoupling(nn.Module):
def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1):
"""The AffineBlock
Parameters
Expand All @@ -60,13 +60,13 @@ def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=1, implementati
s, t are respectively the scale and shift tensors for the affine coupling.
implementation_fwd : int
Switch between different Affine Operation implementations for forward pass. Default = 1
Switch between different Affine Operation implementations for forward pass. Default = -1
implementation_bwd : int
Switch between different Affine Operation implementations for inverse pass. Default = 1
Switch between different Affine Operation implementations for inverse pass. Default = -1
"""
super(AffineBlock, self).__init__()
super(AffineCoupling, self).__init__()
# mirror the passed module, without parameter sharing...
if Gm is None:
Gm = copy.deepcopy(Fm)
Expand Down Expand Up @@ -119,6 +119,15 @@ def inverse(self, y):
return x


class AffineBlock(AffineCoupling):
def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):
warnings.warn("This class has been deprecated. Use the AffineCoupling class instead.",
DeprecationWarning)
super(AffineBlock, self).__init__(Fm=Fm, Gm=Gm,
implementation_fwd=implementation_fwd,
implementation_bwd=implementation_bwd)


class AffineBlockFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, Fm, Gm, *weights):
Expand Down
12 changes: 6 additions & 6 deletions memcnn/models/revop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
import torch.nn as nn
import warnings
from memcnn.models.additive import AdditiveBlock
from memcnn.models.affine import AffineBlock
from memcnn.models.additive import AdditiveCoupling
from memcnn.models.affine import AffineCoupling
from memcnn.models.utils import pytorch_version_one_and_above


Expand Down Expand Up @@ -246,11 +246,11 @@ def __init__(self, Fm, Gm=None, coupling='additive', keep_input=False, keep_inpu

def create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None):
if coupling == 'additive':
fn = AdditiveBlock(Fm, Gm,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
fn = AdditiveCoupling(Fm, Gm,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
elif coupling == 'affine':
fn = AffineBlock(Fm, Gm, adapter=adapter,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
fn = AffineCoupling(Fm, Gm, adapter=adapter,
implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
else:
raise NotImplementedError('Unknown coupling method: %s' % coupling)
return fn
Expand Down
12 changes: 6 additions & 6 deletions memcnn/models/tests/test_revop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch.nn
import numpy as np
import copy
from memcnn.models.affine import AffineAdapterNaive, AffineAdapterSigmoid, AffineBlock
from memcnn.models.affine import AffineAdapterNaive, AffineAdapterSigmoid, AffineCoupling
from memcnn import ReversibleBlock
from memcnn.models.revop import ReversibleModule, create_coupling, is_invertible_module
from memcnn.models.additive import AdditiveBlock
from memcnn.models.additive import AdditiveCoupling


def set_seeds(seed):
Expand Down Expand Up @@ -60,7 +60,7 @@ def is_memory_cleared(var, isclear, shape):
def test_is_invertible_module():
X = torch.zeros(1, 10, 10, 10)
assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), X)
fn = AdditiveBlock(SubModule(),implementation_bwd=-1, implementation_fwd=-1)
fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1)
assert is_invertible_module(fn, X)
class FakeInverse(torch.nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -146,9 +146,9 @@ def test_input_output_invertible_function_share_tensor():


@pytest.mark.parametrize('fn', [
AdditiveBlock(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1),
AffineBlock(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive),
AffineBlock(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid),
AdditiveCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1),
AffineCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive),
AffineCoupling(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid),
MultiplicationInverse()
])
@pytest.mark.parametrize('bwd', [False, True])
Expand Down

0 comments on commit c4af809

Please sign in to comment.