In [1]:
using StarStats

Load the entire database of grids 

In [2]:
using Printf
using BenchmarkTools
function path_constructor(strings::Vector{String})
    DATA_FOLDER = ENV["STARSTATS_TEST_DATA_FOLDER"]
    return DATA_FOLDER*"/LMC/LMC_$(strings[1])_$(strings[2])_$(strings[3]).track.gz"
end
masses = [@sprintf("%.3f", x) for x in range(0.9,2.1,step=0.025)]
rotation = [@sprintf("%.2f", x) for x in range(0.0,0.9,step=0.1)]
overshoot = [@sprintf("%.2f", x) for x in range(0.5,4.5,step=0.5)]
#change name grid to e.g. star_grid
star_grid = ModelDataGrid([rotation,masses,overshoot],
    [:rotation,:logM,:overshoot])
#@benchmark 
load_grid(star_grid,path_constructor,gz_dataframe_loader_with_Teff_and_star_age_fix); 
compute_distances_and_EEPs(star_grid)

In [3]:
using Turing, Distributions

@model function star_model(logTeff_obs, logTeff_err, logL_obs, logL_err, vrot_obs, vrot_err, star_grid)
  x ~ Uniform(0,3)
  logM ~ Uniform(0.9, 1.5)
  rotation ~ Uniform(0,0.9)
  overshoot ~ Uniform(0.5,1.5)
  logTeff = interpolate_grid_quantity(star_grid,[rotation, logM, overshoot],:logTeff,x)
  logL = interpolate_grid_quantity(star_grid,[rotation, logM, overshoot],:logL,x)
  vrot = interpolate_grid_quantity(star_grid,[rotation, logM, overshoot],:vrot,x)
  

  logTeff_obs ~ Normal(logTeff, logTeff_err)
  logL_obs ~ Normal(logL, logL_err)
  vrot_obs ~ Normal(vrot, vrot_err)
  return logTeff_obs, logL_obs, vrot_obs
end

star_model (generic function with 2 methods)

In [4]:
using Logging
Logging.disable_logging(Logging.Warn)
## Here needs more analysis for optimization
num_chains=4

observed_star_model = star_model(4.51974, 0.2, 4.289877, 0.2, 70.7195, 20, star_grid)
star_chains = mapreduce(c -> sample(observed_star_model, NUTS(500,0.9), 20000;stream=false, progress=true), chainscat, 1:num_chains)

Chains MCMC chain (20000×16×4 Array{Float64, 3}):

Iterations        = 501:1:20500
Number of chains  = 4
Samples per chain = 20000
Wall duration     = 92.68 seconds
Compute duration  = 86.12 seconds
parameters        = x, logM, rotation, overshoot
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
 [1m parameters [0m [1m    mean [0m [1m     std [0m [1m naive_se [0m [1m    mcse [0m [1m        ess [0m [1m    rhat [0m [1m[0m ⋯
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m    Float64 [0m [90m Float64 [0m [90m[0m ⋯

           x    1.0398    0.6309     0.0022    0.0047   15907.7094    1.0003   ⋯
        logM    1.0860    0.0841     0.0003    0.0006   21580.6023    1.0003   ⋯
    rotation    0.1214    0.0468     0.0002    0.0005    7666.9398 

In [6]:
using QuadGK
function chain_credible_interval(values, fraction,num_points,weights)
    maxval= maximum(logM)
    minval = minimum(logM)
    my_kde = kde(logM, weights=weights)
    kde_function = x-> pdf(my_kde, x)
    
    integral = quadgk(kde_function,minval,maxval, rtol=1e-8)[1]
    
    xvalues = collect(LinRange(minval,maxval,num_points))
    yvalues = kde_function(xvalues)
    max_valy = findmax(yvalues)
    maxx = xvalues[max_valy[2]]
    maxy = max_valy[1]
    
    function kde_function2(x, kde_function, bound)
        value = kde_function(x)
        if value>bound
            return value
        else
            return 0
        end
    end
    minbound = 0
    maxbound = maxy
    newbound = 0
    for i in 1:15
        newbound = 0.5*(minbound+maxbound)
        integral2 = quadgk(x -> kde_function2(x,kde_function,newbound), minval, maxval, rtol=1e-8)[1]
        newfraction = integral2/integral
        if newfraction>fraction
            minbound = newbound
        else
            maxbound = newbound
        end
    end
    filtered_xvalues = xvalues[yvalues.>newbound]
    return ([maximum(filtered_xvalues), maxx, minimum(filtered_xvalues)], newbound)
end

chain_credible_interval (generic function with 1 method)

In [7]:
num_points = 10000
values = logM
fraction = 0.9


CI = chain_credible_interval(values, fraction, num_points,dtdx)
#@show dtdx

UndefVarError: UndefVarError: logM not defined

In [8]:
observable_names = [:logTeff, :logL, :vrot]
observable_values = [4.51974, 4.289877, 70.7195]
observable_errors = [0.2, 0.2, 20]
# potentially parallelize this
grid_likelihood = ModelDataGridLikelihood(grid, observable_names, observable_values, observable_errors);
CI2 = credible_interval(grid_likelihood, :logM,0.9,1_000_000)

([1.2283107283107284, 1.125000225000225, 0.9642552642552642], 4192.989242514728)

In [9]:
#using Plots
plot()
ml= marginalized_likelihood(grid_likelihood,[:logM])
plot!(grid.input_values[2], ml./maximum(ml))
hline!([CI2[2]/maximum(ml)])
vline!(CI2[1])

my_kde = kde(logM, weights=dtdx)
plot!(x->pdf(my_kde,x)/pdf(my_kde,CI[1][2]), xrange=[0.8, 2.1])

hline!([CI[2]/pdf(my_kde,CI[1][2])])
vline!(CI[1])

UndefVarError: UndefVarError: plot not defined

To do :

1. MCMC code into its own  file within package
    - make sampling code work with arbitrary observables
    - make credible interval code flexible 
    - optimize the credible interval code 
    - create  pretty corner plots 
    - 
2. grid computed 
3. create simple interpolation interface 
4. change longM with values in credible interval bla
5. make ticks consistent in corner plot
6. use interpolation to get bounds on 1D marginalized distribution
7. get proper parsing for report of CIs on top od 1D marginalized distribution
8. use eachindex to iterate instead of for i in 1:length...

In [12]:
using LaTeXStrings
using CairoMakie

function concatenate_chains(star_chains)
   concatenated_chains = Dict()
   for name in Base.names(star_chains)
       concatenated_chains[name] = star_chains[:,name,:][:]
   end
   return concatenated_chains
end

function compute_chain_weights(cchains)
    dtdx=zeros(length(cchains[:x]))
    for i in 1:length(dtdx)
        dtdx[i]= interpolate_grid_quantity(grid,[cchains[:rotation][i], cchains[:logM][i], cchains[:overshoot][i]],:dtdx,cchains[:x][i])
    end
    return dtdx .* cchains[:logM].^-1.35
end

function get_star_corner_plot(star_chains)
    names = [:logM, :rotation, :overshoot]
    cchains = concatenate_chains(star_chains)
    chain_weights = compute_chain_weights(cchains)
    fractions =[0.68,0.95, 0.997]
    fraction_1D = 0.68
    figure= Figure()
    label_names = [L"\log(M/M_{\odot})", L"\omega/\omega_{crit}", L"\alpha_\mathrm{ov}" ]

    create_corner_plot(cchains,names,label_names, chain_weights, fractions, fraction_1D, figure, show_CIs = false)
end

get_star_corner_plot (generic function with 1 method)

In [6]:
using Makie
figure = get_star_corner_plot(star_grid,star_chains)
save("corner_plot.png",figure)

$\log(M/M_{\odot})$=1.0925^0.07000000000000006_0.09000000000000008
$\omega/\omega_{crit}$=0.10250000000000001^0.035_0.020000000000000004


$\alpha_\mathrm{ov}$=1.3225^0.17500000000000004_0.705


CairoMakie.Screen{IMAGE}


In [8]:
using LaTeXStrings
using Plots
Plots.plot(legend=false,
xflip=true, 
xlabel=L"log (T$_{eff}$/K)",
ylabel=L"log (L/L$_{\odot}$)")

xvals = LinRange(0,5, 1000)
rotation = 0.13
logM = 1.21
overshoot = 1.05
logTeff = interpolate_grid_quantity.(Ref(star_grid),Ref([rotation,logM,overshoot]),:logTeff, xvals)
logL = interpolate_grid_quantity.(Ref(star_grid),Ref([0.13,1.21,1.05]),:logL, xvals)

Plots.plot!(logTeff, logL)
savefig("HR.png")

"/media/alina/20FCD8DF125DE0B0/work/work/project_pablo/StarStats.jl/test_notebook/HR.png"