In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using GLMakie

1. load ensemble members, create committee potential 
2. read in the 5+5 calib/test trials. Subsample every 300 again and record indices.  
3. for both pristine/frenkle, 50% is calib 50% is test. Can combine each into combined calib/test
4. Conformalize against energies, standard procedure
5. Conformalize against delta E's, explore different heuristic uncertainty metrics
6. Assess extensivity issues 

In [None]:
pristine_base_calibtest = Vector{DataSet}()
frenkel_base_calibtest = Vector{DataSet}()
base_calibtest_idxs = Dict{Tuple{Symbol, Int64}, Vector{Int64}}()

#pristine
for i in 6:10
    println(i)
    configs = load_data("./data/pristine_$(i).xyz", ExtXYZ(u"eV", u"Å"))
    indxs = randperm(1001)[1:300]

    push!(pristine_base_calibtest, configs[indxs])

    base_calibtest_idxs[(:pristine, i)] = indxs
end

#frenkel
for i in 6:10
    println(i)
    configs = load_data("./data/frenkel_$(i).xyz", ExtXYZ(u"eV", u"Å"))
    indxs = randperm(1001)[1:300]

    push!(frenkel_base_calibtest, configs[indxs])

    base_calibtest_idxs[(:frenkel, i)] = indxs
end

In [None]:
# from subsampling_dpp.jl in PL.jl examples
function concat_dataset(confs::Vector{DataSet})
    N = length(confs)
    confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
    confs_all = reduce(vcat, confs_vec)
    return DataSet(confs_all)
end

In [None]:
pristine_base_calibtest = concat_dataset(pristine_base_calibtest)
frenkel_base_calibtest = concat_dataset(frenkel_base_calibtest)

In [None]:
pristine_idxs = randperm(1500)
frenkel_idxs = randperm(1500)

pristine_base_calib = pristine_base_calibtest[pristine_idxs[1:750]]
pristine_base_test = pristine_base_calibtest[pristine_idxs[751:1500]]

frenkel_base_calib = frenkel_base_calibtest[frenkel_idxs[1:750]]
frenkel_base_test = frenkel_base_calibtest[frenkel_idxs[751:1500]]

In [None]:
combined_calib = concat_dataset([pristine_base_calib; frenkel_base_calib])
combined_test = concat_dataset([pristine_base_test; frenkel_base_test])

In [None]:
ensemble_members = load("ace_cmte1.jld2", "members")

In [None]:
include("../files/committee_potentials.jl")
include("../files/committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
test_edescr = compute_local_descriptors(combined_calib[1:10], my_cmte.members[1].basis)
test_ds = DataSet(combined_calib[1:10] .+ test_edescr)

In [None]:
test_energies = [PotentialLearning.potential_energy(config, my_cmte) for config in test_ds]
test_calib = [ustrip(compute(cmte_energy,config,my_cmte)) for config in test_ds]

In [None]:
pristine_base_calib_edescr = compute_local_descriptors(pristine_base_calib, my_cmte.members[1].basis)
pristine_base_calib_ds = DataSet(pristine_base_calib .+ pristine_base_calib_edescr)

In [None]:
pristine_base_test_edescr = compute_local_descriptors(pristine_base_test, my_cmte.members[1].basis)
pristine_base_test_ds = DataSet(pristine_base_test .+ pristine_base_test_edescr)

frenkel_base_calib_edescr = compute_local_descriptors(frenkel_base_calib, my_cmte.members[1].basis)
frenkel_base_calib_ds = DataSet(frenkel_base_calib .+ frenkel_base_calib_edescr)

frenkel_base_test_edescr = compute_local_descriptors(frenkel_base_test, my_cmte.members[1].basis)
frenkel_base_test_ds = DataSet(frenkel_base_test .+ frenkel_base_test_edescr)

In [None]:
save("datasets_with_descriptors.jld2", Dict("pristine_base_calib_ds" => pristine_base_calib_ds,
                                            "pristine_base_test_ds"  => pristine_base_test_ds,
                                            "frenkel_base_calib_ds"  => frenkel_base_calib_ds,
                                            "frenkel_base_test_ds"   => frenkel_base_test_ds))

Just doing a single qhat for a single energy

In [None]:
include("../files/conformal_prediction_utils.jl")

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

Old slow way

In [None]:
ecalib_pred = [ustrip(AtomsCalculators.potential_energy(sys,my_cmte)) for sys in get_system.(combined_calib)]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib]

Faster way

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]

In [None]:
save("calib_energies.jld2", Dict("ecalib_pred" => ecalib_pred,
                                 "ecalib_ref" => ecalib_ref))

In [None]:
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

In [None]:
qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq, 0.05)

In [None]:
etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]
test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]

num_test = length(etest_pred)
test_abs_residuals = abs.(etest_pred .- etest_ref)

In [None]:
qhat_scores = qhat*test_uq
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
hist(qhat_scores, bins=500)

In [None]:
uncertainty_vs_residuals(test_uq,test_abs_residuals, limits=(0.0,0.05,-0.001,0.5))

In [None]:
calib_scores = abs.(ecalib_pred .- ecalib_ref) ./ calib_uq
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

In [None]:
cmte_cov_energy = CmteEnergyCov(true)

In [None]:
# Already random, so just take every two
ediff_combined_calib_ref = Float64[]
ediff_combined_calib_pred = Float64[]
ediff_combined_calib_uq = Float64[]
ediff_combined_calib_cov_uq = Float64[]
sys1_combined_calib_uqs = Float64[]
sys2_combined_calib_uqs = Float64[]

sys1_combined_calib_epreds = Float64[]
sys2_combined_calib_epreds = Float64[]
for i in 1:2:length(combined_calib_ds)
    sys1 = combined_calib_ds[i]
    sys2 = combined_calib_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_calib_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_calib_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_calib_epreds,e2_pred)

    push!(ediff_combined_calib_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_calib_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_calib_uqs,sys2_uq)
    push!(ediff_combined_calib_uq, sys1_uq+sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_calib_cov_uq, cov_uq)

end

In [None]:
# Already random, so just take every two
ediff_combined_test_ref = Float64[]
ediff_combined_test_pred = Float64[]
ediff_combined_test_uq = Float64[]
ediff_combined_test_cov_uq = Float64[]
sys1_combined_test_uqs = Float64[]
sys2_combined_test_uqs = Float64[]

sys1_combined_test_epreds = Float64[]
sys2_combined_test_epreds = Float64[]
for i in 1:2:length(combined_test_ds)
    sys1 = combined_test_ds[i]
    sys2 = combined_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_test_epreds,e2_pred)

    push!(ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_test_uqs,sys2_uq)
    push!(ediff_combined_test_uq, sys1_uq+sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
# these should be the same as ediff_combined_test/calib_uq
test_ediff_uq1 = sys1_combined_test_uqs .+ sys2_combined_test_uqs
calib_ediff_uq1 = sys1_combined_calib_uqs .+ sys2_combined_calib_uqs


test_ediff_uq2 = sys1_combined_test_uqs .+ sys2_combined_test_uqs .+ 2*ediff_combined_test_cov_uq
calib_ediff_uq2 = sys1_combined_calib_uqs .+ sys2_combined_calib_uqs .+ 2*ediff_combined_calib_cov_uq

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq1
test_abs_residuals_combined = abs.(ediff_combined_test_pred .- ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,test_ediff_uq1, test_abs_residuals_combined)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)