In [None]:
using Revise

using Statistics
using Symbolics
using ArrayOperations

import Base: getindex,
             @propagate_inbounds

# Custom unary operator

In [None]:
struct MyUn <: Ary{1} end

@inline @propagate_inbounds function getindex(this::Ret{MyUn}, i)
    x, = arguments(this)

    x[i-1] ^ 2
end

In [None]:
f = MyUn()

n = 32
x = rand(n)

@assert isequal(f(x)[2], x[1] ^ 2)

In [None]:
using BenchmarkTools

i = 2

@btime $(f(x))[$i]
@btime $x[$i-1]

In [None]:
δ⁻ = Backward{1}(-)
δ⁺ = Forward{1}(-)

@assert isequal(δ⁻(x)[3], x[3]-x[2])

In [None]:
σ⁻ = Backward{1}(middle)
σ⁺ = Forward{1}(middle)

In [None]:
σ⁺(δ⁻(f(x)))

@assert isequal(δ⁻(f(x))[3], x[2] ^ 2 - x[1] ^ 2)

In [None]:
h = σ⁺(δ⁻(f(x)))

In [None]:
2h

In [None]:
@syms i::Int j::Int λ::Real

In [None]:
@variables ω[1:n] γ[1:n] A[1:n] B[1:n] W[1:n]

In [None]:
grad = ⊙(/, ⊙(+, ⊙(*, δ⁻(ω), B), ⊙(-, ⊙(*, ⊙(-, σ⁺(A), B), γ), σ⁻(⊙(*, δ⁺(A), γ)))), W)

expr = simplify(grad[i])

In [None]:
Symbolics.get_variables(expr)

In [None]:
string(expr)

In [None]:
B |> typeof

# Custom binary operator

In [None]:
struct MyBin <: Primitive{Arity{2}} end

@inline @propagate_inbounds function getindex(this::Ret{MyBin}, i)
    x, y = arguments(this)

    x[i] * y[i-1]
end

In [None]:
#n = 3

x, y = rand(n), rand(n)

g = MyBin()

δ⁻(g(x, y))

In [None]:
∂₁ = ∂{Tuple{1}}()
∂₂ = ∂{Tuple{2}}()

In [None]:
# binary operator
f = MyBin()

# fix all arguments
h = f((x, y))

# fix all but first, then fix first and check that result is same
g₁ = ∂₂(f, (y,))
h′ = g₁((x,))
#=
#@assert isequal(h, h′)

# fix all but second, then fix second and check that result is same
g₂ = ∂₁(f, (x,))
h″ = g₂((y,))

@assert isequal(h, h″)
=#

In [None]:
h′

In [None]:
# account for non-locality by hand for now

rng = 2:n

@assert isequal(h[rng], x[rng] .* y[rng .- 1])

# Derivatives

In [None]:
const ∇₁ = ∇{Tuple{1}}()
const ∇₂ = ∇{Tuple{2}}()

const ∂₁ = ∂{1}()
const ∂₂ = ∂{2}()

## First order

In [None]:
# binary operator
f = MyBin()

d₁f = ∇₁(f)

f₁ = ∂₁(f, (y,))
df₁ = ∇₁(f₁)

@assert isequal(df₁, ∂₁(d₁f, (y,)))

d₂f = ∇₂(f)

f₂ = ∂₂(f, (x,))
df₂ = ∇₁(f₂)

@assert isequal(df₂, ∂₂(d₂f, (x,)))

## Second order

In [None]:
const ∇₁₂ = ∇{Tuple{1,2}}()

In [None]:
h = ∇₁₂(f)

# Support

In [None]:
struct LocalOp <: Primitive{Arity{2}} end

(::LocalOp)((x, y)::NTuple{2,AbstractVector}, i::Int) = x[i] * y[i]

OperatorSupport(::Type{<:LocalOp}, ::Dim) = HasStencil()

Stencil(::Type{<:LocalOp}, ::Dim) = PointWise()

(::Jac{1,S,O})((_, y)::NTuple{2,AbstractVector}, i::Int) where {S,O<:LocalOp} = y[i]
(::Jac{2,S,O})((x, _)::NTuple{2,AbstractVector}, i::Int) where {S,O<:LocalOp} = x[i]

In [None]:
f = LocalOp()

d₁f = ∇₁(f)
d₂f = ∇₂(f)

@assert isequal(d₁f((x, y), 2), y[2])
@assert isequal(d₂f((x, y), 2), x[2])