Fit a simple linear regression with the global descriptors of the training data to the training heuristic uncertainty (i.e., ensemble standard deviation). Had to log-transform the data first. The fit is not very good.  

In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes
using LinearAlgebra

In [None]:
includet("../files/conformal_prediction_utils.jl")
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")

In [None]:
training_data_dict = load("training_data.jld2")
frenkel_train_ds = training_data_dict["frenkel_train_ds"]
pristine_train_ds = training_data_dict["pristine_train_ds"]

In [None]:
calibtest_datasets = load("../cesmix_prez/datasets_with_descriptors.jld2")
pristine_base_calib_ds = calibtest_datasets["pristine_base_calib_ds"]
pristine_base_test_ds = calibtest_datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = calibtest_datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = calibtest_datasets["frenkel_base_test_ds"]

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
ace_cmte_dict = load("../cesmix_prez/ace_cmte1.jld2")
ensemble_members = ace_cmte_dict["members"]

In [None]:
ace = ACE(species            = [:Hf],
          body_order         = 4,
          polynomial_degree  = 10,
          wL                 = 1.5,
          csp                = 1.0,
          r0                 = 2.15,
          rcutoff            = 5.0)

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

qhat = calibrate(ecalib_pred, ecalib_ref, calib_uq, 0.1)

In [None]:
etest_pred = [ustrip(PotentialLearning.potential_energy(config,my_cmte)) for config in combined_test_ds]
etest_ref = [ustrip(get_values(get_energy(config))) for config in combined_test_ds]
test_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in combined_test_ds]

num_test = length(etest_pred)
test_abs_residuals = abs.(etest_pred .- etest_ref)

In [None]:
qhat_scores = qhat*test_uq
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
total_train = concat_dataset([frenkel_train_ds, pristine_train_ds])

In [None]:
train_uq = [ustrip(compute(cmte_energy, config, my_cmte)) for config in total_train]

train_ci = qhat .* train_uq


In [None]:
total_train_gds = [sum.(get_values(get_local_descriptors(config))) for config in total_train]

In [None]:
lambda = 0.01
A = reduce(hcat,total_train_gds)'
b = train_uq
AtA = A'*A
Atb = A'*b
coeffs = (AtA + lambda*I) \ Atb

In [None]:
train_uq

In [None]:
[coeffs'*gd for gd in total_train_gds]

In [None]:
function simple_regression(xvecs, yvec, lambda=0.01)
    A = reduce(hcat,xvecs)'
    b = yvec
    AtA = A'*A
    Atb = A'*b
    coeffs = (AtA + lambda*I) \ Atb
    coeffs
end

In [None]:
new_coeffs = simple_regression(total_train_gds, log.(train_uq))

In [None]:
preds = [exp(new_coeffs'*gd) for gd in total_train_gds]

In [None]:
includet("../files/conformal_prediction_plots.jl")

In [None]:
basic_parity_plot(train_uq, preds; min_val=0.0, max_val=0.2, marker_size=4)