# Dependencies

In [1]:
using Base

using BenchmarkTools

# Variable

In [2]:
mutable struct Variable
    value::Number       # Value of this Variable
    parents::Vector     # Who created this Variable
    chain_rules::Vector # Functions representing the chain rules

    # Constructor for creating Input Nodes
    function Variable(value::Number)
        new(value, [], [])
    end

    # Constructor for creating Intermediate/Output Nodes
    function Variable(value::Number, parents::Vector{Variable}, chain_rules::Vector)
        new(value, parents, chain_rules)
    end
end

In [3]:
# Addition Operaton on Scalars
function Base.:+(var1::Variable, var2::Variable)
    # Performing Addition
    value = var1.value + var2.value

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * 1
    global_dvar2 = global_grad::Variable -> global_grad * 1

    return Variable(value, [var1, var2], [global_dvar1, global_dvar2])
end

In [4]:
# Multiplication Operaton on Scalars
function Base.:*(var1::Variable, var2::Variable)
    # Performing Multiplication
    value = var1.value * var2.value

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * var2
    global_dvar2 = global_grad::Variable -> global_grad * var1

    return Variable(value, [var1, var2], [global_dvar1, global_dvar2])
end

# Multiplication Operaton on Scalars
function Base.:*(var1::Variable, var2::Number)
    # Performing Multiplication
    value = var1.value * var2

    # Local Gradients Computations
    global_dvar1 = global_grad::Variable -> global_grad * var2

    return Variable(value, [var1], [global_dvar1])
end

# Multiplication Operaton on Scalars
function Base.:*(var1::Number, var2::Variable)
    # Performing Multiplication
    value = var1 * var2.value

    # Local Gradients Computations
    global_dvar2 = global_grad::Variable -> global_grad * var1

    return Variable(value, [var2], [global_dvar2])
end

In [14]:
# Uses compute() to get gradients and store them in a dict
function autograd(T::Variable)
	gradients = Dict() # Dict to hold grads of T wrt all creators

    # Computes the global gradient WRT node
    function compute(node::Variable, global_grad::Variable)
        for (parent, chain_rule) in zip(node.parents, node.chain_rules)
            # Chain Rule
            new_global_grad = chain_rule(global_grad)

            # Checking if grad present or not
            if haskey(gradients, parent)
                gradients[parent] += new_global_grad
            else
                gradients[parent] = new_global_grad
            end

            # Recusive call
            compute(parent, new_global_grad)
        end
    end
    
	# Output Node
	dT_dT = Variable(1)
	compute(T, dT_dT)

	return gradients
end

autograd (generic function with 1 method)

# Stencil Computation

In [30]:
arr_in = rand(9,1)
var_in = []
for num in arr_in
    push!(var_in, Variable(num)) 
end

stencil(a,b,c) = 2*a * 3*b * 4*c

outputs = []
for i=2:size(arr_in)[1]-1
    output = stencil(var_in[i-1], var_in[i], var_in[i+1])
    push!(outputs, output)
end