In [None]:
using Omega
using Flux, DiffEqFlux, DifferentialEquations, Plots, DiffEqNoiseProcess
PLOTSPATH = joinpath(@__DIR__, "..", "figures")

┌ Info: Recompiling stale cache file /home/zenna/.julia/compiled/v1.1/Omega/cIe5P.ji for Omega [1af16e33-887a-59b3-8344-18f1671b3ade]
└ @ Base loading.jl:1184


### The prior: differential equations model for population dynamics

Lotka Volterra represents dynamics of wolves and Rabbit populations over time

In [None]:
function lotka_volterra(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

Initial conditions of dynamical system are uniformly distributed

In [None]:
u0 = uniform(0.5, 1.5, (2,))

At `t_now` we observe the populations

In [None]:
t_now = 20.0

The time span of the ODE integration is from 0 to `t_now`.  We make it a constant random variable to be easily intervenable

In [None]:
tspan = constant((0.0, t_now))

Prior over Lotka-Volterra parameters

In [None]:
p = ciid(ω -> [uniform(ω, 1.3, 1.7), uniform(ω, 0.7, 1.3), uniform(ω, 2.7, 3.3), uniform(ω, 0.7, 1.3)])

A distribution over simulation problems and solutions

In [None]:
prob = ciid(ω -> ODEProblem(lotka_volterra, u0(ω), tspan(ω), p(ω)))
sol = lift(solve)(prob)

Plot time series from prior

In [None]:
plotfig1() = plot(rand(sol))

In [None]:
plotfig1()

### Counter-factual model

In [None]:
"Generate a counterfatual model"
function gencf(; affect! = integrator -> integrator.u[2] /= 2.0,
                 t_int = uniform(tspan[1], tspan[2]/2.0),
                 tspan = tspan)
  condition = ciid(ω -> (u, t, integrator) -> t == t_int(ω))
  cb = DiscreteCallback(condition, affect!)
  # Solution to differential equation with intervention
  sol_int = ciid(ω -> solve(ODEProblem(lotka_volterra, u0(ω), tspan(ω), p(ω)),
                            EM(),
                            callback = DiscreteCallback(condition(ω), affect!),
                            tstops = t_int(ω)))
end

Plot a solution from an intervened model

In [None]:
function sampleint()
  t, sol_int_ = rand((impulse, sol_int))
  println("intervention occured at time $t")
  plot(sol_int_)
end

Suppose we observe that there are no rabbits

In [None]:
function totalrabbits_(ω; ndays = 10)
  sol_ = sol(ω)
  n = length(sol_)
  rabbits = [sol_[i][1] for i = (n - ndays):n]
  sum(rabbits)
end

totalrabbits = ciid(totalrabbits_)

There are no rabbits if integrated mean value is 0

In [None]:
norabbits = totalrabbits ==ₛ 0.0
toomanyrabbits = totalrabbits ==ₛ 5.0

Effect Of Action #

In [None]:
sol_inc_rab = gencf(; affect! = integrator -> integrator.u[1] += 2.0,
                      t_int = constant(t_now),
                      tspan = constant((0, t_now * 2)))

function plot_effect_action(; n = 100, alg = SSMH, kwargs...)
  samples = rand((toomanyrabbits, sol, sol_inc_rab), toomanyrabbits, n; alg = alg, kwargs...)
  norabbit_, sol_, sol_inc_rab_ = ntranspose(samples)
  p1 = plot(sol_[end], title = "Conditioned Model")
  p2 = plot(sol_inc_rab_[end], title = "Action: Cull Prey")
  display(p1)
  display(p2)
  p1, p2
end

"Affect of increasing the number of predators"
function plot_treatment_action(; n = 10000, alg = SSMH, kwargs...)
  samples = rand((toomanyrabbits, replace(sol, tspan => constant((0, t_now * 2))), sol_inc_rab), toomanyrabbits, n; alg = alg, kwargs...)
  norabbit_, sol_, sol_inc_rab_ = ntranspose(samples)
  a = [sum(extractvals(a, 1, 20.0, 40.0)) for a in sol_[div(n, 2):n]]
  b = [sum(extractvals(a, 1, 20.0, 40.0)) for a in sol_inc_rab_[div(n, 2):n]]
  @show unique(b .- a)
  @show b .- a
  histogram(b .- a, title = "Prey Cull Treatment Effect", yaxis = false)
  # norabbit_, sol_, sol_inc_rab_, a, b
end

Counter Factual #

In [None]:
t_int = uniform(tspan[1], tspan[2]/2.0)
sol_inc_pred = gencf(; t_int = t_int,
                       affect! = integrator -> integrator.u[2] += 2.0)

function plot_inc_pred(; n = 100, alg = SSMH, kwargs...)
  samples = rand((t_int, toomanyrabbits, sol, sol_inc_pred), toomanyrabbits, n; alg = alg, kwargs...)
  t_int_, nor, sol_, sol_inc_pred_ = ntranspose(samples)
  println("intervention occured at time $(t_int_[end])")
  # display(plot(logerr.(nor)))
  # @grab sol_
  # @assert false
  x1, y1 = ntranspose(sol_[end].u)
  x2, y2 = ntranspose(sol_inc_pred_[end].u)
  m = max(maximum(x1), maximum(y1), maximum(x2), maximum(y2))

  p1 = plot(sol_[end], title = "Conditioned Model", ylim = [0, m])
  p2 = plot(sol_inc_pred_[end], title = "Counterfactual: Inc Predators", ylim = [0, m])
  display(p1)
  display(p2)
  p1, p2
end

"Affect of increasing the number of predators"
function plot_treatment(; n = 1000, alg = Replica, kwargs...)
  samples = rand((t_int, toomanyrabbits, sol, sol_inc_pred), toomanyrabbits, n; alg = alg, kwargs...)
  t_int_, nor, sol_, sol_inc_pred_ = ntranspose(samples)
  sol_[end], sol_inc_pred_[end]
  a = [sum(extractvals(a, 1, 0.0, 10.0)) for a in sol_[500:1000]]
  b = [sum(extractvals(a, 1, 0.0, 10.0)) for a in sol_inc_pred_[500:1000]]
  histogram(b .- a, title = "Pred Inc Treatment effect", yaxis = false)
end

"Values of i Population between a and b"
function extractvals(v, id, a, b, ::Type{T} = Float64) where T
  res = Float64[]
  for i = 1:length(v)
    if a < v.t[i] < b
      push!(res, v.u[i][id])
    end
  end
  res
end

Plot

In [None]:
function makeplots(; save = true, fname = joinpath(PLOTSPATH, "allfigs.pdf"))
  @show fname
  @show @__DIR__
  plts_ea = plot_effect_action()
  # plts = [sample() for i = 1:6]
  plt = plot(plts_ea..., plot_inc_pred()..., plot_treatment_action(), plot_treatment(),
             layout = (3,2),
             legend = false)
  display(plt)
  save && savefig(plt, fname)
end

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*