# Setup

In [None]:
using DifferentialEquations, Plots, Turing, Interpolations

# Three node network

Based on lorenz example. Writing out the system of equations individually instead of looping through each node.

In [None]:
phi(a::Number) = (exp(2*a)-1)/(exp(2*a)+1)

In [None]:
function additive_noise!(du,u,p,t)
    s,g,W,I,σ = p
    for i = 1:length(u)
#         du[i] = 0 #no noise
        du[i] = σ #This value multiplies a random number drawn from N(0, dt)
        #i.e. noise then scales with sampling rate and effective sd of noise will be e.g. 0.1*dt
    end
end

function network_model!(du,u,p,t)
    s,g,W,I,σ = p
    for i=1:length(u)
        du[i] = -u[i] + s*phi(u[i]) + g*(sum(W[i,:].*u)) + I(t)
    end
end

In [None]:
num_nodes = 3
u0 = repeat([0.0], num_nodes)
Tmax = 100.0
tspan = (0,Tmax)

s = 0.3
g = 0.7
W = [0.0 0.2 0.0; 0.4 0.0 0.0; 0.0 0.3 0.0]

ts = 1:1:Tmax
task = repeat([0.0], 100)
# task[4] = 1 #without impulse it'd be "resting state"
I = LinearInterpolation(ts, task, extrapolation_bc = 0)
dt = 0.5
σ = 0.1

p = [s,g,W,I,σ]

prob_sde = SDEProblem(network_model,additive_noise,u0,tspan,p)


Note: If you don't turn off adaptive step size then the number of solved points will not be Tmax/dt. It will be determined as detailed [here](https://diffeq.sciml.ai/stable/basics/common_solver_opts/#Basic-Stepsize-Control)

In [None]:
sol = solve(prob_sde, dt=dt, adaptive=false)

In [None]:
plot(sol)

In [None]:
ensembleprob = EnsembleProblem(prob_sde)
data = solve(ensembleprob, SOSRI(), trajectories=100, dt=dt, adaptive=false)
plot(EnsembleSummary(data))

In [None]:
Turing.setadbackend(:forwarddiff)
@model function fit_nmm(data, prob)
    σ ~ LogNormal(-1,1)
    
    s = 0.3
    g = 0.7
    W = [0.0 0.2 0.0; 0.4 0.0 0.0; 0.0 0.3 0.0]
    ts = 1:1:Tmax
    task = repeat([0.0], 100)
    I = LinearInterpolation(ts, task, extrapolation_bc = 0)
    p = [s,g,W,I,σ]
    prob = remake(prob, p=p)
    predicted = solve(prob,SOSRI(),dt=dt, adaptive=false)

    if predicted.retcode != :Success
        Turing.acclogp!(_varinfo, -Inf)
    end
    for j in 1:length(data) #number of nodes
        for i = 1:length(predicted) #number of time points
            data[j][i] ~ MvNormal(predicted[i],σ)
        end
    end
end;

In [None]:
model = fit_nmm(data, prob_sde)
chain = sample(model, NUTS(0.25), 5000)

In [None]:
plot(chain)

In [None]:
names(Main, imported=true)