In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes

In [None]:
ensemble_members = load("ace_cmte1.jld2", "members")

In [None]:
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.std, strip_units=true)

In [None]:
datasets = load("datasets_with_descriptors.jld2")
pristine_base_calib_ds = datasets["pristine_base_calib_ds"]
pristine_base_test_ds = datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = datasets["frenkel_base_test_ds"]



Just doing a single qhat for a single energy

In [None]:
includet("../files/conformal_prediction_utils.jl")

In [None]:
# from subsampling_dpp.jl in PL.jl examples
function concat_dataset(confs::Vector{DataSet})
    N = length(confs)
    confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
    confs_all = reduce(vcat, confs_vec)
    return DataSet(confs_all)
end

In [None]:
combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
orig_combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
ecalib_pred = [ustrip(PotentialLearning.potential_energy(sys,my_cmte)) for sys in combined_calib_ds]
ecalib_ref = [get_values(get_energy(config)) for config in combined_calib_ds]
calib_uq = [ustrip(compute(cmte_energy,config,my_cmte)) for config in combined_calib_ds]

In [None]:
num_test_raw = length(orig_combined_test_ds)
rand_idxs = randperm(num_test_raw)
combined_test_ds = orig_combined_test_ds[rand_idxs]

In [None]:
# Already random, so just take every two
ediff_test_ref = Float64[]
ediff_test_pred = Float64[]
ediff_test_uq = Float64[]
sys1_uqs = Float64[]
sys2_uqs = Float64[]

sys1_epreds = Float64[]
sys2_epreds = Float64[]
for i in 1:2:length(combined_test_ds)
    sys1 = combined_test_ds[i]
    sys2 = combined_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1,my_cmte))
    push!(sys1_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2,my_cmte))
    push!(sys2_epreds,e2_pred)

    push!(ediff_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_uqs,sys2_uq)
end

In [None]:
# for each (sys1,sys2) pair, output (low,high) for the prediction set of the energy difference
function compute_ediff_prediction_sets(sys1_epreds,
    sys2_epreds,
    sys1_uqs,
    sys2_uqs,
    qhat)
ediff_prediction_sets = Vector{Tuple{Float64,Float64}}()
for idx in 1:length(sys1_epreds)
band1 = qhat*sys1_uqs[idx]
band2 = qhat*sys2_uqs[idx]

upper_2 = sys2_epreds[idx] + band2
lower_1 = sys1_epreds[idx] - band1
bound1 = upper_2 - lower_1

lower_2 = sys2_epreds[idx] - band2
upper_1 = sys1_epreds[idx] + band1
bound2 = lower_2 - upper_1

push!(ediff_prediction_sets, tuple(sort([bound1,bound2])...))
end

ediff_prediction_sets
end

In [None]:
# for each (sys1,sys2) pair, output (low,high) for the prediction set of the energy difference
function correct_conservative_ediff_pred_sets(sys1_epreds,
    sys2_epreds,
    sys1_uqs,
    sys2_uqs,
    qhat)
ediff_prediction_sets = Vector{Tuple{Float64,Float64}}()
for idx in 1:length(sys1_epreds)
band1 = qhat*sys1_uqs[idx]
band2 = qhat*sys2_uqs[idx]

upper_2 = sys2_epreds[idx] + band2
lower_1 = sys1_epreds[idx] - band1

lower_2 = sys2_epreds[idx] - band2
upper_1 = sys1_epreds[idx] + band1

bound1 = upper_2 - lower_1
bound2 = upper_2 - upper_1
bound3 = lower_2 - upper_1
bound4 = lower_2 - lower_1

max_bound = maximum([bound1,bound2,bound3,bound4])
min_bound = minimum([bound1,bound2,bound3,bound4])
push!(ediff_prediction_sets, (min_bound,max_bound))
end

ediff_prediction_sets
end

In [None]:
num_test_ediff = length(ediff_test_ref)
alpha_compls = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_compls # i.e. iterate 0.99..0.01, but will then plot as 1-0.99...1-0.01

conservative_predicted_alphas = Float64[]
for alpha in alpha_refs
    qh = calibrate(ecalib_pred, ecalib_ref, calib_uq, alpha)
    ediff_pred_sets = correct_conservative_ediff_pred_sets(sys1_epreds,
                                                sys2_epreds,
                                                sys1_uqs,
                                                sys2_uqs,
                                                qh)

    ediff_low_bounds = [bound[1] for bound in ediff_pred_sets]
    ediff_high_bounds = [bound[2] for bound in ediff_pred_sets]
    predicted_alpha_compl = sum(ediff_low_bounds .<= ediff_test_ref .<= ediff_high_bounds) / num_test_ediff
    push!(conservative_predicted_alphas, 1.0 - predicted_alpha_compl)
end

In [None]:
using ColorSchemes

function make_custom_calibration_plot(expected_ps, observed_ps;
                                      width=600,
                                      colormap=:viridis,
                                      color_value=0.6,  # Value between 0-1 in the colormap
                                      main_line_width=3.0,
                                      band_alpha=0.2,
                                      axis_color=:black,
                                      text_size=18,
                                      label_size=22,
                                      grid_visible=true,
                                      grid_color=(:gray, 0.3),
                                      grid_width=0.5)
    # Convert to percentages
    #expected_ps = expected_ps .* 100
    #observed_ps = observed_ps .* 100
    expected_ps = (1.0 .- expected_ps).* 100
    observed_ps = (1.0 .- observed_ps).* 100

    # Get color from colormap
    colormap = :lajolla
    #axis_color = get(ColorSchemes.colorschemes[colormap], 0.4)
    #grid_color = (axis_color, 0.3)
    base_band_color = get(ColorSchemes.colorschemes[colormap], 0.45)
    band_color = (base_band_color, band_alpha)

    line_color = get(ColorSchemes.colorschemes[colormap], 0.4)

    fig = Figure(resolution=(width, width), fontsize=text_size, figure_padding=30)
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        xlabel="Expected Confidence Level",
        ylabel="Observed Confidence Level",
        limits=(0, 100, 0, 100),
        xlabelsize=label_size,
        ylabelsize=label_size,
        xticklabelsize=text_size,
        yticklabelsize=text_size,
        spinewidth=1.5,
        xgridvisible=grid_visible,
        ygridvisible=grid_visible,
        xgridcolor=grid_color,
        ygridcolor=grid_color,
        xgridwidth=grid_width,
        ygridwidth=grid_width
    )

    #Set spine and tick colors
    ax.bottomspinecolor = axis_color
    ax.leftspinecolor = axis_color
    ax.rightspinecolor = axis_color
    ax.topspinecolor = axis_color

    ax.xticklabelcolor = axis_color
    ax.yticklabelcolor = axis_color
    ax.xlabelcolor = axis_color
    ax.ylabelcolor = axis_color

    # Main line - made bolder
    lines!(ax, expected_ps, observed_ps, color=line_color, linewidth=main_line_width)

    # Diagonal reference line
    lines!(ax, expected_ps, expected_ps, linestyle=:dash, color=line_color, alpha=0.6, linewidth=1.5)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=band_color)
    #band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks
    ax.xticks = 0:20:100
    ax.yticks = 0:20:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    return fig
end

In [None]:
naive_fig = make_custom_calibration_plot(alpha_refs,conservative_predicted_alphas; text_size=24, label_size=28)
save("naive_calibration.svg", naive_fig)

In [None]:
compute_miscalibration_area(alpha_refs,conservative_predicted_alphas)