In [1]:
using Pkg;
Pkg.add("ReactionNetworkImporters");
Pkg.add("Dictionaries");
Pkg.add("LaTeXStrings");
Pkg.add("Statistics");
Pkg.add("ColorSchemes");
Pkg.add("IterTools"); 
Pkg.add("NNlib"); 
Pkg.add("DifferentialEquations");
# Pkg.add("Plots");
Pkg.add("Formatting");
Pkg.add("LinearAlgebra");
Pkg.add("Noise");
Pkg.add("Catalyst");

using DifferentialEquations;
using Random;
# using Plots;
using Formatting;
using LinearAlgebra;
using Noise;
using ReactionNetworkImporters;
using Dictionaries;
using LaTeXStrings;
using Statistics;
using ColorSchemes;
using Catalyst;
using IterTools;
using NNlib;

nothing

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.8/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.ju

In [7]:
include("datasets.jl")
include("utils.jl")
nothing

In [8]:
function f(u, p, x, t)
    """
    Args:
        u: hidden state
        params: [length(z)**2 + 2*length(z)]
    """
    theta, beta, w, h, _, _ = sequester_params(p)
    hvec = ones(length(w))*h
    
    fmat = (theta*x + beta).*u - u.*u + hvec 
    @assert length(fmat) == length(x)
    return fmat
end


function forward!(du, u, xAndp, t)
    """
    xAndp: [x, theta, beta, w, h, t0, t1]
    """
    x = xAndp[1:length(u)]
    p = xAndp[length(u)+1:end]
    
    func = f(u, p, x, t)
    
    for i in eachindex(func)
        du[i] = func[i]
    end
end


# Calculates the final hidden state of the neural ode
function forward_node(u0, xAndp, tspan)
    
    prob = ODEProblem(forward!, u0, tspan, xAndp)
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return sol
end


# Final feedforward layer similar to a perceptron
function forward_ffnet(z, w; threshold=nothing)
    yhat = dot(w, z) # Verified!
    # CHECK: Thinking of the final layer as a binary perceptron 
    # println("ODE | yhat at t=T: $yhat")
    
    return yhat
end


function forward_step(u0, p, tspan; threshold=nothing)
    
    xAndp = []
    append!(xAndp, u0)
    append!(xAndp, p)
    
    theta, beta, w, h, t0, t1 = sequester_params(p, dims=length(u0))
    # Output from the neural ode
    node_out = forward_node(u0, xAndp, tspan)
    # Extracting hidden state
    z = node_out.u[end][1:length(u0)]
    
    yhat = forward_ffnet(z, w, threshold=threshold)
    return (z, yhat)
end

forward_step (generic function with 1 method)

In [13]:
function aug_dynamics!(du, u, sAndp, t)
    
    s0 = sAndp[1:3]
    p = sAndp[4:end]
    theta, beta, w, h, _, _ = sequester_params(p)
    dims = Int32(sqrt(length(theta)))

    # Time dynamics for the hidden state
    offset = 0
    z = u[1:dims]
    func = -f(u, p, z, t)

    @assert length(func) == dims

    for i in 1:dims
        du[offset+i] = func[i]
    end

#     offset += dims


#     # Time dynamics for the adjoint
#     a = u[dims+1:2*dims]
#     a = reshape(a, (dims, 1))

#     # ∂f/∂z = 𝜃
#     dfdz = theta
#     @assert size(theta) == (dims, dims)

#     dadt = reshape(-transpose(a) * dfdz, dims)
#     for i in eachindex(dadt)
#         du[offset+i] = dadt[i]
#     end

#     offset += length(dadt)
#     @assert offset == 2 * dims # offset after adding dzdt and dady


#     # Time dynamics for gradients
#     dfdtheta = zeros(dims, dims^2)

#     for i in 1:dims
#         for j in 1:dims
#             dfdtheta[i, (i-1)*dims+j] = z[j]
#         end
#     end

#     dgrads = -transpose(a) * dfdtheta

#     @assert size(dgrads) == (1, dims^2)

#     for i in eachindex(dgrads)
#         du[offset+i] = dgrads[i]
#     end
#     offset += length(dgrads)

#     # Time dynamics of time(!!): Not used though sigh.
#     # TODO: Might wanna change this in future if things don't work
#     dfdt = zeros(dims, 1)
#     tgrads = -transpose(a) * dfdt

#     @assert length(tgrads) == 1

#     # currently not changing time!
#     du[offset+1] = 0

end


function backpropagation_step(s0, p, tspan; dims=3)
    theta, beta, w, h, _, _ = sequester_params(p, dims=dims)
    
    sAndp = []
    append!(sAndp, s0)
    append!(sAndp, p)
    println("sAndp: ", s0, p)
    prob = ODEProblem(aug_dynamics!, s0, tspan, sAndp)
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
    return sol
end

backpropagation_step (generic function with 1 method)

In [16]:
function training_step(x, y, p; threshold=nothing)
    """
    Args:
        x: augmented input
        y: output 
        p: parameters of the entire network
    """
    dims = length(x)
    
    theta, beta, w, h, t0, t1 = sequester_params(p, dims=dims)
    tspan = (t0, t1)
    
    @assert length(w) == dims 
    
    # Forward & Hidden state calculation
    println("ODE | w at t=0 | ", w)
    
    z, yhat = forward_step(x, p, tspan, threshold=threshold)
    z = reshape(z, (dims, 1)) # Make z a row-vector
    
    println("ODE | z at t=T | ", z)
    # Loss
    loss = 0.5*(yhat-y)^2
    
#     # Adjoint calculation
#     a = (yhat-y)*w
#     a = reshape(a, (dims, 1))
#     println("ODE | yhat at t=T | ", yhat)
#     println("ODE | Adjoint at t=T | ", a)
    
#     # Initial theta gradients
#     gtheta = zeros(dims^2, 1)
#     println("ODE | Theta gradients at t=T | ", gtheta)
    
#     # Initial time gradients 
#     func = f(z, theta, beta, x)

#     dldt1 = -transpose(a)*f(z, theta, beta, x)
#     dldt1 = convert(Array{Float64}, dldt1)
    
    # Initial state for the reverse time ODE
#     s0 = vcat(z, a, gtheta, dldt1)
    s0 = z # Just for check
    
#     rtspan = reverse(tspan)
    backward = backpropagation_step(s0, p, tspan)
    println("ODE | z at t=0 | ", backward.u[end][1:dims])
    gradients = nothing
#     println("ODE | Adjoint at t=0 | ", backward.u[end][dims+1:2*dims])
    
#     gradients = backward.u[end][2*dims+1:end]
#     gradients = reshape(gradients, size(gradients)[1])
    
#     # Note that gradients[end] already contains gradient for t0
#     append!(gradients, dldt1) # gradient for t1
    
#     # Gradients wrt w
#     wgrads = (yhat-y)*z
#     println("ODE | error: ", yhat-y)
#     # println("ODE | wgrads, z: ", wgrads, z)
#     for i in eachindex(wgrads)
#         push!(gradients, wgrads[i])
#     end
#     println("ODE | M: Final layer gradients | ", wgrads)
#     println("ODE | G: Gradients at t=0 | ", gradients)
    return z, yhat, loss, gradients
end


function one_step_node(x, y, params, LR, dims)
    println("=======ODE==================")
    println("ODE | Input: $x | Target: $y")
    println("params before | ", params)
    z, yhat, loss, gradients = training_step(x, y, params)

    # Parameter update
    for param_index in eachindex(gradients)
        params[param_index] -= LR * gradients[param_index]
    end
    params[dims^2+1] = 0.0
    params[dims^2+2] = 1.0
    println("==============ODE END=============")
    return params
    
end


################################################################## 

function node_main(params, train, val; dims=3, EPOCHS=20, LR=0.001, threshold=nothing)
    # Begin the training process
    losses = []
    val_losses = []
    for epoch in 1:EPOCHS
        epoch_loss = 0.0
        for i in eachindex(train)
            println("=========EPOCH: $epoch | ITERATION: $i ===========")
            x, y = get_one(train, i)
            
            # Augment
            x = augment(x, dims-length(x))
            for j in 1:length(x)
                x[j] = abs(x[j])
            end
            println("ODE | Input: $x | Target: $y")
            theta, beta, w, h, t0, t1 = sequester_params(params)
            println("Ideal ReLU | ", relu.(theta*x + beta))
            println("params before | ", params)
            z, yhat, loss, gradients = training_step(x, y, params, threshold=threshold)
            epoch_loss += loss
            
#             for param_index in eachindex(gradients)
#                 params[param_index] -= LR * gradients[param_index]
#             end
#             params[dims^2+1] = 0.0
#             params[dims^2+2] = 1.0
#             println("params at t=0 after update | ", params)
        end
#         epoch_loss /= length(train)
#         push!(losses, epoch_loss)
#         lossplts = plot(losses)
#         # png(lossplts, "trainlossplts.png")
#         accuracy = 0.0
#         val_epoch_loss = 0.0
#         before = []
#         after = []
#         yhats = []
#         for v in eachindex(val)
#             println("=======VAL Epoch: $epoch | ITERATION: $v")
#             x, y = get_one(val, v)

#             # Augment
#             x = augment(x, dims - length(x))
#             dims = length(x)

#             theta, beta, w, t0, t1= sequester_params(params, dims)
#             tspan = (t0, t1)
#             @assert length(w) == dims

#             # Forward & Hidden state calculation
#             println("ODE | w at t=0 | ", w)

#             before_tmp = []
#             append!(before_tmp, x)
#             push!(before_tmp, y)
#             push!(before, before_tmp)

#             println("ODE | Input: $x | Target: $y")
#             println("params before | ", params)
#             z, yhat = forward_step(x, theta, w, tspan, threshold=threshold)
#             loss = 0.5 * (yhat - y)^2
        
#             class = math.ceil(yhat)
#             after_tmp = []
#             append!(after_tmp, z)
#             push!(after_tmp, y)
#             push!(after, after_tmp)


#             val_epoch_loss += loss
#             println("params | ", params)

#             yhats_tmp = []
#             append!(yhats_tmp, x)
#             push!(yhats_tmp, class)
#             push!(yhats, yhats_tmp)

#             if class == y
#                 accuracy += 1
#             end
#         end
        # if dims == 2
            # beforeplt = scatter(getindex.(before, 1), getindex.(before, 2), group=getindex.(before, 3))
            # afterplot = scatter(getindex.(after, 1), getindex.(after, 2), group=getindex.(after, 3))
            # yhatplt = scatter(getindex.(yhats, 1), getindex.(yhats, 2), group=getindex.(yhats, 3))
        # end
        # if dims==3
            # beforeplt = scatter3d(getindex.(before, 1), getindex.(before, 2), getindex.(before, 3), group=getindex.(before, 4))
            # afterplot = scatter3d(getindex.(after, 1), getindex.(after, 2), getindex.(after, 3), group=getindex.(after, 4))
            # yhatplt = scatter3d(getindex.(yhats, 1), getindex.(yhats, 2), getindex.(yhats, 3), group=getindex.(yhats, 4))
        # end
        # png(beforeplt, "before.png")
        # png(afterplot, "after.png")
#         println("accuracy: ", accuracy / length(val))
        
        
        # png(yhatplt, "yhats.png")

    end
    
end

node_main (generic function with 1 method)

In [17]:
 function neuralode(; DIMS=3)
    # train = create_linearly_separable_dataset(100, linear, threshold=0.0)
    # val = create_linearly_separable_dataset(40, linear, threshold=0.0)
    train = create_annular_rings_dataset(150)
    val = create_annular_rings_dataset(50)  
    # val = train   

    params_orig = create_node_params(DIMS, t0=0.0, t1=4.0)
    for i in 1:length(params_orig)
        params_orig[i] = abs(params_orig[i])
    end
    theta, beta, w, h, t0, t1 = sequester_params(params_orig)
    node_main(params_orig, train[1:1], val[1:1], dims=DIMS, EPOCHS=1, threshold=0.0, LR=0.001)
end

neuralode()

ODE | Input: [0.11241656541824341, 0.4358801543712616, 0.0] | Target: -1.0
Ideal ReLU | [0.26884228960859996, 0.6384834862007127, 0.8997631288451635]
params before | Any[0.8334208454, 0.0103126134, 1.8675020091, 0.1686510643, 0.9794816587, 1.5727820922, 0.0643065187, 0.2435356233, 0.8191115101, 0.1706569171, 0.1925876962, 0.7863816658, 0.4702697478, 0.2403057774, 1.1513489769, 0.1, 0.0, 4.0]
ODE | w at t=0 | Any[0.4702697478, 0.2403057774, 1.1513489769]
ODE | z at t=T | [0.43142417728770766; 0.7543119449420583; 0.8794857907479267;;]
sAndp: [0.43142417728770766; 0.7543119449420583; 0.8794857907479267;;]Any[0.8334208454, 0.0103126134, 1.8675020091, 0.1686510643, 0.9794816587, 1.5727820922, 0.0643065187, 0.2435356233, 0.8191115101, 0.1706569171, 0.1925876962, 0.7863816658, 0.4702697478, 0.2403057774, 1.1513489769, 0.1, 0.0, 4.0]
ODE | z at t=0 | [-0.21120420472677606, -0.159424385202493, -0.08508218892106033]
