In [15]:
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization, Conv2D, Activation

In [16]:
class RevBlock(tf.keras.Model):
    def __init__(self, number_res, filters, input_shape):
        super().__init__()
        self.blocks = tf.contrib.checkpoint.List()
        for i in range(number_res):
            block = _ResBlock(filters, input_shape)
            self.blocks.append(block)
    
    
    def call(self, x, training=True):
        for block in self.blocks:
            x = block(x, training=training)
        return x
    
    
    def backward_grads_and_vars(self, y, dy, training=True):
        grads_all = []
        vars_all = []
        
        y, dy, grads, vars_ = block.backward_grads_and_vars(
            y, dy, training=training)
        
        all_grads += grads
        all_vars += vars_
        
        return dy, all_grads, all_vars
    

class _ResBlock(tf.keras.Model):
    def __init__(self, filters, input_shape):
        super().__init__()
        self.filters = filters
        
        f_input_shape = input_shape[:2] + [input_shape[2] // 2]
        g_input_shape = input_shape[:2] + [filters // 2]
        
        self.f = _ResInner(filters=filters // 2, 
                           input_shape=f_input_shape)
        self.g = _ResInner(filters=filters // 2, 
                           input_shape=g_input_shape)
        
        
    def call(self, x, training=True):
        x1, x2 = tf.split(x, num_or_size_splits=2, axis=3)
        
        y1 = self.f(x2, training=training) + x1
        y2 = self.g(y1, training=training) + x2
        
        return tf.concat([y1, y2], axis=3)
    
    
    def backward_grads_and_vars(self, y, dy, training=True):
        dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=3)
        
        with tf.GradientTape(persistent=True) as tape:
            y = tf.identity(y)
            tape.watch(y)
            y1, y2 = tf.split(y, num_or_size_splits=2, axis=3)
            z1 = y1
            g_z1 = self.g(z1, training=training)
            x2 = y2 - g_z1
            f_x2 = self.f(x2, training=training)
            x1 = z1 - f_x2
            
        g_grads = tape.gradient(g_z1, 
                                [z1] + self.g.trainable_variables, 
                                output_gradients=dy2)
        dz1 = dy1 + g_grads[0]
        dg = g_grads[1:]
        dx1 = dz1
        
        f_grads = tape.gradient(f_x2, 
                                [x2] + self.f.trainable_variables, 
                                output_gradients=dz1)
        dx2 = dy2 + f_grads[0]
        df = f_grads[1:]
        
        del tape
        
        grads = df + dg
        vars_ = self.f.trainable_variables + self.g.trainable_variables
        
        x = tf.concat([x1, x2], axis=3)
        dx = tf.concat([dx1, dx2], axis=3)
        
        return x, dx, grads, vars_
    
    
def _ResInner(filters, input_shape):
    model = tf.keras.Sequential()
    model.add(BatchNormalization(axis=3, input_shape=input_shape))
    model.add(Activation("relu"))
    model.add(Conv2D(filters=filters, kernel_size=3, 
                     input_shape=input_shape, use_bias=False, 
                     padding="SAME"))
    model.add(BatchNormalization(axis=3))
    model.add(Activation("relu"))
    model.add(Conv2D(filters=filters, kernel_size=3, use_bias=False, 
                     padding="SAME"))
    
    return model

In [None]:
class RevNet(tf.keras.Model):
    def __init__(self, n_blocks=1):
        super().__init__()
        self._init_block = self._construct_init()
        self._block_list = self._construct_blocks()
        self._final_block = self._construct_final()
    
    
    def call(self, inputs, training=True):
        if training:
            saved = [inputs]
            
        x = self._init_block(inputs, training=training)
        if training:
            saved.append(x)
            
        for block in self._block_list:
            x = block(x, training=training)
            if training:
                saved.append(x)
                
        logits = self._final_block(x, training=training)
        
        return logits, saved
    
    