In [2]:
# Dependencies
using Base

using BenchmarkTools
using Random

# Reverse Mode Automatic Differentiation

## Introduction
Reverse Mode AD is a little different from Forward Mode AD. In reverse accumulation AD, the dependent variable to be differentiated is fixed and the derivative is computed with respect to each sub-expression recursively. In other words, the derivative of the outer functions is repeatedly substituted in the chain rule:

$\begin{aligned}
\frac{\partial y}{\partial x}   &= \frac{\partial y}{\partial w_1} \cdot \frac{\partial w_1}{\partial x} \\
                                &= \bigg( \frac{\partial y}{\partial w_2} \cdot \frac{\partial w_2}{\partial w_1} \bigg) \cdot \frac{\partial w_1}{\partial x} \\
                                &= \cdots \\
\end{aligned}$

Consider a function $f(x_1, x_2)=\sin(x_1) + (x_1 \cdot x_2)$. To get the derivatives with respect to $x_1$ and $x_2$, Reverse accumulation traverses the chain rule from outside to inside. The example function is scalar-valued, and thus there is only one seed for the derivative computation, and only one sweep of the computational graph is needed to calculate the (two-component) gradient. This is only half the work when compared to forward accumulation (this is only true if the output is scalar), but reverse accumulation requires the storage of the intermediate variables as well as the instructions that produced them in a data structure known as a "tape", which may consume significant memory if the computational graph is large. This can be mitigated to some extent by storing only a subset of the intermediate variables and then reconstructing the necessary work variables by repeating the evaluations, a technique known as rematerialization. Checkpointing is also used to save intermediary states.

In this work, I'll build a simple Reverse AD which tracks the computation graph and then evaluates the derivatives. To do this, I'll define a `mutable struct Variable` that holds the value, parents and the chain rule to evaluate derivatives with respect to it's parents.

In [3]:
mutable struct Variable
    value::Number       # Value of this Variable
    parents::Vector     # Who created this Variable
    chain_rules::Vector # Functions representing the chain rules

    # Constructor for creating Input Nodes
    function Variable(value::Number)
        new(value, [], [])
    end

    # Constructor for creating Intermediate/Output Nodes
    function Variable(value::Number, parents::Vector{Variable}, chain_rules::Vector)
        new(value, parents, chain_rules)
    end
end

Now, I just have to define the operations of `Variable` to store the parents and define the chain rules.

In [4]:
# Addition Operaton on Scalars
function Base.:+(var1::Variable, var2::Variable)
    # Performing Addition
    value = var1.value + var2.value

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * 1
    global_dvar2 = global_grad::Variable -> global_grad * 1

    return Variable(value, [var1, var2], [global_dvar1, global_dvar2])
end

In [5]:
# Multiplication Operaton on Scalars
function Base.:*(var1::Variable, var2::Variable)
    # Performing Multiplication
    value = var1.value * var2.value

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * var2
    global_dvar2 = global_grad::Variable -> global_grad * var1

    return Variable(value, [var1, var2], [global_dvar1, global_dvar2])
end

# Multiplication Operaton on Scalars
function Base.:*(var1::Variable, var2::Number)
    # Performing Multiplication
    value = var1.value * var2

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * var2

    return Variable(value, [var1], [global_dvar1])
end

# Multiplication Operaton on Scalars
function Base.:*(var1::Number, var2::Variable)
    # Performing Multiplication
    value = var1 * var2.value

    # Local Gradients Computations
    global_dvar2 = global_grad::Variable -> global_grad * var1

    return Variable(value, [var2], [global_dvar2])
end

In [6]:
function Base.:^(var1::Variable, pow::Int)
    # Performing power
    value = var1.value^pow

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * pow * var1^(pow-1)

    return Variable(value, [var1], [global_dvar1])
end

In [7]:
function Base.:sin(var1::Variable)
    # Performing Addition
    value = sin(var1.value)

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * cos(var1)

    return Variable(value, [var1], [global_dvar1])
end

function Base.:cos(var1::Variable)
    # Performing Addition
    value = cos(var1.value)

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * -1 * sin(var1)

    return Variable(value, [var1], [global_dvar1])
end

Once the computation graph is defined, we have to move backwards and collect the gradients. I've done that by defining the function `autograd`.

In [8]:
# Uses compute() to get gradients and store them in a dict
function autograd(T::Variable)
	gradients = Dict() # Dict to hold grads of T wrt all creators

    # Computes the global gradient WRT node
    function compute(node::Variable, global_grad::Variable)
        for (parent, chain_rule) in zip(node.parents, node.chain_rules)
            # Chain Rule
            new_global_grad = chain_rule(global_grad)

            # Checking if grad present or not
            if haskey(gradients, parent)
                gradients[parent] += new_global_grad
            else
                gradients[parent] = new_global_grad
            end

            # Recusive call
            compute(parent, new_global_grad)
        end
    end
    
	# Output Node
	dT_dT = Variable(1)
	compute(T, dT_dT)

	return gradients
end

autograd (generic function with 1 method)

## Example: Differentiating a simple equation
Let's differentiate $f(x)=x^2+2x$ at $x=2$.

In [9]:
# Defining Function
h(x) = x^2 + 2*x

# Defining x
x = Variable(2)

# Forward Pass
y = h(x)

# Backward Pass
dy_dx = autograd(y)[x]

println("Function Value: ", y.value)
println("Function Gradient: ", dy_dx.value)

Function Value: 8
Function Gradient: 6


For parallelisation in reverse mode AD, we can only rely on the parallelisation in the computation graph. 

# Parallel Reverse Mode Automatic Differentiation: Stencil Computation

In [10]:
# Stencil Computation
function stencil(a,b,c) 
    return 2*a * 3*b * 4*c
end

# Parallel Stencil Computation
function parallel_stencil(arr_in, vals_out, grads_out)
    Threads.@threads for i=2:size(arr_in)[1]-1
        a = Variable(arr_in[i-1])
        b = Variable(arr_in[i])
        c = Variable(arr_in[i+1])
        # Forward Pass
        output = stencil(a, b, c)
        # Derivative
        d_output = autograd(output)

        vals_out[i-1] = output.value
        grads_out[i-1, :] = [d_output[a].value, d_output[b].value, d_output[c].value]
    end

    return vals_out, grads_out
end

parallel_stencil (generic function with 1 method)

In [12]:
# Number of available threads
println("Number of Threads: ", Base.Threads.nthreads())

# Computation
Random.seed!(123)
arr_in = rand(100000,1)
vals_out = zeros(size(arr_in)[1]-2,1)
grads_out = zeros(size(arr_in)[1]-2,3)

vals_out, grads_out = parallel_stencil(arr_in, vals_out, grads_out) 

println("Parallel Evaluation Time")
@btime _, _ = parallel_stencil(arr_in, vals_out, grads_out); 

Number of Threads: 1


Parallel Evaluation Time


  683.638 ms (16698654 allocations: 588.96 MiB)


In [13]:
# Sequential Stencil Computation
function sequential_stencil(arr_in, vals_out, grads_out)
    for i=2:size(arr_in)[1]-1
        a = Variable(arr_in[i-1])
        b = Variable(arr_in[i])
        c = Variable(arr_in[i+1])
        # Forward Pass
        output = stencil(a, b, c)
        # Derivative
        d_output = autograd(output)

        vals_out[i-1] = output.value
        grads_out[i-1, :] = [d_output[a].value, d_output[b].value, d_output[c].value]
    end

    return vals_out, grads_out
end

sequential_stencil (generic function with 1 method)

In [14]:
println("Sequential Evaluation Time")
@btime _, _ = sequential_stencil(arr_in, vals_out, grads_out); 

Sequential Evaluation Time


  680.299 ms (16698645 allocations: 588.96 MiB)


Let's make the stencil computation a little more complicated with other functions.

In [15]:
# Stencil Computation
function complex_stencil(a,b,c) 
    return sin(2*a) * cos(3*b) * 4*c
end

# Parallel Stencil Computation
function complex_parallel_stencil(arr_in, vals_out, grads_out)
    Threads.@threads for i=2:size(arr_in)[1]-1
        a = Variable(arr_in[i-1])
        b = Variable(arr_in[i])
        c = Variable(arr_in[i+1])
        # Forward Pass
        output = complex_stencil(a, b, c)
        # Derivative
        d_output = autograd(output)

        vals_out[i-1] = output.value
        grads_out[i-1, :] = [d_output[a].value, d_output[b].value, d_output[c].value]
    end

    return vals_out, grads_out
end

complex_parallel_stencil (generic function with 1 method)

In [16]:
# Computation
Random.seed!(123)
arr_in = rand(100000,1)
vals_out = zeros(size(arr_in)[1]-2,1)
grads_out = zeros(size(arr_in)[1]-2,3)

vals_out, grads_out = complex_parallel_stencil(arr_in, vals_out, grads_out) 

println("Parallel Evaluation Time")
@btime _, _ = complex_parallel_stencil(arr_in, vals_out, grads_out); 

Parallel Evaluation Time


  1.011 s (21898550 allocations: 773.59 MiB)
