In [1]:
using Catalyst
using ColorSchemes
using Dictionaries
using DifferentialEquations
using Formatting
using IterTools
using LaTeXStrings
using LinearAlgebra
using NNlib
using Noise
using Pkg
using Plots
using Random
using ReactionNetworkImporters
using Statistics

nothing

In [2]:
include("datasets.jl")
include("utils.jl")
include("reactionsReLU.jl")
include("neuralode.jl")
include("neuralcrn.jl")
nothing

LoadError: LoadError: @reaction_network notation where parameters are declared after "end", e.g. like:

```julia
@reaction_network begin
    p, 0 --> X
    d, X --> 0
end p d
```

has been deprecated in favor of a notation where the parameters are inferred, e.g:

```julia
@reaction_network begin
    p, 0 --> X
    d, X --> 0
end
```

Parameters and species can be explicitly indicated using the @parameters and @species
macros, e.g:

```julia
@reaction_network begin
    @parameters p d
    @species X(t)
    p, 0 --> X
    d, X --> 0
end
```

in expression starting at /Users/rajiv/Desktop/PhD/neural-ode/NeuralCRNGen/julia/reactionsReLU.jl:393
in expression starting at /Users/rajiv/Desktop/PhD/neural-ode/NeuralCRNGen/julia/reactionsReLU.jl:393

In [3]:
function _convert_species2var(sp)
    ret = string(sp)
    ret = replace(ret, "(t)" => "")
    return ret
end


# Verified: @show _index2param("P", 3, -3.0)
function _index2Dvar(sym, index, val; dims=2)
    second = (index - 1) % dims + 1
    first = (index - 1) ÷ dims + 1
    return Dict(
        "$(sym)$(first)$(second)p" => max(0, val),
        "$(sym)$(first)$(second)m" => max(0, -val)
    )
end


function _index1Dvar(sym, index, val; dims=2)
    return Dict(
        "$sym$(index)p" => max(0.0, val),
        "$sym$(index)m" => max(0.0, -val)
    )
end


function _prepare_u(rn, vars)
    ss = species(rn)
    uvalues = [vars[_convert_species2var(sp)] for sp in ss]
    u = Pair.(ss, uvalues)
    return u
end


function _print_vars(vars, prefix; title="")
    println(title, "---------------")
    Xvars = [k for k in keys(vars) if startswith(k, prefix)]
    sort!(Xvars)
    for xvp in Xvars
        if endswith(xvp, "p")
            xvm = replace(xvp, "p" => "m")
            println("$xvp: $(vars[xvp]) | $xvm: $(vars[xvm]) | $(vars[xvp] - vars[xvm])")
        end

    end
    # println("CRN | Title: ", title, " | Prefix: ", prefix)
    # Pspecies = [k for k in keys(vars) if startswith(k, prefix)]
    # sort!(Pspecies)
    # for psp in Pspecies
    #     println(psp, ": ", vars[psp], " | ")
    # end
    # println("-------------")
end



function _filter_rn_species(rn; prefix="Z")
    ss = species(rn)
    xs = filter(x->startswith(string(x), prefix), ss)
    return xs
end



_filter_rn_species (generic function with 1 method)

In [4]:
function crn_dual_node_fwd(vars; tspan=(0.0, 1.0))
    u = _prepare_u(rn_dual_node_fwd, vars)
    p = []
    @show u
    sol = simulate_reaction_network(rn_dual_node_fwd, u, p, tspan=tspan)
    # First update the Z species 
    Zspecies = _filter_rn_species(rn_dual_node_fwd, prefix="Z")
    for i in eachindex(Zspecies)
        vars[_convert_species2var(Zspecies[i])] = sol[end][i]
    end
    _print_vars(vars, "Z", title="CRN | z at t=T |")
    return sol
end


crn_dual_node_fwd (generic function with 1 method)

In [5]:
# function crn_main(params, train, val; dims=3, EPOCHS=10, LR=0.001, tspan=(0.0, 1.0))
#     # Initialize a dictionary to track concentrations of all the species
#     vars = Dict();

#     # Get all the involved CRNs and add their species to the vars 
#     crns = [
#         rn_dual_node_relu_fwd,
#         rn_dual_node_relu_bwd, 
#         rn_param_update,
#         rn_final_layer_update, 
#         rn_dissipate_reactions,
#     ]
    
#     # Zeroing out all the species
#     for crn in crns
#         for sp in species(crn)
#             get!(vars, _convert_species2var(sp), 0.0)
#         end
#     end
    
#     # Assign the values of the parameters
#     for param_index in 1:dims^2
#         d = _index2Dvar("P", param_index, params[param_index], dims=dims)
#         for (k,v) in d
#             vars[k] = v
#         end
#     end

#     # Adding time species, although we don't manipulate them now
#     vars["T0"] = 0.0
#     vars["T1"] = 1.0

#     # Assign the weight parameters
#     offset = dims^2 + 2  # 2 for t0 and t1
#     for param_index in (offset+1):length(params)
#         d = _index1Dvar("W", param_index-offset, params[param_index], dims=dims)
#         for (k,v) in d
#             vars[k] = v
#         end
#     end

#     tr_losses = []
#     for epoch in 1:EPOCHS
#         tr_epoch_loss = 0.0
#         for i in eachindex(train)
#             x, y = get_one(train, i)
#             x = augment(x, dims-length(x))
#             yvec = [y 1-y]
#             @show x
#             for i in eachindex(x)
#                 d = _index1Dvar("Z", i, x[i], dims=dims)
#                 for (k,v) in d
#                     vars[k] = v
#                 end
#             end
#             # Forward stage
#             sol = crn_dual_node_fwd(vars, tspan=tspan)
            
#             ss = species(rn_dual_node_fwd)
#             for si in eachindex(ss)
#                 println(si)
#                 if startswith(string(ss[si]), "Z")
#                     plt = plot!(sol[si])
#                     png(plt, "Z.png")
#                 end
#             end
            
#         end
#     end
# end

function neuralcrn(;DIMS=3)
    train = create_linearly_separable_dataset(100, linear, threshold=0.0)
    val = create_linearly_separable_dataset(40, linear, threshold=0.0)
    params_orig = create_node_params(DIMS, t0=0.0, t1=1.0)
#     node_main(params_orig, train[1:1], [], EPOCHS=1)
#     println("===============================")
    crn_main(params_orig, dims=DIMS, train[1:1], [], EPOCHS=1, tspan=(0.0, 1.0))
end

neuralcrn()

nneg: 47 npos: 53nneg: 25 npos: 15

UndefVarError: UndefVarError: `crn_main` not defined

In [52]:
p = randn(17)
params, beta, w, t0, t1 = sequester_params(p, 3)

([1.392623084786063 0.381796215414273 0.4499716209092798; -0.8109635966725603 1.0751084502670774 -0.4477546495006805; 0.018576260076363774 -0.892497163225951 0.5805765293097936], [-1.4110319614130438, -0.12452155962808596, -1.2808320700717801], [-0.9356519809539731, -0.31887501959852027, -1.1269265876586663], -0.17032481805714994, -0.3941344707687343)