In [None]:
using Revise
using Pkg; Pkg.activate(".") # If I'm in new_muller_brown

In [None]:
using Cairn
using LinearAlgebra, Random, Statistics, StatsBase # Do need StatsBase this time
using PotentialLearning
using Molly, AtomsCalculators
using AtomisticQoIs
using SpecialPolynomials
using JLD2

In [None]:
include("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/makie.jl")
includet("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/plot_contours.jl")

In [None]:
ref = MullerBrownRot()
limits = [[-4.4,1.5],[-2,2]]
coord_grid = coord_grid_2d(limits,0.05) # from the makie.jl file
#ctr_lvls = -150:50:1000
ctr_lvls = -150:2:1000

In [None]:
basisfam = Jacobi{0.5,0.5}
#basisfam = Jacobi{2,2}

#order=20
order=50
pce0 = PolynomialChaos(order,2,basisfam,xscl=limits)

In [None]:
coords_eval = potential_grid_2d(ref,limits,0.04,cutoff=800)
#@show coords_eval
#sys_eval = define_ens(ref,coords_eval) # This has been replaced with Ensemble()
sys_eval = Ensemble(ref,coords_eval)

In [None]:
sys0 = System(ref,[0.5,0.5], loggers=(coords=CoordinateLogger(100;dims=2),))

In [None]:
my_coordsets = load("ten_1500K_trainsets.jld2")["coordsets"]
my_trainsets = [Ensemble(deepcopy(pce0),coords) for coords in my_coordsets]

In [None]:
ensemble_members = load("ten_pce_ensembles.jld2")["ensemble_members"]

In [None]:
include("committee_potentials.jl")
include("committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"kJ*mol^-1", length_units=u"nm")

In [None]:
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
# test data
sim_highT = OverdampedLangevin(
            dt=0.002u"ps",
            temperature=1500.0u"K",
            friction=4.0u"ps^-1",
)
test_sys = deepcopy(sys0)
simulate!(test_sys, sim_highT, 10_000_000)

In [None]:
# test/calibration split
function obtain_test_cal_indxs(frac::Float64, set_size::Int64)
    @assert frac <= 1.0
    num_select = Int(floor(frac*set_size))

    perm_idxs = randperm(set_size)
    test_set_idxs = perm_idxs[begin:1:num_select]
    calib_set_idxs = perm_idxs[num_select+1:end]

    test_set_idxs, calib_set_idxs
end

In [None]:
possible_test_idxs, possible_cal_idxs = obtain_test_cal_indxs(0.5,length(test_sys.loggers.coords.history))

In [None]:
ediff_calib_idxs = possible_cal_idxs[5001:15_000]
ediff_calib_coords = [test_sys.loggers.coords.history[i][1] for i in ediff_calib_idxs]
ediff_calibset = Ensemble(deepcopy(pce0), ediff_calib_coords)

In [None]:
ediff_test_idxs = possible_test_idxs[5001:15000]
ediff_test_coords = [test_sys.loggers.coords.history[i][1] for i in ediff_test_idxs]
ediff_testset = Ensemble(deepcopy(pce0), ediff_test_coords)

In [None]:
# Already random, so just take every two
ediff_test_ref = Float64[]
ediff_test_pred = Float64[]
ediff_test_uq = Float64[]
sys1_test_uqs = Float64[]
sys2_test_uqs = Float64[]

sys1_test_epreds = Float64[]
sys2_test_epreds = Float64[]
for i in 1:2:length(ediff_testset)
    sys1 = ediff_testset[i]
    sys2 = ediff_testset[i+1]
    e1_ref = ustrip.(potential_energy(sys1,ref))
    e2_ref = ustrip.(potential_energy(sys2, ref))

    push!(ediff_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip.(potential_energy(sys1, my_cmte))
    push!(sys1_test_epreds,e1_pred)
    e2_pred = ustrip.(potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_test_epreds,e2_pred)

    push!(ediff_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_test_uqs,sys2_uq)
    push!(ediff_test_uq, sys1_uq+sys2_uq)
end

In [None]:
# Already random, so just take every two
ediff_calib_ref = Float64[]
ediff_calib_pred = Float64[]
ediff_calib_uq = Float64[]
sys1_calib_uqs = Float64[]
sys2_calib_uqs = Float64[]

sys1_calib_epreds = Float64[]
sys2_calib_epreds = Float64[]
for i in 1:2:length(ediff_calibset)
    sys1 = ediff_calibset[i]
    sys2 = ediff_calibset[i+1]
    e1_ref = ustrip.(potential_energy(sys1,ref))
    e2_ref = ustrip.(potential_energy(sys2, ref))

    push!(ediff_calib_ref, e2_ref - e1_ref)

    e1_pred = ustrip.(potential_energy(sys1, my_cmte))
    push!(sys1_calib_epreds,e1_pred)
    e2_pred = ustrip.(potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_calib_epreds,e2_pred)

    push!(ediff_calib_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_calib_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_calib_uqs,sys2_uq)
    push!(ediff_calib_uq, sys1_uq+sys2_uq)
end

In [None]:
includet("conformal_prediction_utils.jl")

In [None]:
test_ediff_uq = sys1_test_uqs .+ sys2_test_uqs
calib_ediff_uq = sys1_calib_uqs .+ sys2_calib_uqs

In [None]:
calib_scores = abs.(ediff_calib_pred .- ediff_calib_ref) ./ calib_ediff_uq
test_abs_residuals = abs.(ediff_test_pred .- ediff_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_ediff_uq, test_abs_residuals)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

- estimate covariance matrix, k-NN of residuals with some other dataset (they did train set, but I've lost access to that... should've saved it to jld2)
- compute e1 residuals 
- compute e2 residuals

In [None]:
alpha=0.05
num_calib = length(calib_scores)
q_hat = quantile(calib_scores, clamp(ceil((num_calib+1)*(1-alpha))/num_calib, 0.0, 1.0))