In [None]:
include("../src/celerite.jl")
include("../TransitModel/transit.jl")
using Plots
using Statistics
using ForwardDiff
using LaTeXStrings
using LinearAlgebra
using DelimitedFiles

In [None]:
function trapezoidal_transit(t, t0, r, d, tin)
    trans = zeros(Real, length(t))
    t1 = (t0-d/2-tin)
    t2 = (t0-d/2)
    t3 = (t0+d/2)
    t4 = (t0+d/2+tin)
    intransit = (t.<=t3).&(t.>=t2)
    ingress = (t.<t2).&(t.>t1)
    egress = (t.<t4).&(t.>t3)
    trans[intransit] .= -(r^2)
    trans[ingress] .= -(t[ingress].-t1)*(r^2)/tin
    trans[egress] .= (1 - (r^2)).+(t[egress].-t3)*(r^2)/tin.-1
    return trans
end

function get_derivative(t, x0, p)
    if p == nothing
        d = transit(t, x0...)[1]
    else
        df = @eval (t, t0, b0, rp, d, u1, u2) -> ForwardDiff.derivative($p -> transit(t, t0, b0, rp, d, [u1, u2])[1], $p)
        d = Base.invokelatest(df, t, x0...)
    end
    return d
end

function get_derivative_trapezoid(t, x0, p)
    if p == nothing
        d = transit(t, x0...)[1]
    else
        df = @eval (t, t0, rp, d, tin) -> ForwardDiff.derivative($p -> trapezoidal_transit(t, t0, rp, d, tin), $p)
        d = Base.invokelatest(df, t, x0...)
    end
    return d
end

function precompute_derivatives(t, params, x0; transit_model="trapezoid")
    if transit_model == "trapezoid" 
        func=get_derivative_trapezoid 
    elseif transit_model == "full" 
        func=get_derivative 
    end
    D = zeros(length(params), length(t))
    for (i, p) in enumerate(params)
        D[i, :] = func(t, x0, p)
    end
    return D
end

function fisher(t, gp, params, x0; transit_model="trapezoid")
    D = precompute_derivatives(t, params, x0, transit_model=transit_model)
    if transit_model == "trapezoid" 
        func=get_derivative_trapezoid 
    elseif transit_model == "full" 
        func=get_derivative 
    end
    m = length(gp.Q[1,:])
    n = length(t)
    I = zeros(length(params), length(params))
    for i in 1:length(params)
        df1 = D[i, :]
        if m > 1
            df1 = kron(df1, ones(m))
        end
        for j in 1:length(params)
            df2 = D[j, :]
            if m > 1
                df2 = kron(df2, ones(m))
            end
            I[i, j] = df1'*celerite.apply_inverse(gp, df2)
        end
    end
    return inv(I)
end

function make_gp(t, log_s, log_w0, log_var; bins=nothing, A0=nothing, ds0dlam=1, c=nothing)
    if bins != nothing 
        f = [sum(bins[1:i]) for i in 0:length(bins)]
        var = [exp(log_var)/b for b in bins]
        if c == nothing
            c = [(A0+ds0dlam*(f[i+1]+f[i]))/(A0+ds0dlam*(2*sum(f[2:end-1])+1)) for i in 1:length(bins)]*length(bins)
        else
            c = c./mean(c)
        end
        m = length(c)
        Q = broadcast(*, c, c')
        wn_vec = zeros(length(t)*m)
        for i in 1:m
            wn_vec[i:m:end] .= sqrt.(var[i]) 
        end
    else
        Q = ones(1, 1)
        c = nothing
    end
    log_q = log(1/sqrt(2))
    if bins != nothing
        gp = celerite.Celerite(celerite.SHOTerm(log_s, log_q, log_w0), Q)
        logD = celerite.compute!(gp, t, wn_vec)
    else
        gp = celerite.Celerite(celerite.SHOTerm(log_s, log_q, log_w0))
        logD = celerite.compute!(gp, t, exp.(0.5*log_var))
    end
    return gp
end

In [None]:
# make plots fisher_uncertainty_params_1 and fisher_uncertainty_params_2
# which show the fisher uncertainty for rp, d, and t0 as a function 
# of the correlated noise - white noise amplitude ratio 
# change w0T to get different noise regimes 

t = collect(-4:0.01:4)

t0 = 0
b0 = 0
rp = 0.1
d = 0.5
tin = d/3
u = [0.5, 0.5]
x0 = [t0, rp, d, tin]
c2 = 2
log_total_var = -14
log10_amp_ratio = collect(-3:0.1:3)
log_amp_ratio = log.(10 .^(log10_amp_ratio))
log_sig = 0.5*log_total_var .- 0.5*log.(1 .+ exp.(2*log_amp_ratio))
log_corr_amp = log_amp_ratio + log_sig

w0T = 10
log_w0 = log(w0T) - log(d)
log_s0 = 2*(log_corr_amp)  .- log_w0

In [None]:
monofish = zeros(4, 4, length(log_amp_ratio))
polyfish = zeros(4, 4, length(log_amp_ratio))
for i in 1:length(log_amp_ratio)
    println("computing model $i")
    bins = [0.5, 0.5]
    polygp = make_gp(t, log_s0[i], log_w0, 2*log_sig[i], bins=bins, c=[1, 2])
    monogp = make_gp(t, log_s0[i], log_w0, 2*log_sig[i])
    monofish[:, :, i] = fisher(t, monogp, [:t0, :rp, :d, :tin], x0, transit_model="trapezoid")
    polyfish[:, :, i] = fisher(t, polygp, [:t0, :rp, :d, :tin], x0, transit_model="trapezoid")
end

In [None]:
x = log10_amp_ratio
writedlm("plot_data/fisher_information/correct/monofish_shortt.dat", monofish)
writedlm("plot_data/fisher_information/correct/polyfish_shortt.dat", polyfish)
writedlm("plot_data/fisher_information/correct/x_longt.dat", x)

In [None]:
dt = t[2]-t[1]
x = log10_amp_ratio
x_mcmc = collect(-3.5:0.5:2.5)

monofish = reshape(readdlm("plot_data/fisher_information/correct/monofish_shortt.dat"), 4, 4, 61)
polyfish = reshape(readdlm("plot_data/fisher_information/correct/polyfish_shortt.dat"), 4, 4, 61)
x = readdlm("plot_data/fisher_information/correct/x_shortt.dat")

#monomcmc = [readdlm("plot_data/mcmc_uncertainty/$(label)_10_mono.dat") for label in ["rp", "d", "t0"]]
#polymcmc = [readdlm("plot_data/mcmc_uncertainty/$(label)_10_poly.dat") for label in ["rp", "d", "t0"]]

plot(x, 0.5*log10.([diag(monofish[:, :, i])[1] for i in 1:size(monofish)[3]]),
    linewidth=3, 
    linestyle=:dash,
    color=palette(:dark)[2],
    label="")
plot!(x, 0.5*log10.([diag(monofish[:, :, i])[2] for i in 1:size(monofish)[3]]), 
    linewidth=3, 
    linestyle=:dash,
    color=palette(:dark)[1],
    label="")
plot!(x, 0.5*log10.([diag(monofish[:, :, i])[3] for i in 1:size(monofish)[3]]), 
    linewidth=3,
    linestyle=:dash,
    legend=:bottomleft, 
    color=palette(:dark)[5],
    label="")
plot!(x, 0.5*log10.([diag(polyfish[:, :, i])[1] for i in 1:size(polyfish)[3]]), 
    linewidth=3, 
    color=palette(:dark)[2], 
    label=L"t_0")
plot!(x, 0.5*log10.([diag(polyfish[:, :, i])[2] for i in 1:size(polyfish)[3]]), 
    linewidth=3, 
    color=palette(:dark)[1], 
    label=L"R_p/R_*")
plot!(x, 0.5*log10.([diag(polyfish[:, :, i])[3] for i in 1:size(polyfish)[3]]), 
    linewidth=3,
    color=palette(:dark)[5],
    label=L"\Delta T", 
    xlabel=L"\mathrm{log}(A_\mathrm{corr}/A_\mathrm{wn})", 
    ylabel=L"\sigma",
    framestyle=:box)

# overplot mcmc results
plot!(x_mcmc, log10.(monomcmc[1]), linewidth=0, marker=:circle, markersize=3, 
    color=palette(:dark)[1],
    label="")

plot!(x_mcmc, log10.(monomcmc[3]), linewidth=0, marker=:circle, markersize=3, 
    color=palette(:dark)[2],
    label="")

plot!(x_mcmc, log10.(monomcmc[2]), linewidth=0, marker=:circle, markersize=3, 
    color=palette(:dark)[5],
    label="")