# Forwad Diff Implementation

In [2]:
# Initialization
using LinearAlgebra, StaticArrays
import Base: +, -, *, /

Let us define a struct that tracks the value along with its derivative

In [7]:
struct Dual{T<:Real} <: Real
    x::T
    ϵ::T
end

# Constructors for our struct Dual
Dual(x::S, d::T) where {S<:Real, T<:Real} = Dual{promote_type(S, T)}(x, d)
Dual(x::Real) = Dual(x, zero(x))
Dual{T}(x::Real) where {T} = Dual(T(x), zero(T))

# To make the output beautiful
function Base.show(io::IO, d::Dual)
    if signbit(d.ϵ)
        print(io, d.x, " - ", -d.ϵ, "ϵ")
    else
        print(io, d.x, " + ", d.ϵ, "ϵ")
    end
end

In [5]:
# Elementary algebraic operations on dual numbers

a::Dual + b::Dual = Dual(a.x + b.x, a.ϵ + b.ϵ)
a::Dual - b::Dual = Dual(a.x - b.x, a.ϵ - b.ϵ)
a::Dual * b::Dual = Dual(a.x * b.x, b.x * a.ϵ + a.x * b.ϵ)
a::Dual / b::Dual = Dual(a.x / b.x, (a.ϵ*b.x - a.x*b.ϵ) / b.x^2)


/ (generic function with 112 methods)

In [6]:
# Let us define a few primitives
Base.sin(d::Dual) = Dual(sin(d.x), d.ϵ * cos(d.x))
Base.cos(d::Dual) = Dual(cos(d.x), - d.ϵ * sin(d.x))

In [8]:
# To ensure compatibility between duals of different types, and between duals and reals

Base.convert(::Type{Dual{T}}, d::Dual) where T = Dual(convert(T, d.x), convert(T, d.ϵ))
Base.convert(::Type{Dual{T}}, d::Real) where T = Dual(convert(T, d), zero(T))
Base.promote_rule(::Type{Dual{T}}, ::Type{R}) where {T,R} = Dual{promote_type(T,R)}
Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{R}}) where {T<:Real, R<:Real} = Dual{promote_type(T,R)}

Now let us define an arbitrary function in Julia and see what happens when we call it with some argument and pass duals to it

In [11]:
f(x) = 5x^2 + 10x + 10
f(1.0)
f(Dual(1., 1.))

25.0 + 20.0ϵ

In [30]:
# Let us define a function that
D(f, x) = f(Dual(x, one(x))).ϵ
E(f, x) = f(Dual(x, one(x)))
D(f,1.)

20.0

In [15]:
@code_typed f(1.0)
@code_typed D(f, 1.0)

CodeInfo(
[90m1 ─[39m %1  = Base.mul_float(x, x)[36m::Float64[39m
[90m│  [39m %2  = Base.mul_float(x, 1.0)[36m::Float64[39m
[90m│  [39m %3  = Base.mul_float(x, 1.0)[36m::Float64[39m
[90m│  [39m %4  = Base.add_float(%2, %3)[36m::Float64[39m
[90m│  [39m %5  = Base.mul_float(%1, 0.0)[36m::Float64[39m
[90m│  [39m %6  = Base.mul_float(5.0, %4)[36m::Float64[39m
[90m│  [39m %7  = Base.add_float(%5, %6)[36m::Float64[39m
[90m│  [39m %8  = Base.mul_float(x, 0.0)[36m::Float64[39m
[90m│  [39m %9  = Base.add_float(%8, 10.0)[36m::Float64[39m
[90m│  [39m %10 = Base.add_float(%7, %9)[36m::Float64[39m
[90m│  [39m %11 = Base.add_float(%10, 0.0)[36m::Float64[39m
[90m└──[39m       return %11
) => Float64

In [16]:
function g(x)
    if x < 5
        return x
    else
        return 2x
    end
end

g (generic function with 1 method)

In [None]:
g(Dual(1.,1.))

In [21]:
D(x -> x*D(y -> x+y, 1), 1) # == 1

2

In [28]:
D(x -> x*D(y -> x+y, 1), 1) # == 1

g(x) = x -> x*(y -> x+y)

function h(x)
    f(y) = x + y
    println(f(Dual(1)))
    @show x*f(Dual(1)).ϵ
end

D(h,1)

2 + 1ϵ
x * (f(Dual(1))).ϵ = 1 + 1ϵ


1

In [38]:
D(x -> x*D(y -> x*y, 1), 4)
# @code_lowered D(x -> x*D(y -> x*y, 1), 4)
# E(x -> x*E(y -> x*y, 1), 4)

5

In [None]:
D(x -> x*D(y -> x+y, 1), 1) # == 1 # but we get 2
D(x -> x*D(y -> x*y, 1), 4) # == 8 # but we get 5

###  For functions with multiple parameters

In [42]:
f1(x, y) = x^2 + x*y

2 + 1ϵ

We will have to take partial with respect to each of the parameters

In [None]:
x,y = 1,1
f1(Dual(x,1),Dual(y,0))
f1(Dual(x,0),Dual(y,1))

Let's use StaticArrays and Multi Duals to do this