# Dependencies

In [17]:
using Base
using StaticArrays

using BenchmarkTools

# MultiDual

In [2]:
mutable struct MultiDual{N,T}
	val::T
	# SVector is static vector which lives on the stack
	grads::SVector{N,T} 
end

In [3]:
function Base.:+(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(f.val+g.val, f.grads+g.grads)
end

function Base.:*(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(f.val*g.val, f.val*g.grads+f.grads*g.val)
end

function Base.:*(f::Number, g::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(f*g.val, f*g.grads)
end

function Base.:*(f::MultiDual{N,T}, g::Number) where {N,T}
    return MultiDual{N,T}(f.val*g, f.grads*g)
end

function Base.:^(f::MultiDual{N,T}, g::Real) where {N,T}
    return Base.power_by_squaring(f,g)
end

In [19]:
function parallel_stencil(arr_in, vals_out, grads_out)
    Threads.@threads for i=2:length(arr_in)-1
        output = stencil(MultiDual(arr_in[i-1], SVector(1.,0.,0.)), MultiDual(arr_in[i], SVector(0.,1.,0.)), MultiDual(arr_in[i+1], SVector(0.,0.,1.)))
        vals_out[i-1] = output.val
        grads_out[i-1,:] = output.grads
    end

    return vals_out, grads_out
end

function sequential_stencil(arr_in, vals_out, grads_out)
    for i=2:length(arr_in)-1
        output = stencil(MultiDual(arr_in[i-1], SVector(1.,0.,0.)), MultiDual(arr_in[i], SVector(0.,1.,0.)), MultiDual(arr_in[i+1], SVector(0.,0.,1.)))
        vals_out[i-1] = output.val
        grads_out[i-1,:] = output.grads
    end

    return vals_out, grads_out
end

sequential_stencil (generic function with 1 method)

In [22]:
stencil(a,b,c) = 2*a * 3*b * 4*c

arr_in = rand(900,1)
vals_out = zeros(898,1)
grads_out = zeros(898,3)

vals_out, grads_out = parallel_stencil(arr_in, vals_out, grads_out)

@btime _, _ = parallel_stencil(arr_in, vals_out, grads_out)

@btime _, _ = sequential_stencil(arr_in, vals_out, grads_out)

  9.790 μs (3687 allocations: 177.62 KiB)


  14.949 μs (3593 allocations: 168.41 KiB)


([0.7999799006974531; 2.3989998799775543; … ; 0.0010598276077355168; 0.21146823274780807;;], [4.338212984915069 1.8834663309962316 1.8797524599189919; 5.648186283258341 5.637049033108953 4.338212984915069; … ; 0.2459293528902565 0.0011271089312295884 0.09725354488304032; 0.22489292792688995 19.405075990442104 0.2459293528902565])

In [13]:
vals_out


7×1 Matrix{Float64}:
 7.242411135584619
 3.246528690805309
 3.228258736602759
 2.435043244492823
 0.4653338054183609
 0.09377348147807171
 0.3033432584250987

In [14]:
grads_out

7×3 Matrix{Float64}:
 8.22064   7.72401  19.8257
 3.46241   8.8872    8.22064
 8.83719   8.17438   3.46241
 6.16585   2.61166   8.83719
 0.499086  1.68878   6.16585
 0.34032   1.24253   0.499086
 4.01942   1.61447   0.34032

In [7]:
Threads.nthreads()

16