Explored the per-atom verion of heuristic uncertainty, went all the way through to regressing thne doing the calibration curve. 

In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes
using LinearAlgebra

In [None]:
includet("../files/conformal_prediction_utils.jl")
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")
includet("../files/conformal_prediction_plots.jl")

In [None]:
training_data_dict = load("training_data.jld2")
frenkel_train_ds = training_data_dict["frenkel_train_ds"]
pristine_train_ds = training_data_dict["pristine_train_ds"]

In [None]:
total_train = concat_dataset([frenkel_train_ds, pristine_train_ds])

In [None]:
calibtest_datasets = load("../cesmix_prez/datasets_with_descriptors.jld2")
pristine_base_calib_ds = calibtest_datasets["pristine_base_calib_ds"]
pristine_base_test_ds = calibtest_datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = calibtest_datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = calibtest_datasets["frenkel_base_test_ds"]

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
ace_cmte_dict = load("../cesmix_prez/ace_cmte1.jld2")
ensemble_members = ace_cmte_dict["members"]

In [None]:
ace = ACE(species            = [:Hf],
          body_order         = 4,
          polynomial_degree  = 10,
          wL                 = 1.5,
          csp                = 1.0,
          r0                 = 2.15,
          rcutoff            = 5.0)

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)
atomic_energies_qoi = CmteAtomicEnergies(Statistics.var, true) # actually this should be std, but whatever I don't think it matters for this notebook

Obtain qhat score and quick eval

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq, 0.1)

etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]
test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]

num_test = length(etest_pred)
test_abs_residuals = abs.(etest_pred .- etest_ref)

qhat_scores = qhat*test_uq
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
train_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in total_train]
train_ci = qhat .* train_uq

In [None]:
function simple_regression(xvecs, yvec, lambda=0.01)
    A = reduce(hcat,xvecs)'
    b = yvec
    AtA = A'*A
    Atb = A'*b
    coeffs = (AtA + lambda*I) \ Atb
    coeffs
end

In [None]:
sum(get_values(get_local_descriptors(total_train[1])))

In [None]:
#total_train_gds = [sum.(get_values(get_local_descriptors(config))) for config in total_train] # THIS WAS WRONG BEFORE
total_train_gds = [sum(get_values(get_local_descriptors(config))) for config in total_train]

In [None]:
new_coeffs = simple_regression(total_train_gds, log.(train_uq))

In [None]:
preds = [exp(new_coeffs'*gd) for gd in total_train_gds]

In [None]:
basic_parity_plot(train_uq, preds; min_val=0.0, max_val=0.23, marker_size=4)

In [None]:
@show mean(abs.(train_uq .- preds))
@show sqrt(mean((train_uq .- preds).^2))

In [None]:
total_train_lds = [get_values(get_local_descriptors(config)) for config in total_train]

In [None]:
total_train_atomic_energies_std = [compute(atomic_energies_qoi,config,my_cmte) for config in total_train]

In [None]:
total_train_atomic_energies_std[1]

In [None]:
total_train_lds[1]

In [None]:
#[vec(total_train_atomic_energies_std[1]).*total_train_lds[1]]
intermediate = [vec(total_train_atomic_energies_std[i]) .* total_train_lds[i] for i in 1:length(total_train_lds)]

In [None]:
sum(intermediate[1])

In [None]:
intermediate[1]

In [None]:
my_sum =0.0
for arr in intermediate[1]
    @show arr[1]
    my_sum += arr[1]
end
my_sum

In [None]:
adjusted_gds_total_train = [sum(inter) for inter in intermediate]

In [None]:
alt_coeffs = simple_regression(adjusted_gds_total_train, log.(train_uq))

In [None]:
alt_preds = [exp(alt_coeffs'*gd) for gd in adjusted_gds_total_train]

In [None]:
basic_parity_plot(train_uq, alt_preds; min_val=0.0, max_val=0.23, marker_size=4)

In [None]:
@show mean(abs.(train_uq .- alt_preds))
@show sqrt(mean((train_uq .- alt_preds).^2))

In [None]:
function compute_basic_estimated_uqs(configs, coeffs)
    gds = [sum(get_values(get_local_descriptors(config))) for config in configs]
    preds = [exp(coeffs'*gd) for gd in gds]
    preds
end

In [None]:
compute_basic_estimated_uqs(combined_test_ds, new_coeffs)

In [None]:
test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]

In [None]:
#ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
#ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
#calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]
#etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
#etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]
calib_scores = abs.(ecalib_pred .- ecalib_ref) ./ calib_uq
test_abs_residuals = abs.(etest_pred .- etest_ref)

test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]
est_test_uq = compute_basic_estimated_uqs(combined_test_ds, new_coeffs)

test_abs_residuals = abs.(etest_pred .- etest_ref)

alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(calib_scores,test_uq, test_abs_residuals)

In [None]:
make_custom_calibration_plot1(alpha_refs,alpha_pred; text_size=24, label_size=28)

In [None]:
compute_miscalibration_area(alpha_refs, alpha_pred)

In [None]:
est_alpha_pred = generate_predicted_alphas(calib_scores,est_test_uq, test_abs_residuals)

In [None]:
make_custom_calibration_plot1(alpha_refs,est_alpha_pred; text_size=24, label_size=28)

In [None]:
compute_miscalibration_area(alpha_refs, est_alpha_pred)

In [None]:
save("dumb_coeffs_for_heuristic_uq.jld2", Dict("basic_coeffs" => new_coeffs,
                                               "alt_coeffs" => alt_coeffs))