In [None]:
using Revise
using Pkg; Pkg.activate(".") # If I'm in new_muller_brown

In [None]:
using Cairn
using LinearAlgebra, Random, Statistics, StatsBase # Do need StatsBase this time
using PotentialLearning
using Molly, AtomsCalculators
using AtomisticQoIs
using SpecialPolynomials
using JLD2

In [None]:
include("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/makie.jl")
includet("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/plot_contours.jl")

In [None]:
ref = MullerBrownRot()
limits = [[-4.4,1.5],[-2,2]]
coord_grid = coord_grid_2d(limits,0.05) # from the makie.jl file
#ctr_lvls = -150:50:1000
ctr_lvls = -150:2:1000

In [None]:
basisfam = Jacobi{0.5,0.5}
#basisfam = Jacobi{2,2}

#order=20
order=50
pce0 = PolynomialChaos(order,2,basisfam,xscl=limits)

In [None]:
coords_eval = potential_grid_2d(ref,limits,0.04,cutoff=800)
#@show coords_eval
#sys_eval = define_ens(ref,coords_eval) # This has been replaced with Ensemble()
sys_eval = Ensemble(ref,coords_eval)

In [None]:
sys0 = System(ref,[0.5,0.5], loggers=(coords=CoordinateLogger(100;dims=2),))

In [None]:
# Training set 3: samples from high-T MD
sim_highT = OverdampedLangevin(
            dt=0.002u"ps",
            temperature=1500.0u"K",
            #temperature=1250.0u"K",
            friction=4.0u"ps^-1",
)
# simulate
sys = deepcopy(sys0)
simulate!(sys, sim_highT, 10_000_000)

id = StatsBase.sample(1:length(sys.loggers.coords.history), 50_000, replace=false)
coords = [sys.loggers.coords.history[i][1] for i in id]
trainset = Ensemble(deepcopy(pce0), coords)

In [None]:
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, [get_values(crd) for crd in coords])'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=1, label="train set 3")
axislegend(ax)
f

In [None]:
trainsets = Vector{eltype(trainset)}()

In [None]:
pce = deepcopy(pce0)
lpe = learn!(trainset, ref, pce; e_flag=true, f_flag=false)

In [None]:
pce.params = deepcopy(lpe.β)

In [None]:
Cairn.potential_pce(pce,coords[5])

In [None]:
Cairn.potential_muller_brown(ref, coords[5])

In [None]:
function generate_trainsets(sim, base_sys, base_pot;
                            num_trainsets=10,
                            num_run=10_000_000,
                            num_sample=50_000)
    trainsets = []
    coordsets = []
    for i in 1:num_trainsets
        println("generating trainset $(i)")
        sys = deepcopy(base_sys)
        simulate!(sys, sim, num_run)

        id = StatsBase.sample(1:length(sys.loggers.coords.history), num_sample, replace=false)
        coords = [sys.loggers.coords.history[k][1] for k in id]
        push!(coordsets, coords)
        trainset = Ensemble(deepcopy(base_pot), coords)
        push!(trainsets, trainset)
    end
    identity.(trainsets), identity.(coordsets)
end

In [None]:
# Training set 3: samples from high-T MD
sim_highT = OverdampedLangevin(
            dt=0.002u"ps",
            temperature=1500.0u"K",
            friction=4.0u"ps^-1",
)

my_trainsets, my_coordsets = generate_trainsets(sim_highT, sys0, pce0, num_trainsets=10)

In [None]:
#save("ten_1500K_trainsets.jld2", Dict("coordsets"=>my_coordsets))

In [None]:
my_coordsets = load("ten_1500K_trainsets.jld2")["coordsets"]
my_trainsets = [Ensemble(deepcopy(pce0),coords) for coords in my_coordsets]

In [None]:
function plot_trainsets(ref_pot, coord_grid, ctr_lvls, coordsets)
    f, ax = plot_contours_2d(ref_pot, coord_grid; fill=true, lvls=ctr_lvls)
    cmap = :seaborn_bright
    for (i,coords) in enumerate(coordsets)
        coordmat = reduce(hcat, [get_values(crd) for crd in coords])'
        scatter!(ax, coordmat[:,1], coordmat[:,2], colormap=cmap, color=i, colorrange=(1,10), markersize=1, label="trainset $(i)")
        axislegend(ax)
    end
    f
end

There's a bug here with the legend when I plot all ten, but fine if I plot the first 9

In [None]:
plot_trainsets(ref, coord_grid, ctr_lvls, my_coordsets)

In [None]:
ensemble_members = Vector{typeof(pce0)}()

for trainset in my_trainsets
    pce = deepcopy(pce0)
    lpe = learn!(trainset, ref, pce; e_flag=true, f_flag=false)
    pce.params = deepcopy(lpe.β)
    push!(ensemble_members, pce)
end

In [None]:
#save("ten_pce_ensembles.jld2", Dict("ensemble_members" => ensemble_members))

In [None]:
ensemble_members = load("ten_pce_ensembles.jld2")["ensemble_members"]

In [None]:
include("committee_potentials.jl")
include("committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"kJ*mol^-1", length_units=u"nm")

In [None]:
ustrip.(compute_all_energies(sys0,my_cmte))

In [None]:
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
compute(cmte_energy,sys0,my_cmte)

In [None]:
# test data
sim_highT = OverdampedLangevin(
            dt=0.002u"ps",
            temperature=1500.0u"K",
            friction=4.0u"ps^-1",
)
test_sys = deepcopy(sys0)
simulate!(test_sys, sim_highT, 10_000_000)

In [None]:
# test/calibration split
function obtain_test_cal_indxs(frac::Float64, set_size::Int64)
    @assert frac <= 1.0
    num_select = Int(floor(frac*set_size))

    perm_idxs = randperm(set_size)
    test_set_idxs = perm_idxs[begin:1:num_select]
    calib_set_idxs = perm_idxs[num_select+1:end]

    test_set_idxs, calib_set_idxs
end

In [None]:
possible_test_idxs, possible_cal_idxs = obtain_test_cal_indxs(0.5,length(test_sys.loggers.coords.history))

In [None]:
test_idxs = possible_test_idxs[1:5000]
test_coords = [test_sys.loggers.coords.history[i][1] for i in test_idxs]
testset = Ensemble(deepcopy(pce0), test_coords)


In [None]:
using BenchmarkTools

In [None]:
@benchmark compute(cmte_energy,sys0,my_cmte)

In [None]:
test_cmte_stds = [compute(cmte_energy,sys,my_cmte) for sys in testset]

In [None]:
function plot_histogram(data; num_bins=500, xlow=0.0, xhigh=0.2)
    fig = Figure()
    ax = Axis(fig[1, 1],
        xlabel = "Value",
        ylabel = "Frequency",
        title = "Histogram")

    hist!(ax, data, bins = num_bins)

    # Set the x-axis limits to 0 to 0.5
    xlims!(ax, xlow,xhigh)

    return fig
end

In [None]:
maximum(test_cmte_stds)

In [None]:
plot_histogram([val for val in test_cmte_stds if val< 0.5]; num_bins=500)

In [None]:
out = sort(deepcopy(test_cmte_stds))

In [None]:
out[end-19:end]

In [None]:
calib_idxs = possible_cal_idxs[1:5000]
calib_coords = [test_sys.loggers.coords.history[i][1] for i in calib_idxs]
calib_set = Ensemble(deepcopy(pce0), calib_coords)

In [None]:
ecalib_pred = [ustrip(potential_energy(sys,my_cmte)) for sys in calib_set]
ecalib_ref = [ustrip(potential_energy(sys,ref)) for sys in calib_set]
calib_uq = [ustrip(compute(cmte_energy, sys, my_cmte)) for sys in calib_set]

In [None]:
includet("conformal_prediction_utils.jl")

In [None]:
qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq)

In [None]:
etest_pred = [ustrip(potential_energy(sys,my_cmte)) for sys in testset]
etest_ref = [ustrip(potential_energy(sys,ref)) for sys in testset]
test_uq = [ustrip(compute(cmte_energy, sys, my_cmte)) for sys in testset]

num_test = length(etest_pred)
test_abs_residuals = abs.(etest_pred .- etest_ref)

In [None]:
qhat_scores = qhat*test_uq
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
hist([score for score in qhat_scores if score <0.1], bins=500)

In [None]:
test_uq

In [None]:
uncertainty_vs_residuals(test_uq,test_abs_residuals)

In [None]:
calib_scores = abs.(ecalib_pred .- ecalib_ref) ./ calib_uq
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

- [x] take 2x random test samples, randomly pair them, compute ref and predicted energy differences
- [ ]to find conservative CI bounds, take high bound of first, low bound of second; low bound of first, high bound of second
- [ ] assess coverage, i.e. is the residual within those bounds

Then to conformalize against energy differences. 
- [ ] Double size of calibration data, randomly pair, and conformalize against energy difference. 
     (naive approach is to take sum of each point's std as the uq metric)
- [ ] redo coverage check

In [None]:
ediff_test_idxs = possible_test_idxs[5001:15000]
ediff_test_coords = [test_sys.loggers.coords.history[i][1] for i in ediff_test_idxs]
ediff_testset = Ensemble(deepcopy(pce0), ediff_test_coords)

In [None]:
# Already random, so just take every two
ediff_test_ref = Float64[]
ediff_test_pred = Float64[]
ediff_test_uq = Float64[]
sys1_uqs = Float64[]
sys2_uqs = Float64[]

sys1_epreds = Float64[]
sys2_epreds = Float64[]
for i in 1:2:length(ediff_testset)
    sys1 = ediff_testset[i]
    sys2 = ediff_testset[i+1]
    e1_ref = ustrip.(potential_energy(sys1,ref))
    e2_ref = ustrip.(potential_energy(sys2, ref))

    push!(ediff_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip.(potential_energy(sys1, my_cmte))
    push!(sys1_epreds,e1_pred)
    e2_pred = ustrip.(potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_epreds,e2_pred)

    push!(ediff_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_uqs,sys2_uq)
    push!(ediff_test_uq, sys1_uq+sys2_uq)
end

In [None]:
# for each (sys1,sys2) pair, output (low,high) for the prediction set of the energy difference
function compute_ediff_prediction_sets(sys1_epreds,
                                       sys2_epreds,
                                       sys1_uqs,
                                       sys2_uqs,
                                       qhat)
    ediff_prediction_sets = Vector{Tuple{Float64,Float64}}()
    for idx in 1:length(sys1_epreds)
        band1 = qhat*sys1_uqs[idx]
        band2 = qhat*sys2_uqs[idx]

        upper_2 = sys2_epreds[idx] + band2
        lower_1 = sys1_epreds[idx] - band1
        bound1 = upper_2 - lower_1

        lower_2 = sys2_epreds[idx] - band2
        upper_1 = sys1_epreds[idx] + band1
        bound2 = lower_2 - upper_1

        push!(ediff_prediction_sets, tuple(sort([bound1,bound2])...))
    end

    ediff_prediction_sets
end

In [None]:
# for each (sys1,sys2) pair, output (low,high) for the prediction set of the energy difference
function correct_conservative_ediff_pred_sets(sys1_epreds,
                                            sys2_epreds,
                                            sys1_uqs,
                                            sys2_uqs,
                                            qhat)
    ediff_prediction_sets = Vector{Tuple{Float64,Float64}}()
    for idx in 1:length(sys1_epreds)
        band1 = qhat*sys1_uqs[idx]
        band2 = qhat*sys2_uqs[idx]

        upper_2 = sys2_epreds[idx] + band2
        lower_1 = sys1_epreds[idx] - band1

        lower_2 = sys2_epreds[idx] - band2
        upper_1 = sys1_epreds[idx] + band1

        bound1 = upper_2 - lower_1
        bound2 = upper_2 - upper_1
        bound3 = lower_2 - upper_1
        bound4 = lower_2 - lower_1

        max_bound = maximum([bound1,bound2,bound3,bound4])
        min_bound = minimum([bound1,bound2,bound3,bound4])
        push!(ediff_prediction_sets, (min_bound,max_bound))
    end

    ediff_prediction_sets
end

In [None]:
qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq)
ediff_pred_sets = compute_ediff_prediction_sets(sys1_epreds,
                                                sys2_epreds,
                                                sys1_uqs,
                                                sys2_uqs,
                                                qhat)

In [None]:
ediff_test_pred

In [None]:
using GLMakie

In [None]:
ediff_low_bounds = [bound[1] for bound in ediff_pred_sets]
ediff_high_bounds = [bound[2] for bound in ediff_pred_sets]
parity_plot(ediff_test_ref, ediff_test_pred, ediff_test_pred .- ediff_low_bounds, ediff_high_bounds .- ediff_test_pred)

In [None]:
ediff_high_bounds[1:10]

In [None]:
ediff_low_bounds

In [None]:
num_test_ediff = length(ediff_test_ref)
sum(ediff_low_bounds .<= ediff_test_ref .<= ediff_high_bounds)/num_ediff

In [None]:

alpha_compls = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_compls # i.e. iterate 0.99..0.01, but will then plot as 1-0.99...1-0.01

bad_predicted_alphas = Float64[]
for alpha in alpha_refs
    qh = calibrate(ecalib_pred, ecalib_ref, calib_uq, alpha)
    ediff_pred_sets = compute_ediff_prediction_sets(sys1_epreds,
                                                sys2_epreds,
                                                sys1_uqs,
                                                sys2_uqs,
                                                qh)

    ediff_low_bounds = [bound[1] for bound in ediff_pred_sets]
    ediff_high_bounds = [bound[2] for bound in ediff_pred_sets]
    predicted_alpha_compl = sum(ediff_low_bounds .<= ediff_test_ref .<= ediff_high_bounds) / num_test_ediff
    push!(bad_predicted_alphas, 1.0 - predicted_alpha_compl)
end

In [None]:
make_calibration_plot(alpha_refs, bad_predicted_alphas)

In [None]:
compute_miscalibration_area(alpha_refs,bad_predicted_alphas)

In [None]:
function check_alpha(alpha;
                    ecalib_pred=ecalib_pred,
                    ecalib_ref=ecalib_ref,
                    calib_uq=calib_uq,
                    sys1_epreds=sys1_epreds,
                    sys2_epreds=sys2_epreds,
                    sys1_uqs=sys2_uqs)
    qh = calibrate(ecalib_pred, ecalib_ref, calib_uq, alpha)
    ediff_pred_sets = compute_ediff_prediction_sets(sys1_epreds,
                                                sys2_epreds,
                                                sys1_uqs,
                                                sys2_uqs,
                                                qh)

    ediff_low_bounds = [bound[1] for bound in ediff_pred_sets]
    ediff_high_bounds = [bound[2] for bound in ediff_pred_sets]
    predicted_alpha = sum(ediff_low_bounds .<= ediff_test_ref .<= ediff_high_bounds) / num_test_ediff
    ediff_pred_sets, 1.0 -predicted_alpha
end

In [None]:
pred_sets, my_alpha = check_alpha(0.05)

In [None]:
my_alpha

In [None]:
alpha_compls = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_compls # i.e. iterate 0.99..0.01, but will then plot as 1-0.99...1-0.01

conservative_predicted_alphas = Float64[]
for alpha in alpha_refs
    qh = calibrate(ecalib_pred, ecalib_ref, calib_uq, alpha)
    ediff_pred_sets = correct_conservative_ediff_pred_sets(sys1_epreds,
                                                sys2_epreds,
                                                sys1_uqs,
                                                sys2_uqs,
                                                qh)

    ediff_low_bounds = [bound[1] for bound in ediff_pred_sets]
    ediff_high_bounds = [bound[2] for bound in ediff_pred_sets]
    predicted_alpha_compl = sum(ediff_low_bounds .<= ediff_test_ref .<= ediff_high_bounds) / num_test_ediff
    push!(conservative_predicted_alphas, 1.0 - predicted_alpha_compl)
end

In [None]:
make_calibration_plot(alpha_refs, conservative_predicted_alphas)

In [None]:
compute_miscalibration_area(alpha_refs,conservative_predicted_alphas)

OK actually it's the same, so suprisingly (to me), it's undercovered?
