In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using GLMakie

In [None]:
ensemble_members = load("ace_cmte1.jld2", "members")

In [None]:
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
datasets = load("datasets_with_descriptors.jld2")
pristine_base_calib_ds = datasets["pristine_base_calib_ds"]
pristine_base_test_ds = datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = datasets["frenkel_base_test_ds"]



Just doing a single qhat for a single energy

In [None]:
includet("../files/conformal_prediction_utils.jl")

In [None]:
# from subsampling_dpp.jl in PL.jl examples
function concat_dataset(confs::Vector{DataSet})
    N = length(confs)
    confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
    confs_all = reduce(vcat, confs_vec)
    return DataSet(confs_all)
end

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

In [None]:
qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq, 0.05)

In [None]:
etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]
test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]

num_test = length(etest_pred)
test_abs_residuals = abs.(etest_pred .- etest_ref)

In [None]:
qhat_scores = qhat*test_uq
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
hist(qhat_scores, bins=500)

In [None]:
uncertainty_vs_residuals(test_uq,test_abs_residuals, limits=(0.0,0.05,-0.001,0.5))

In [None]:
calib_scores = abs.(ecalib_pred .- ecalib_ref) ./ calib_uq
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

In [None]:
cmte_cov_energy = CmteEnergyCov(true)

In [None]:
combined_calib_ds_orig = combined_calib_ds
combined_test_ds_orig = combined_test_ds

In [None]:
# Actually, I didn't do this originally, but this ensures that you get some (pristine,frenkel) pairs rather than just (pristine, pristine)
rand_idxs = randperm(1500)
combined_calib_ds = combined_calib_ds_orig[rand_idxs]
combined_test_ds = combined_test_ds_orig[rand_idxs]

In [None]:
# Already random, so just take every two
ediff_combined_calib_ref = Float64[]
ediff_combined_calib_pred = Float64[]
sys1_combined_calib_uqs = Float64[]
sys2_combined_calib_uqs = Float64[]
ediff_combined_calib_cov_uq = Float64[]

sys1_combined_calib_epreds = Float64[]
sys2_combined_calib_epreds = Float64[]
for i in 1:2:length(combined_calib_ds)
    sys1 = combined_calib_ds[i]
    sys2 = combined_calib_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_calib_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_calib_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_calib_epreds,e2_pred)

    push!(ediff_combined_calib_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_calib_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_calib_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_calib_cov_uq, cov_uq)

end

In [None]:
# Already random, so just take every two
ediff_combined_test_ref = Float64[]
ediff_combined_test_pred = Float64[]
ediff_combined_test_uq = Float64[]
ediff_combined_test_cov_uq = Float64[]
sys1_combined_test_uqs = Float64[]
sys2_combined_test_uqs = Float64[]

sys1_combined_test_epreds = Float64[]
sys2_combined_test_epreds = Float64[]
for i in 1:2:length(combined_test_ds)
    sys1 = combined_test_ds[i]
    sys2 = combined_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_test_epreds,e2_pred)

    push!(ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_test_uqs,sys2_uq)
    push!(ediff_combined_test_uq, sys1_uq+sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
# these should be the same as ediff_combined_test/calib_uq
test_ediff_uq1 = sys1_combined_test_uqs .+ sys2_combined_test_uqs
calib_ediff_uq1 = sys1_combined_calib_uqs .+ sys2_combined_calib_uqs


test_ediff_uq2 = sys1_combined_test_uqs .+ sys2_combined_test_uqs .+ 2*ediff_combined_test_cov_uq
calib_ediff_uq2 = sys1_combined_calib_uqs .+ sys2_combined_calib_uqs .+ 2*ediff_combined_calib_cov_uq

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq1
test_abs_residuals_combined = abs.(ediff_combined_test_pred .- ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,test_ediff_uq1, test_abs_residuals_combined)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

In [None]:
large_pristine_ds = load("large_pristine_ds.jld2", "large_pristine_ds")

In [None]:
large_8x_frenkel_ds = load("large_8x_frenkel_ds.jld2", "large_8x_frenkel_ds")

In [None]:
#large_pristine_configs1 = load_data("./data/pod_Hf_frenkel_large/large_pristine_1.xyz", ExtXYZ(u"eV", u"Å"))
#large_pristine_configs2 = load_data("./data/pod_Hf_frenkel_large/large_pristine_2.xyz", ExtXYZ(u"eV", u"Å"))
#large_pristine_configs = concat_dataset([large_pristine_configs1; large_pristine_configs2])
#
#large_8x_frenkel_configs1 = load_data("./data/pod_Hf_frenkel_large/large_8x_frenkel_1.xyz", ExtXYZ(u"eV", u"Å"))
#large_8x_frenkel_configs2 = load_data("./data/pod_Hf_frenkel_large/large_8x_frenkel_2.xyz", ExtXYZ(u"eV", u"Å"))
#large_8x_frenkel_configs = concat_dataset([large_8x_frenkel_configs1; large_8x_frenkel_configs2])
#
#large_dilute_frenkel_configs1 = load_data("./data/pod_Hf_frenkel_large/large_dilute_frenkel_1.xyz", ExtXYZ(u"eV", u"Å"))
#large_dilute_frenkel_configs2 = load_data("./data/pod_Hf_frenkel_large/large_dilute_frenkel_2.xyz", ExtXYZ(u"eV", u"Å"))
#large_dilute_frenkel_configs = concat_dataset([large_dilute_frenkel_configs1; large_dilute_frenkel_configs2])

In [None]:
#large_pristine_edescrs = compute_local_descriptors(large_pristine_configs, my_cmte.members[1].basis)
#large_pristine_ds = DataSet(large_pristine_configs .+ large_pristine_edescrs)
#
#large_8x_frenkel_edescrs = compute_local_descriptors(large_8x_frenkel_configs, my_cmte.members[1].basis)
#large_8x_frenkel_ds = DataSet(large_8x_frenkel_configs .+ large_8x_frenkel_edescrs)
#
#large_dilute_frenkel_edescrs = compute_local_descriptors(large_dilute_frenkel_configs, my_cmte.members[1].basis)
#large_dilute_frenkel_ds = DataSet(large_dilute_frenkel_configs .+ large_dilute_frenkel_edescrs)

In [None]:
large_combined_orig_ds = concat_dataset([large_pristine_ds;large_8x_frenkel_ds])

In [None]:
rand_idxs_large = randperm(204)
large_combined_ds = large_combined_orig[rand_idxs_large]

In [None]:
#ediff_combined_test_ref = Float64[]
#ediff_combined_test_pred = Float64[]
#ediff_combined_test_uq = Float64[]
#ediff_combined_test_cov_uq = Float64[]
#sys1_combined_test_uqs = Float64[]
#sys2_combined_test_uqs = Float64[]
#
#sys1_combined_test_epreds = Float64[]
#sys2_combined_test_epreds = Float64[]
large_ediff_ref = Float64[]
large_ediff_pred = Float64[]
large_sys1_uq = Float64[]
large_sys2_uq = Float64[]
large_cov_uq = Float64[]
for i in 1:2:length(large_combined_ds)
    sys1 = large_combined_ds[i]
    sys2 = large_combined_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(large_ediff_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))

    push!(large_ediff_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(large_sys1_uq,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(large_sys2_uq,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(large_cov_uq, cov_uq)
end

In [None]:
large_ediff_uq1 = large_sys1_uq .+ large_sys2_uq
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq1
test_abs_residuals_combined = abs.(large_ediff_pred .- large_ediff_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,large_ediff_uq1, test_abs_residuals_combined)

In [None]:
make_calibration_plot(alpha_refs,alpha_pred)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

In [None]:
qhat = calibrate(ediff_combined_calib_pred, ediff_combined_calib_ref, calib_ediff_uq1, 0.05)
qhat_scores = qhat*large_ediff_uq1

In [None]:
coverage = sum(test_abs_residuals_combined .> qhat_scores)/ length(large_ediff_pred)