In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes
using LinearAlgebra

In [None]:
includet("../files/conformal_prediction_utils.jl")
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")
includet("../files/conformal_prediction_plots.jl")

In [None]:
basic_coeffs = load("dumb_coeffs_for_heuristic_uq.jld2", "basic_coeffs")

In [None]:
calibtest_datasets = load("../cesmix_prez/datasets_with_descriptors.jld2")
pristine_base_calib_ds = calibtest_datasets["pristine_base_calib_ds"]
pristine_base_test_ds = calibtest_datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = calibtest_datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = calibtest_datasets["frenkel_base_test_ds"]

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
ace_cmte_dict = load("../cesmix_prez/ace_cmte1.jld2")
ensemble_members = ace_cmte_dict["members"]

In [None]:
ace = ACE(species            = [:Hf],
          body_order         = 4,
          polynomial_degree  = 10,
          wL                 = 1.5,
          csp                = 1.0,
          r0                 = 2.15,
          rcutoff            = 5.0)

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]

In [None]:
function compute_basic_estimated_uqs(configs, coeffs)
    gds = [sum(get_values(get_local_descriptors(config))) for config in configs]
    preds = [exp(coeffs'*gd) for gd in gds]
    preds
end

In [None]:
calib_scores = abs.(ecalib_pred .- ecalib_ref) ./ calib_uq
test_abs_residuals = abs.(etest_pred .- etest_ref)

test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]
est_test_uq = compute_basic_estimated_uqs(combined_test_ds, basic_coeffs)

alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)

In [None]:
test_uq

In [None]:
get_values(get_local_descriptors(combined_test_ds[1]))

In [None]:
make_custom_calibration_plot1(alpha_refs,alpha_pred; text_size=24, label_size=28)

In [None]:
pristine_8x_ds, frenkel_8x_ds, dilute_8x_ds = load("large_8x_data.jld2", "pristine_8x_ds", "frenkel_8x_ds", "dilute_8x_ds")

In [None]:
large_pristine_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in pristine_8x_ds]
est_large_pristine_uq = compute_basic_estimated_uqs(pristine_8x_ds,basic_coeffs)

large_frenkel_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in frenkel_8x_ds]
est_large_frenkel_uq = compute_basic_estimated_uqs(frenkel_8x_ds,basic_coeffs)

large_dilute_frenkel_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in dilute_8x_ds]
est_large_dilute_frenkel_uq = compute_basic_estimated_uqs(dilute_8x_ds,basic_coeffs)

In [None]:
large_pristine_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in pristine_8x_ds]
large_pristine_ref = [ustrip(get_values(get_energy(config))) for config in pristine_8x_ds]

large_frenkel_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in frenkel_8x_ds]
large_frenkel_ref = [ustrip(get_values(get_energy(config))) for config in frenkel_8x_ds]

large_dilute_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in dilute_8x_ds]
large_dilute_ref = [ustrip(get_values(get_energy(config))) for config in dilute_8x_ds]

In [None]:
total_large_uq = [large_pristine_uq; large_frenkel_uq; large_dilute_frenkel_uq]
total_est_large_uq = [est_large_pristine_uq; est_large_frenkel_uq; est_large_dilute_frenkel_uq]

total_large_pred = [large_pristine_pred; large_frenkel_pred; large_dilute_pred]
total_large_ref = [large_pristine_ref; large_frenkel_ref; large_dilute_ref]

In [None]:
total_large_uq

In [None]:
total_est_large_uq

In [None]:
etest_pred = large_pristine_pred
etest_ref = large_pristine_ref
test_uq = large_pristine_uq

test_abs_residuals = abs.(etest_pred .- etest_ref)
alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)
@show compute_miscalibration_area(alpha_refs, alpha_pred)
make_custom_calibration_plot1(alpha_refs,alpha_pred; text_size=24, label_size=28)

In [None]:
etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]

In [None]:
mean(abs.(etest_pred .- etest_ref))

In [None]:
mean(abs.(ecalib_pred .- ecalib_ref))