In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
import os
os.sys.path.insert(0, '/home/schirrmr/code/memcnn/')

import unittest
import torch
import torch.nn
from torch import nn
from torch.autograd import Variable
import memcnn.models.revop as revop
import numpy as np
import copy

## Parts of original test

In [None]:
dims = (2, 10, 8, 8)
data = np.random.random(dims).astype(np.float32)
target_data = np.random.random(dims).astype(np.float32)

class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.bn = torch.nn.BatchNorm2d(10 // 2)
        self.conv = torch.nn.Conv2d(10 // 2, 10 // 2, (3, 3), padding=1)

    def forward(self, x):
        return self.bn(self.conv(x))

Gm = SubModule()

### Reconstruction of X fails

In [None]:
bwd = False
coupling = 'additive'
keep_input = False
implementation_fwd = 0
implementation_bwd = 0

this_iter_keep_input = keep_input or (implementation_bwd == -1)

X = Variable(torch.from_numpy(data.copy())).clone()
Ytarget = Variable(torch.from_numpy(target_data.copy())).clone()
Xshape = X.shape
Gm2 = copy.deepcopy(Gm)
rb = revop.ReversibleBlock(Gm2, coupling=coupling, implementation_fwd=implementation_fwd,
                           implementation_bwd=implementation_bwd, keep_input=this_iter_keep_input)
rb.train()
rb.zero_grad()

optim = torch.optim.RMSprop(rb.parameters())
optim.zero_grad()
if not bwd:
    Y = rb(X)
    Yrev = Y.clone()
    Xinv = rb.inverse(Yrev)
loss = torch.nn.MSELoss()(Y, Ytarget)

if this_iter_keep_input:
    self.assertTrue(X.data.shape == Xshape)
    self.assertTrue(Y.data.shape == Yrev.shape)
else:
    try:
        assert (len(X.data.shape) == 0 or (len(X.data.shape) == 1 and X.data.shape[0] == 0))
        assert (len(Yrev.data.shape) == 0 or (len(Yrev.data.shape) == 1
                                              and Yrev.data.shape[0] == 0))
    except AssertionError:
        print("Test would have failed")
        

optim.zero_grad()
loss.backward()
optim.step()

    
assert (Y.shape == Xshape)
assert (X.data.numpy().shape == data.shape)
np.testing.assert_allclose(X.data.numpy(), data, atol=1e-06)

## Stacking 2 rev blocks will lead to wrong gradients as well

In [None]:
class DoubleBlocks(nn.Module):
    def __init__(self, revblock):
        super(DoubleBlocks, self).__init__()
        self.revblock1 = revblock
        self.revblock2 = copy.deepcopy(revblock)
        
    def forward(self, x):
        return self.revblock2(self.revblock1(x))
    
    def inverse(self, y):
        return self.revblock1(self.revblock2(y))
        

### Collect correct gradients with naive implementation

In [None]:
bwd = False
coupling = 'additive'
keep_input = False
implementation_fwd = -1
implementation_bwd = -1

this_iter_keep_input = keep_input or (implementation_bwd == -1)

X = Variable(torch.from_numpy(data.copy())).clone()
Ytarget = Variable(torch.from_numpy(target_data.copy())).clone()
Xshape = X.shape
Gm2 = copy.deepcopy(Gm)
rb = revop.ReversibleBlock(Gm2, coupling=coupling, implementation_fwd=implementation_fwd,
                           implementation_bwd=implementation_bwd, keep_input=this_iter_keep_input)
rb = DoubleBlocks(rb)
rb.train()
rb.zero_grad()

optim = torch.optim.RMSprop(rb.parameters())
optim.zero_grad()
if not bwd:
    Y = rb(X)
    Yrev = Y.clone()
    Xinv = rb.inverse(Yrev)
loss = torch.nn.MSELoss()(Y, Ytarget)

if this_iter_keep_input:
    assert (X.data.shape == Xshape)
    assert (Y.data.shape == Yrev.shape)
else:
    try:
        assert (len(X.data.shape) == 0 or (len(X.data.shape) == 1 and X.data.shape[0] == 0))
        assert (len(Yrev.data.shape) == 0 or (len(Yrev.data.shape) == 1
                                              and Yrev.data.shape[0] == 0))
    except AssertionError:
        print("Test would have failed")
        

optim.zero_grad()
loss.backward()
optim.step()

    
assert (Y.shape == Xshape)
assert (X.data.numpy().shape == data.shape)
np.testing.assert_allclose(X.data.numpy(), data, atol=1e-06)

correct_grads = [copy.deepcopy(p.grad.numpy()) for p in rb.parameters()]

### Fails with keep_input=False and imp=0

In [None]:
bwd = False
coupling = 'additive'
keep_input = False
implementation_fwd = 0
implementation_bwd = 0

this_iter_keep_input = keep_input or (implementation_bwd == -1)

X = Variable(torch.from_numpy(data.copy())).clone()
Ytarget = Variable(torch.from_numpy(target_data.copy())).clone()
Xshape = X.shape
Gm2 = copy.deepcopy(Gm)
rb = revop.ReversibleBlock(Gm2, coupling=coupling, implementation_fwd=implementation_fwd,
                           implementation_bwd=implementation_bwd, keep_input=this_iter_keep_input)
rb = DoubleBlocks(rb)
rb.train()
rb.zero_grad()

optim = torch.optim.RMSprop(rb.parameters())
optim.zero_grad()
if not bwd:
    Y = rb(X)
    Yrev = Y.clone()
    Xinv = rb.inverse(Yrev)
loss = torch.nn.MSELoss()(Y, Ytarget)

if this_iter_keep_input:
    assert (X.data.shape == Xshape)
    assert (Y.data.shape == Yrev.shape)
else:
    try:
        assert (len(X.data.shape) == 0 or (len(X.data.shape) == 1 and X.data.shape[0] == 0))
        assert (len(Yrev.data.shape) == 0 or (len(Yrev.data.shape) == 1
                                              and Yrev.data.shape[0] == 0))
    except AssertionError:
        print("Test would have failed")
        

optim.zero_grad()
loss.backward()
optim.step()

    
assert (Y.shape == Xshape)
assert (X.data.numpy().shape == data.shape)
#np.testing.assert_allclose(X.data.numpy(), data, atol=1e-06)
cur_grads = [copy.deepcopy(p.grad.numpy()) for p in rb.parameters()]
for g1, g2 in zip(correct_grads, cur_grads):
    np.testing.assert_allclose(g1, g2, atol=1e-06)

### Everything fine with keep_input=True and imp=0

In [None]:
bwd = False
coupling = 'additive'
keep_input = True
implementation_fwd = 0
implementation_bwd = 0

this_iter_keep_input = keep_input or (implementation_bwd == -1)

X = Variable(torch.from_numpy(data.copy())).clone()
Ytarget = Variable(torch.from_numpy(target_data.copy())).clone()
Xshape = X.shape
Gm2 = copy.deepcopy(Gm)
rb = revop.ReversibleBlock(Gm2, coupling=coupling, implementation_fwd=implementation_fwd,
                           implementation_bwd=implementation_bwd, keep_input=this_iter_keep_input)
rb = DoubleBlocks(rb)
rb.train()
rb.zero_grad()

optim = torch.optim.RMSprop(rb.parameters())
optim.zero_grad()
if not bwd:
    Y = rb(X)
    Yrev = Y.clone()
    Xinv = rb.inverse(Yrev)
loss = torch.nn.MSELoss()(Y, Ytarget)

if this_iter_keep_input:
    assert (X.data.shape == Xshape)
    assert (Y.data.shape == Yrev.shape)
else:
    try:
        assert (len(X.data.shape) == 0 or (len(X.data.shape) == 1 and X.data.shape[0] == 0))
        assert (len(Yrev.data.shape) == 0 or (len(Yrev.data.shape) == 1
                                              and Yrev.data.shape[0] == 0))
    except AssertionError:
        print("Test would have failed")
        

optim.zero_grad()
loss.backward()
optim.step()

    
assert (Y.shape == Xshape)
assert (X.data.numpy().shape == data.shape)
#np.testing.assert_allclose(X.data.numpy(), data, atol=1e-06)
cur_grads = [copy.deepcopy(p.grad.numpy()) for p in rb.parameters()]
for g1, g2 in zip(correct_grads, cur_grads):
    np.testing.assert_allclose(g1, g2, atol=1e-06)