In [None]:
using Catalyst

function heaviside(t)
   0.5 * (sign(t) + 1)
end
n=10
function hill(t,τ)
   t^n/(t^n+τ^n)
end
function logistic(t,τ)
   1/(1+3^(-1*(t-τ)))
end
function steps(t)
    .2*logistic(t,20) + .4*logistic(t,70) - .6*logistic(t,120)
end

sigma_model = @reaction_network begin
    μ1*steps(t), ∅ --> S1
#     .02*μ*t, ∅ --> S
#     μ, ∅ --> S
    k/(S1+Kd), S1 --> S1 + G1
    γ, S1 --> ∅
    γ, G1 --> ∅
    μ2*steps(t), ∅ --> S2
#     .02*μ*t, ∅ --> S
#     μ, ∅ --> S
    k/(S2+Kd), S2 --> S2 + G2
    γ, S2 --> ∅
    γ, G2 --> ∅
    end μ1 μ2 k Kd γ

In [None]:
using Latexify
odes = convert(ODESystem, sigma_model)
latexify(odes)

In [None]:
using DifferentialEquations, Plots

cc = ["#835C3B" "#10DA05"]

## Parameters [μ k Kd γ]
p = (20, 10, 10, 8, .2)
u₀ = [1., 1.,1.,1.]
tspan = (0., 150.)


# create the ODEProblem 
ds = ODEProblem(sigma_model, u₀, tspan, p)
sol = solve(ds, Tsit5())
plot(sol, lw=4, lc=cc, legend=false)

In [None]:
ds_discrete = DiscreteProblem(sigma_model, [1, 1], tspan, p)
jump_ds = JumpProblem(sigma_model, ds_discrete, Direct())

sol = solve(jump_ds, SSAStepper())
plot(sol, lw=4, lc=cc, legend=false)

In [None]:

## Parameter inference
x0_inf = [1.; 1.;1.;1.]
p_inf = [20, 10, 10, 8, .2]
tspan = (0.0, 150.0)

ds = ODEProblem(sigma_model, x0_inf, tspan, p_inf)

sol = solve(ds,Tsit5(),saveat=3)
targetdata = Array(sol) + .5*randn(size(Array(sol)))

plot(sol, alpha = 0.5, lc=cc, legend = false); scatter!(sol.t, targetdata', color=cc)

In [None]:
using Turing 
@model function fitmodel(data, ds)
    σ ~ InverseGamma(2, 3) 
    μ1 ~ truncated(Normal(15.0,5.0),0,100)
    μ2 ~ truncated(Normal(15.0,5.0),0,100)
    k ~ truncated(Normal(3.0,5.0),0,100)
    Kd ~ truncated(Normal(20.0,5.0),0,100)
    γ ~ truncated(Normal(2.0,1.0),0,10)

    p = [μ1,μ2,k,Kd,γ]
    prob = remake(ds, p=p)
    predicted = solve(prob,Tsit5(),saveat=3)

    for i = 1:length(predicted)
#         data[:,i] ~ MvNormal(predicted[i], σ)
        data[:,i] ~ MvNormal(predicted[[2,4],i], σ)

    end
end

model = fitmodel(targetdata[[2,4],:], ds)
@time chain = mapreduce(c -> sample(model, NUTS(.7), 1000), chainscat, 1:2)

In [None]:
using StatsPlots
plot(chain)

In [None]:
scatter(sol.t, targetdata', color=cc)
chain_array = Array(chain)
for k in 1:30
    resol = solve(remake(ds,p=chain_array[rand(1:1500), 2:6]),Tsit5(),saveat=2)
    plot!(resol, w=2, alpha=0.2, color = "#BBBBBB", legend = false)
end
# display(pl)
plot!(sol, w=.1, legend = false)

In [None]:
resol