In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes

In [None]:
ensemble_members = load("ace_cmte1.jld2", "members")

In [None]:
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Ã…")
cmte_energy = CmteEnergy(Statistics.var, strip_units=true)
# !!!!! important, I changed this to be variance, not std

In [None]:
datasets = load("datasets_with_descriptors.jld2")
pristine_base_calib_ds = datasets["pristine_base_calib_ds"]
pristine_base_test_ds = datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = datasets["frenkel_base_test_ds"]



Just doing a single qhat for a single energy

In [None]:
includet("../files/conformal_prediction_utils.jl")

In [None]:
# from subsampling_dpp.jl in PL.jl examples
function concat_dataset(confs::Vector{DataSet})
    N = length(confs)
    confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
    confs_all = reduce(vcat, confs_vec)
    return DataSet(confs_all)
end

In [None]:
orig_combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
orig_combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
rand_idxs = randperm(1500)
combined_calib_ds = orig_combined_calib_ds[rand_idxs]
combined_test_ds = orig_combined_test_ds[rand_idxs]

In [None]:
cmte_cov_energy = CmteEnergyCov(true)

In [None]:
# Already random, so just take every two
ediff_combined_calib_ref = Float64[]
ediff_combined_calib_pred = Float64[]
sys1_combined_calib_uqs = Float64[]
sys2_combined_calib_uqs = Float64[]
ediff_combined_calib_cov_uq = Float64[]

sys1_combined_calib_epreds = Float64[]
sys2_combined_calib_epreds = Float64[]
for i in 1:2:length(combined_calib_ds)
    sys1 = combined_calib_ds[i]
    sys2 = combined_calib_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_calib_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_calib_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_calib_epreds,e2_pred)

    push!(ediff_combined_calib_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_calib_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_calib_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_calib_cov_uq, cov_uq)

end

In [None]:
# Already random, so just take every two
ediff_combined_test_ref = Float64[]
ediff_combined_test_pred = Float64[]
sys1_combined_test_uqs = Float64[]
sys2_combined_test_uqs = Float64[]
ediff_combined_test_cov_uq = Float64[]

sys1_combined_test_epreds = Float64[]
sys2_combined_test_epreds = Float64[]
for i in 1:2:length(combined_test_ds)
    sys1 = combined_test_ds[i]
    sys2 = combined_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_test_epreds,e2_pred)

    push!(ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
# Var(X1) + Var(-X2) + 2*Cov(X1,-X2), note Var(-X2) = Var(X2)
test_ediff_uq = sqrt.(sys1_combined_test_uqs .+ sys2_combined_test_uqs .+ 2*ediff_combined_test_cov_uq)
calib_ediff_uq = sqrt.(sys1_combined_calib_uqs .+ sys2_combined_calib_uqs .+ 2*ediff_combined_calib_cov_uq)

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
test_abs_residuals_combined = abs.(ediff_combined_test_pred .- ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,test_ediff_uq, test_abs_residuals_combined)

In [None]:
using ColorSchemes

function make_custom_calibration_plot1(expected_ps, observed_ps;
                                      width=600,
                                      colormap=:viridis,
                                      color_value=0.6,  # Value between 0-1 in the colormap
                                      main_line_width=3.0,
                                      band_alpha=0.2,
                                      axis_color=:black,
                                      text_size=18,
                                      label_size=22,
                                      grid_visible=true,
                                      grid_color=(:gray, 0.3),
                                      grid_width=0.5)
    # Convert to percentages
    #expected_ps = expected_ps .* 100
    #observed_ps = observed_ps .* 100

    expected_ps = (1.0 .- expected_ps).* 100
    observed_ps = (1.0 .- observed_ps).* 100

    # Get color from colormap
    colormap = :managua
    #axis_color = get(ColorSchemes.colorschemes[colormap], 0.4)
    #grid_color = (axis_color, 0.3)
    base_band_color = get(ColorSchemes.colorschemes[colormap], 0.5)
    band_color = (base_band_color, band_alpha)

    line_color = get(ColorSchemes.colorschemes[colormap], 0.5)

    # Get color from colormap
    #color = get(ColorSchemes.colorschemes[colormap], color_value)
    #band_color = (color, band_alpha)
    #line_color=:black

    fig = Figure(resolution=(width, width), fontsize=text_size, figure_padding=30)
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        xlabel="Expected Confidence Level",
        ylabel="Observed Confidence Level",
        limits=(0, 100, 0, 100),
        xlabelsize=label_size,
        ylabelsize=label_size,
        xticklabelsize=text_size,
        yticklabelsize=text_size,
        spinewidth=1.5,
        xgridvisible=grid_visible,
        ygridvisible=grid_visible,
        xgridcolor=grid_color,
        ygridcolor=grid_color,
        xgridwidth=grid_width,
        ygridwidth=grid_width
    )

    #Set spine and tick colors
    ax.bottomspinecolor = axis_color
    ax.leftspinecolor = axis_color
    ax.rightspinecolor = axis_color
    ax.topspinecolor = axis_color

    ax.xticklabelcolor = axis_color
    ax.yticklabelcolor = axis_color
    ax.xlabelcolor = axis_color
    ax.ylabelcolor = axis_color

    # Main line - made bolder
    lines!(ax, expected_ps, observed_ps, color=line_color, linewidth=main_line_width)

    # Diagonal reference line
    lines!(ax, expected_ps, expected_ps, linestyle=:dash, color=line_color, alpha=0.6, linewidth=1.5)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=band_color)
    #band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks
    ax.xticks = 0:20:100
    ax.yticks = 0:20:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    return fig
end

In [None]:
ediff_fig = make_custom_calibration_plot1(alpha_refs,alpha_pred; text_size=24, label_size=28)
save("basic_ediff_calibration.svg", ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,alpha_pred)

In [None]:
new_rand_idxs = randperm(length(pristine_base_test_ds))
pristine_test_ds = pristine_base_test_ds[new_rand_idxs]
frenkel_test_ds = frenkel_base_test_ds[new_rand_idxs]

In [None]:
# pristine-pristine
pp_ediff_combined_test_ref = Float64[]
pp_ediff_combined_test_pred = Float64[]
pp_sys1_combined_test_uqs = Float64[]
pp_sys2_combined_test_uqs = Float64[]
pp_ediff_combined_test_cov_uq = Float64[]

pp_sys1_combined_test_epreds = Float64[]
pp_sys2_combined_test_epreds = Float64[]
for i in 1:2:length(pristine_test_ds)
    sys1 = pristine_test_ds[i]
    sys2 = pristine_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(pp_ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(pp_sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(pp_sys2_combined_test_epreds,e2_pred)

    push!(pp_ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(pp_sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(pp_sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(pp_ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
pp_test_ediff_uq = sqrt.(pp_sys1_combined_test_uqs .+ pp_sys2_combined_test_uqs .+ 2*pp_ediff_combined_test_cov_uq)

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
pp_test_abs_residuals_combined = abs.(pp_ediff_combined_test_pred .- pp_ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

pp_alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,pp_test_ediff_uq, pp_test_abs_residuals_combined)

In [None]:
using ColorSchemes

function make_custom_calibration_plot2(expected_ps, observed_ps;
                                      width=600,
                                      colormap=:viridis,
                                      color_value=0.6,  # Value between 0-1 in the colormap
                                      main_line_width=3.0,
                                      band_alpha=0.2,
                                      axis_color=:black,
                                      text_size=18,
                                      label_size=22,
                                      grid_visible=true,
                                      grid_color=(:gray, 0.3),
                                      grid_width=0.5)
    # Convert to percentages
    #expected_ps = expected_ps .* 100
    #observed_ps = observed_ps .* 100

    expected_ps = (1.0 .- expected_ps).* 100
    observed_ps = (1.0 .- observed_ps).* 100

    # Get color from colormap
    colormap = :managua
    #axis_color = get(ColorSchemes.colorschemes[colormap], 0.4)
    #grid_color = (axis_color, 0.3)
    base_band_color = get(ColorSchemes.colorschemes[colormap], 0.35)
    band_color = (base_band_color, band_alpha)

    line_color = get(ColorSchemes.colorschemes[colormap], 0.35)

    # Get color from colormap
    #color = get(ColorSchemes.colorschemes[colormap], color_value)
    #band_color = (color, band_alpha)
    #line_color=:black

    fig = Figure(resolution=(width, width), fontsize=text_size, figure_padding=40)
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        #xlabel="Expected Confidence Level",
        #ylabel="Observed Confidence Level",
        limits=(0, 100, 0, 100),
        xlabelsize=label_size,
        ylabelsize=label_size,
        xticklabelsize=text_size,
        yticklabelsize=text_size,
        spinewidth=1.5,
        xgridvisible=grid_visible,
        ygridvisible=grid_visible,
        xgridcolor=grid_color,
        ygridcolor=grid_color,
        xgridwidth=grid_width,
        ygridwidth=grid_width
    )

    #Set spine and tick colors
    ax.bottomspinecolor = axis_color
    ax.leftspinecolor = axis_color
    ax.rightspinecolor = axis_color
    ax.topspinecolor = axis_color

    ax.xticklabelcolor = axis_color
    ax.yticklabelcolor = axis_color
    ax.xlabelcolor = axis_color
    ax.ylabelcolor = axis_color

    # Main line - made bolder
    lines!(ax, expected_ps, observed_ps, color=line_color, linewidth=main_line_width)

    # Diagonal reference line
    lines!(ax, expected_ps, expected_ps, linestyle=:dash, color=line_color, alpha=0.6, linewidth=1.5)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=band_color)
    #band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks
    ax.xticks = 0:20:100
    ax.yticks = 0:20:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    return fig
end

In [None]:
pp_ediff_fig = make_custom_calibration_plot2(alpha_refs,pp_alpha_pred; text_size=28, label_size=28)
save("pp_calibration.svg", pp_ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,pp_alpha_pred)

In [None]:
# frenkel-frenkel
ff_ediff_combined_test_ref = Float64[]
ff_ediff_combined_test_pred = Float64[]
ff_sys1_combined_test_uqs = Float64[]
ff_sys2_combined_test_uqs = Float64[]
ff_ediff_combined_test_cov_uq = Float64[]

ff_sys1_combined_test_epreds = Float64[]
ff_sys2_combined_test_epreds = Float64[]
for i in 1:2:length(frenkel_test_ds)
    sys1 = frenkel_test_ds[i]
    sys2 = frenkel_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ff_ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(ff_sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(ff_sys2_combined_test_epreds,e2_pred)

    push!(ff_ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(ff_sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(ff_sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ff_ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
ff_test_ediff_uq = sqrt.(ff_sys1_combined_test_uqs .+ ff_sys2_combined_test_uqs .+ 2*ff_ediff_combined_test_cov_uq)

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
ff_test_abs_residuals_combined = abs.(ff_ediff_combined_test_pred .- ff_ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

ff_alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,ff_test_ediff_uq, ff_test_abs_residuals_combined)

In [None]:
ff_ediff_fig = make_custom_calibration_plot2(alpha_refs,ff_alpha_pred; text_size=24, label_size=28)
save("ff_calibration.svg", ff_ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,ff_alpha_pred)

In [None]:
# frenkel-pristine
fp_ediff_combined_test_ref = Float64[]
fp_ediff_combined_test_pred = Float64[]
fp_sys1_combined_test_uqs = Float64[]
fp_sys2_combined_test_uqs = Float64[]
fp_ediff_combined_test_cov_uq = Float64[]

fp_sys1_combined_test_epreds = Float64[]
fp_sys2_combined_test_epreds = Float64[]
for i in 1:2:length(pristine_test_ds)
    sys1 = pristine_test_ds[i]
    sys2 = frenkel_test_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(fp_ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(fp_sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(fp_sys2_combined_test_epreds,e2_pred)

    push!(fp_ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(fp_sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(fp_sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(fp_ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
fp_test_ediff_uq = sqrt.(fp_sys1_combined_test_uqs .+ fp_sys2_combined_test_uqs .+ 2*fp_ediff_combined_test_cov_uq)

In [None]:
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
fp_test_abs_residuals_combined = abs.(fp_ediff_combined_test_pred .- fp_ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

fp_alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,fp_test_ediff_uq, fp_test_abs_residuals_combined)

In [None]:
fp_ediff_fig = make_custom_calibration_plot2(alpha_refs,fp_alpha_pred; text_size=24, label_size=28)
save("fp_calibration.svg", fp_ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,fp_alpha_pred)