<a href="https://colab.research.google.com/github/vifirsanova/100-days-of-code/blob/main/day16/Reversible_Residual_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**2 base features:**

1. *Locality Sensitive Hashing (LSH) Attention* reduces the compute cost of the dot product attention
2. *Reversible Residual Networks (RevNets)* reduces the storage requirements during backpropagation

In [3]:
#!pip install trax

import trax
from trax import layers as tl
import numpy as np
from trax.layers.reversible import ReversibleHalfResidual
from trax import fastmath
from trax import shapes
from trax.fastmath import numpy as jnp
from trax.shapes import ShapeDtype
from trax.shapes import signature

# Residual Model

In [9]:
def Residual(layer):
    return tl.Serial(
        tl.Branch(layer, None), # stack: if None, pull the top and push onto the output 
        tl.Add()
    )

# Reversible Residual Network (RevNet)

**Concept**: additional computations trade for memory space.

In [18]:
F_layer = tl.Fn("F", lambda x0: (2 * x0), n_out=1)
G_layer = tl.Fn("G", lambda x0: (8 * x0), n_out=1)

block = [
    ReversibleHalfResidual(F_layer), 
    tl.ReversibleSwap(), # reverses, duplicates the top of stack, returns the stack
    ReversibleHalfResidual(G_layer), 
    tl.ReversibleSwap(), # reverses, duplicates the top of stack, returns the stack
]
blocks = [block, block]

model = tl.Serial(
    tl.Dup(), # duplicates the top of stack, returns the stack
    blocks,
    tl.Concatenate(),
)
model

Serial[
  Dup_out2
  ReversibleHalfResidual_in2_out2[
    Serial[
      F
    ]
  ]
  ReversibleSwap_in2_out2
  ReversibleHalfResidual_in2_out2[
    Serial[
      G
    ]
  ]
  ReversibleSwap_in2_out2
  ReversibleHalfResidual_in2_out2[
    Serial[
      F
    ]
  ]
  ReversibleSwap_in2_out2
  ReversibleHalfResidual_in2_out2[
    Serial[
      G
    ]
  ]
  ReversibleSwap_in2_out2
  Concatenate_in2
]

In [19]:
x1 = np.array([1])
model.init(shapes.signature(x1))
out = model(x1)
out

DeviceArray([ 53, 449], dtype=int32)