In [None]:
using Turing
using DataFrames
using CSV
using Distributions
using StatsFuns
using StatsPlots
using StatsBase
using Random

default(label=false);

## Using experimental data (first half of the lecture notes)

In [None]:
d = DataFrame(CSV.File("data/reedfrogs.csv"))
describe(d)

d.tank = 1:nrow(d)
d


### Conventional single-level model

In [None]:
@model function frog_single_level(S, N, tank)

    a = zeros(length(tank))
    for i in 1:length(tank)
        a[i] ~ Normal(0, 1.5)
    end
    
    for i in 1:length(tank)
        p = logistic(a[i])  # probability of survival or proportional survival
        S[i] ~ Binomial(N[i], p)
    end

end

#@model function frog_single_level(S, N, tank)
#
#    a ~ filldist(Normal(0, 1.5), length(tank))  # offsets are defined for each tank
#    p = logistic.(a)  # probability of survival or proportional survival
#    S .~ Binomial.(N, p)
#
#end

In [None]:
Random.seed!(1)
frog_single_level_ch = sample(frog_single_level(d.surv, d.density, d.tank), NUTS(200, 0.65, init_ϵ=0.5), 1000);
frog_single_level_df = DataFrame(frog_single_level_ch)

### Multilevel model

In [None]:
@model function frog_multi_level(S, N, tank)

    σ ~ Exponential()
    ā ~ Normal(0, 1.5)

    a = zeros(length(tank))  # fancier way: a = Vector{Real}(undef, length(tank))
    for i in 1:length(tank)
        a[i] ~ Normal(ā, σ)
    end
    
    for i in 1:length(tank)
        p = logistic(a[i])  # probability of survival or proportional survival
        S[i] ~ Binomial(N[i], p)
    end

end


# @model function frog_multi_level(S, N, tank)

#     σ ~ Exponential()
#     ā ~ Normal(0, 1.5)

#     a ~ filldist(Normal(ā, σ), length(tank))
#     p = logistic.(a)
#     S .~ Binomial.(N, p)

# end

In [None]:
Random.seed!(1)
frog_multi_level_ch = sample(frog_multi_level(d.surv, d.density, d.tank), NUTS(200, 0.65, init_ϵ=0.2), 1000)
frog_multi_level_df = DataFrame(frog_multi_level_ch);

### Analysis

In [None]:
single_level_surival = zeros(nrow(frog_single_level_df), nrow(d));

for j in 1:nrow(frog_single_level_df)
    for i in 1:nrow(d)
        tank_id = d.tank[i]
        tank_density = d.density[i]
        tank_surv = d.surv[i]  
        single_level_surival[j,i] = binomlogpdf(tank_density, logistic(frog_single_level_df[j,"a[$(tank_id)]"]), tank_surv)
    end
end

single_level_surival

In [None]:
multi_level_survival = zeros(nrow(frog_multi_level_df), nrow(d));

for j in 1:nrow(frog_multi_level_df)
    for i in 1:nrow(d)
        tank_id = d.tank[i]
        tank_density = d.density[i]
        tank_surv = d.surv[i]  
        multi_level_survival[j,i] = binomlogpdf(tank_density, logistic(frog_multi_level_df[j,"a[$(tank_id)]"]), tank_surv)
    end
end

multi_level_survival

In [None]:
## a "fancier" and more compact way

# link_fun = (chain_df, dr) -> begin
#     a = chain_df[:,"a[$(dr.tank)]"]
#     p = logistic.(a)
#     binomlogpdf.(dr.density, p, dr.surv)
# end

# single_level_survival = map( dr -> link_fun(frog_single_level_df, dr), eachrow(d) )
# single_level_survival = hcat(single_level_survival...)

# multi_level_survival = map( dr -> link_fun(frog_multi_level_df, dr), eachrow(d) )
# multi_level_survival = hcat(multi_level_survival...);


In [None]:
# sample 10_000 samples again from the multi-level model

post = sample(frog_multi_level_ch, 10000)
post_df = DataFrame(post)


In [None]:
propsurv_est = [                # array comprehension
    logistic(mean(post_df[:,"a[$i]"]))
    for i ∈ 1:nrow(d)
]

In [None]:
scatter(propsurv_est, mc=:white, label="model", legend=:topright, xlab="tank", ylab="proportion survival", ylim=(-0.05, 1.05))


In [None]:
scatter!(d.propsurv, mc=:blue, ms=3, label="data")
hline!([mean(logistic.(post_df.ā))], ls=:dash, c=:black)
vline!([16.5, 32.5], c=:black)
annotate!([
        (8, 0, ("small tanks", 10)),
        (16+8, 0, ("medium tanks", 10)),
        (32+8, 0, ("large tanks", 10))
])


In [None]:
p1 = plot(xlim=(-3, 4), xlab="Log-odds survival", ylab="Density");

for j in 1:100  # first 100 rows (or samples)
    plot!(Normal(post_df.ā[j], post_df.σ[j]), c=:black, alpha=0.2)
end

## a fancier way
# for r ∈ first(eachrow(post_df), 100)
#     plot!(Normal(r.ā, r.σ), c=:black, alpha=0.2)
# end

p1


In [None]:
p2 = plot(xlab="Probability survival", ylab="Density", xlim=(-0.1, 1.1));

sim_tanks_logistic = zeros(8000)
for j in 1:8000  # 8000 rows (or samples)
    sim_tanks_logistic[j] = logistic( rand( Normal(post_df.ā[j], post_df.σ[j]) ) )
end
density!(sim_tanks_logistic, lw=2)

# a more compact code using broadcasting
#sim_tanks = rand.(Normal.(post_df.ā[1:8000], post_df.σ[1:8000]));  # we could have used for loops
#density!(logistic.(sim_tanks), lw=2)

plot(p1, p2, size=(800, 400))


## Using synthetic data

### Generate syntheric data

In [None]:
## Varying effects and the underfitting/overfitting trade-off
# Generate a mock data to test the models

Random.seed!(1)

ā = 1.5
σ = 1.5
nponds = 60
Ni = repeat([3, 10, 25, 35], inner=15);

a_pond = rand(Normal(ā, σ), nponds);    # mock "true" data

dsim = DataFrame(pond=1:nponds, Ni=Ni, true_a=a_pond);

dsim.true_p = logistic.(dsim.true_a);

dsim


In [None]:
dsim.Si = rand.(Binomial.(dsim.Ni, dsim.true_p));

dsim.p_sim = dsim.Si ./ dsim.Ni;

In [None]:
dsim

### Single-level model

In [None]:
# no pooling (single-level model)
@model function pond_single_level(Si, Ni)

    a_pond = zeros(length(Ni))
    for i in 1:length(Ni)
        a_pond[i] ~ Normal(0, 1.5)
    end
    
    for i in 1:length(Ni)
        p = logistic(a_pond[i])  # probability of survival or proportional survival
        Si[i] ~ Binomial(Ni[i], p)
    end

end

#@model function frog_single_level(S, N, tank)
#
#    a_pond ~ filldist(Normal(0, 1.5), length(Ni))  # offsets are defined for each tank
#    p = logistic.(a_pond)  # probability of survival or proportional survival
#    Si .~ Binomial.(Ni, p)
#
#end

### Multi-level model

In [None]:
# partial pooling using the multi-level model
@model function pond_multi_level(Si, Ni) #, pond)

    σ ~ Exponential()
    ā ~ Normal(0, 1.5)

    a_pond = zeros(length(Ni))  # fancier way: a = Vector{Real}(undef, length(tank))
    for i in 1:length(Ni)
        a_pond[i] ~ Normal(ā, σ)
    end
    
    for i in 1:length(Ni)
        p = logistic(a_pond[i])  # probability of survival or proportional survival
        Si[i] ~ Binomial(Ni[i], p)
    end

    # # a more compact way
    # a_pond ~ filldist(Normal(ā, σ), length(Ni))
    # p = logistic.(a_pond)
    # @. Si ~ Binomial(Ni, p)

end

### Running MC

In [None]:
Random.seed!(1)
pond_single_level_ch = sample(pond_single_level(dsim.Si, dsim.Ni), NUTS(), 1000);
pond_single_level_df = DataFrame(pond_single_level_ch)

In [None]:
Random.seed!(1)
pond_multi_level_ch = sample(pond_multi_level(dsim.Si, dsim.Ni), NUTS(), 1000)
pond_multi_level_df = DataFrame(pond_multi_level_ch)


### Analysis of errors

In [None]:
dsim.p_nopool = [
    mean(logistic.(pond_single_level_df[:,"a_pond[$i]"]))
    for i ∈ 1:nponds
]

In [None]:
dsim.p_partpool = [
    mean(logistic.(pond_multi_level_df[:,"a_pond[$i]"]))
    for i ∈ 1:nponds
]

In [None]:
nopool_error = abs.(dsim.p_nopool - dsim.true_p)
partpool_error = abs.(dsim.p_partpool - dsim.true_p);

plt = scatter(nopool_error, xlab="pond", ylab="absolute error", label = "no pooling")
scatter!(partpool_error, mc=:white, label = "partial pooling")

vline!([15.5, 30.5, 45.5], c=:black)
annotate!([
        (7, 0.37, ("small ponds", 10)),
        (16+7, 0.37, ("mid-small ponds", 10)),
        (31+7, 0.37, ("mid-large ponds", 10)),
        (45+7, 0.37, ("large ponds", 10))
])

In [None]:
dsim.nopool_error = nopool_error;
dsim.partpool_error = partpool_error;

In [None]:
dsim

### Means of the errors in each cluster

In [None]:
# group results according to the number of tadpoles
gb = groupby(dsim, :Ni)

In [None]:
pools = combine(gb, :nopool_error => mean, :partpool_error => mean, :pond => minimum, :pond => maximum)

In [None]:
pools.pond_minimum, pools.pond_maximum # ranges of pond numbers with the same numbers of tadpoles for plotting purpose


In [None]:
for i in 1:length(pools.pond_minimum)
    plot!([pools.pond_minimum[i],pools.pond_maximum[i]],[pools.nopool_error_mean[i],pools.nopool_error_mean[i]])
end

for i in 1:length(pools.pond_minimum)
    plot!([pools.pond_minimum[i],pools.pond_maximum[i]],[pools.partpool_error_mean[i],pools.partpool_error_mean[i]], line=:dash)
end

plt