In [10]:
using Pkg
# Pkg.add("ReactionNetworkImporters")
# Pkg.add("Dictionaries")
# Pkg.add("LaTeXStrings")
# Pkg.add("Statistics")
# Pkg.add("ColorSchemes")
# Pkg.add("IterTools");

using DifferentialEquations
using Random
using Plots
using Formatting
using LinearAlgebra
using Noise
using ReactionNetworkImporters
using Dictionaries
using LaTeXStrings
using Statistics
using ColorSchemes
using Catalyst
using IterTools

include("datasets.jl")
include("utils.jl")
include("reactions2D.jl")
include("neuralode.jl")


neuralode (generic function with 1 method)

In [11]:
function _convert_species2var(sp)
    ret = string(sp)
    ret = replace(ret, "(t)" => "")
    return ret
end


# Verified: @show _index2param("P", 3, -3.0)
function _index2Dvar(sym, index, val; dims=2)
    second = (index - 1) % dims + 1
    first = (index - 1) ÷ dims + 1
    return Dict(
        "$(sym)$(first)$(second)p" => max(0, val),
        "$(sym)$(first)$(second)m" => max(0, -val)
    )
end


function _index1Dvar(sym, index, val; dims=2)
    return Dict(
        "$sym$(index)p" => max(0.0, val),
        "$sym$(index)m" => max(0.0, -val)
    )
end


function _prepare_u(rn, vars)
    ss = species(rn)
    uvalues = [vars[_convert_species2var(sp)] for sp in ss]
    u = Pair.(ss, uvalues)
    return u
end


function _print_vars(vars, prefix; title="")
    println(title, "---------------")
    Xvars = [k for k in keys(vars) if startswith(k, prefix)]
    sort!(Xvars)
    for xvp in Xvars
        if endswith(xvp, "p")
            xvm = replace(xvp, "p" => "m")
            println("$xvp: $(vars[xvp]) | $xvm: $(vars[xvm]) | $(vars[xvp] - vars[xvm])")
        end

    end
    # println("CRN | Title: ", title, " | Prefix: ", prefix)
    # Pspecies = [k for k in keys(vars) if startswith(k, prefix)]
    # sort!(Pspecies)
    # for psp in Pspecies
    #     println(psp, ": ", vars[psp], " | ")
    # end
    # println("-------------")
end



function _filter_rn_species(rn; prefix="Z")
    ss = species(rn)
    xs = filter(x->startswith(string(x), prefix), ss)
    return xs
end

_filter_rn_species (generic function with 1 method)

In [14]:
function crn_dual_node_fwd(vars; tspan=(0.0, 1.0))
    u = _prepare_u(rn_dual_node_fwd, vars)
    p = []
    @show u
    sol = simulate_reaction_network(rn_dual_node_fwd, u, p, tspan=tspan)
    # First update the Z species 
    Zspecies = _filter_rn_species(rn_dual_node_fwd, prefix="Z")
    for i in eachindex(Zspecies)
        vars[_convert_species2var(Zspecies[i])] = sol[end][i]
    end
    _print_vars(vars, "Z", title="CRN | z at t=T |")
    return sol
end


crn_dual_node_fwd (generic function with 1 method)

In [19]:
function crn_main(params, train, val; dims=2, EPOCHS=10, LR=0.001, tspan=(0.0, 1.0))
    # Initialize a dictionary to track concentrations of all the species
    vars = Dict();

    # Get all the involved CRNs and add their species to the vars
    crns = [rn_dual_node_fwd, rn_dual_backprop, rn_param_update, rn_final_layer_update, rn_annihilation_reactions]    
    for crn in crns
        for sp in species(crn)
            get!(vars, _convert_species2var(sp), 0.0)
        end
    end
    
    # Assign the values of the parameters
    for param_index in 1:dims^2
        d = _index2Dvar("P", param_index, params[param_index], dims=dims)
        for (k,v) in d
            vars[k] = v
        end
    end

    # Adding time species, although we don't manipulate them now
    vars["T0"] = 0.0
    vars["T1"] = 1.0

    # Assign the weight parameters
    offset = dims^2 + 2  # 2 for t0 and t1
    for param_index in (offset+1):length(params)
        d = _index1Dvar("W", param_index-offset, params[param_index], dims=dims)
        for (k,v) in d
            vars[k] = v
        end
    end

    tr_losses = []
    for epoch in 1:EPOCHS
        tr_epoch_loss = 0.0
        for i in eachindex(train)
            x, y = get_one(train, i)
            x = augment(x, dims-length(x))
            yvec = [y 1-y]
            @show x
            for i in eachindex(x)
                d = _index1Dvar("Z", i, x[i], dims=dims)
                for (k,v) in d
                    vars[k] = v
                end
            end
            # Forward stage
            sol = crn_dual_node_fwd(vars, tspan=tspan)
            
            ss = species(rn_dual_node_fwd)
            for si in eachindex(ss)
                println(si)
                if startswith(string(ss[si]), "Z")
                    plt = plot!(sol[si])
                    png(plt, "Z.png")
                end
            end
            
        end
    end
end

function neuralcrn(;DIMS=2)
    train = create_linearly_separable_dataset(100, linear, threshold=0.0)
    val = create_linearly_separable_dataset(40, linear, threshold=0.0)
    params_orig = create_node_params(DIMS, t0=0.0, t1=1.0)
    node_main(params_orig, train[1:1], [], EPOCHS=1)
    println("===============================")
    crn_main(params_orig, train[1:1], [], EPOCHS=1, tspan=(0.0, 1.0))
end

neuralcrn()

nneg: 44 npos: 56nneg: 22 npos: 18ODE | Input: [-0.1077658803024611, -0.23979584149967573] | Target: 0.0
params before | Any[0.0170054417, 2.244159622, -0.2516677402, 0.2833236587, 0.0, 1.0, -0.7535499601, 0.226244532]
ODE | w at t=0 | Any[-0.7535499601, 0.226244532]
ODE | z at t=T | [-0.6471685752434628; -0.21093638376668616;;]
ODE | yhat at t=T | 0.43995065062561894
ODE | Adjoint at t=T | [-0.3315247952249042; 0.09953642905388867;;]
ODE | Gradients at t=T | [0.0; 0.0; 0.0; 0.0;;]
ODE | Adjoint at t=0 | [-0.26436324536693945, -0.6883657201251144]
ODE | Gradients at t=0 |  [0.1232698484870147, 0.07461591383895802, 0.07395890522780879, 0.07028672651061718]
ODE | Final layer gradients | [-0.2847222357428163; -0.09280159927876883;;]
ODE | gradients | [0.1232698484870147, 0.07461591383895802, 0.07395890522780879, 0.07028672651061718, -0.17084710120263694, -0.17084710120263694, -0.2847222357428163, -0.09280159927876883]
params | Any[0.016882171851512983, 2.244085006086161, -0.25174169910522


3
4
5
6
7
8
9
10
11
12
