In [None]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [None]:
import os
os.sys.path.insert(0, '/home/schirrmr/code/memcnn/')

import unittest
import torch
import torch.nn
from torch.autograd import Variable
import memcnn.models.revop as revop
import numpy as np
import copy

In [None]:
dims = (2, 10, 8, 8)
data = np.random.random(dims).astype(np.float32)
target_data = np.random.random(dims).astype(np.float32)

class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.bn = torch.nn.BatchNorm2d(10 // 2)
        self.conv = torch.nn.Conv2d(10 // 2, 10 // 2, (3, 3), padding=1)

    def forward(self, x):
        return self.bn(self.conv(x))

Gm = SubModule()

In [None]:
s_grad = [p.data.numpy().copy() for p in Gm.parameters()]

In [None]:
class ReversibleOperationsTestCase(unittest.TestCase):
    def test_reversible_block_fwd_bwd(self):
        for _ in range(1): # was 10
            for bwd in [False]:#, True]:
                for coupling in ['additive']:  # , 'affine']:
                    impl_out, impl_grad = [], []
                    for keep_input in [False]:#, True]:
                        for implementation_fwd in [-1, 0,]:#, 1, 1]:
                            for implementation_bwd in [-1, 0]:#, 1]:
                                print("Running imp_fwd: {:d}, imp_bwd: {:d}".format(
                                    implementation_fwd, implementation_bwd))
                                keep_input = keep_input or (implementation_bwd == -1)
                                # print(bwd, coupling, keep_input, implementation_fwd, implementation_bwd)
                                # test with zero padded convolution
                                X = Variable(torch.from_numpy(data.copy())).clone()
                                Ytarget = Variable(torch.from_numpy(target_data.copy())).clone()
                                Xshape = X.shape
                                Gm2 = copy.deepcopy(Gm)
                                rb = revop.ReversibleBlock(Gm2, coupling=coupling, implementation_fwd=implementation_fwd,
                                                           implementation_bwd=implementation_bwd, keep_input=keep_input)
                                from torch import nn
                                class DoubleRb(nn.Module):
                                    def __init__(self, rb):
                                        super(DoubleRb, self).__init__()
                                        self.rb = rb
                                        self.rb2 = copy.deepcopy(self.rb)

                                    def forward(self, x):
                                        return self.rb2(self.rb(x))

                                    def inverse(self, y):
                                        return self.rb.inverse(self.rb2.inverse(y))
                                rb = DoubleRb(rb)
                                rb.train()
                                rb.zero_grad()

                                optim = torch.optim.RMSprop(rb.parameters())
                                optim.zero_grad()

                                if not bwd:
                                    Y = rb(X)
                                    Yrev = Y.clone()
                                    Xinv = rb.inverse(Yrev)
                                else:
                                    Y = rb.inverse(X)
                                    Yrev = Y.clone()
                                    Xinv = rb(Yrev)
                                loss = torch.nn.MSELoss()(Y, Ytarget)

                                # has input been retained/discarded after forward (and backward) passes?
                                if keep_input:
                                    self.assertTrue(X.data.shape == Xshape)
                                    self.assertTrue(Y.data.shape == Yrev.shape)
                                else:
                                    self.assertTrue(len(X.data.shape) == 0 or (len(X.data.shape) == 1 and X.data.shape[0] == 0))
                                    self.assertTrue(len(Yrev.data.shape) == 0 or (len(Yrev.data.shape) == 1
                                                                                  and Yrev.data.shape[0] == 0))

                                optim.zero_grad()
                                print("loss bwd")
                                loss.backward()
                                print("bwd completed")
                                optim.step()

                                self.assertTrue(Y.shape == Xshape)
                                self.assertTrue(X.data.numpy().shape == data.shape)
                                self.assertTrue(np.allclose(X.data.numpy(), data, atol=1e-06))
                                self.assertTrue(np.allclose(X.data.numpy(), Xinv.data.numpy(), atol=1e-06))
                                impl_out.append(Y.data.numpy().copy())
                                #impl_grad.append([p.data.numpy().copy() for p in Gm2.parameters()])
                                impl_grad.append([p.data.numpy().copy() for p in
                                                  rb.parameters()])
                                self.assertFalse(np.allclose(impl_grad[-1][0], s_grad[0]))

                        # output and gradients similar over all implementations?
                        for i in range(0, len(impl_grad) - 1, 1):
                            self.assertTrue(np.allclose(impl_grad[i][0], impl_grad[i + 1][0]))
                            self.assertTrue(np.allclose(impl_out[i], impl_out[i + 1]))

In [None]:
t = ReversibleOperationsTestCase()

In [None]:
t.test_reversible_block_fwd_bwd()