In [None]:
using Revise
using Pkg; Pkg.activate(".")

In [None]:
using Unitful
using PotentialLearning
using Random: randperm
using JLD2
using InteratomicPotentials
using AtomsBase, AtomsCalculators
using Statistics
using CairoMakie, ColorSchemes

In [None]:
ensemble_members = load("ace_cmte1.jld2", "members")

In [None]:
includet("../files/committee_potentials.jl")
includet("../files/committee_qois.jl")

In [None]:
my_cmte = CommitteePotential(ensemble_members; energy_units=u"eV", length_units=u"Å")
cmte_energy = CmteEnergy(Statistics.var, strip_units=true)
# !!!!! important, I changed this to be variance, not std

In [None]:
datasets = load("datasets_with_descriptors.jld2")
pristine_base_calib_ds = datasets["pristine_base_calib_ds"]
pristine_base_test_ds = datasets["pristine_base_test_ds"]
frenkel_base_calib_ds = datasets["frenkel_base_calib_ds"]
frenkel_base_test_ds = datasets["frenkel_base_test_ds"]

Just doing a single qhat for a single energy

In [None]:
includet("../files/conformal_prediction_utils.jl")

In [None]:
# from subsampling_dpp.jl in PL.jl examples
function concat_dataset(confs::Vector{DataSet})
    N = length(confs)
    confs_vec = [[confs[i][j] for j = 1:length(confs[i])] for i = 1:N]
    confs_all = reduce(vcat, confs_vec)
    return DataSet(confs_all)
end

In [None]:
orig_combined_calib_ds = concat_dataset([pristine_base_calib_ds; frenkel_base_calib_ds])
orig_combined_test_ds = concat_dataset([pristine_base_test_ds; frenkel_base_test_ds])

In [None]:
rand_idxs = randperm(1500)
combined_calib_ds = orig_combined_calib_ds[rand_idxs]
combined_test_ds = orig_combined_test_ds[rand_idxs]

In [None]:
cmte_cov_energy = CmteEnergyCov(true)

In [None]:
# Already random, so just take every two
ediff_combined_calib_ref = Float64[]
ediff_combined_calib_pred = Float64[]
sys1_combined_calib_uqs = Float64[]
sys2_combined_calib_uqs = Float64[]
ediff_combined_calib_cov_uq = Float64[]

sys1_combined_calib_epreds = Float64[]
sys2_combined_calib_epreds = Float64[]
for i in 1:2:length(combined_calib_ds)
    sys1 = combined_calib_ds[i]
    sys2 = combined_calib_ds[i+1]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(ediff_combined_calib_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(sys1_combined_calib_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(sys2_combined_calib_epreds,e2_pred)

    push!(ediff_combined_calib_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(sys1_combined_calib_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(sys2_combined_calib_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(ediff_combined_calib_cov_uq, cov_uq)

end

In [None]:
# Var(X1) + Var(-X2) + 2*Cov(X1,-X2), note Var(-X2) = Var(X2)
calib_ediff_uq = sqrt.(sys1_combined_calib_uqs .+ sys2_combined_calib_uqs .+ 2*ediff_combined_calib_cov_uq)

In [None]:
large_pristine_ds = load("large_pristine_ds.jld2", "large_pristine_ds")
large_8x_frenkel_ds = load("large_8x_frenkel_ds.jld2", "large_8x_frenkel_ds")
large_dilute_frenkel_ds = load("large_dilute_frenkel_ds.jld2", "large_dilute_frenkel_ds")

In [None]:
using ColorSchemes

function make_custom_calibration_plot1(expected_ps, observed_ps;
                                      width=600,
                                      colormap=:viridis,
                                      color_value=0.6,  # Value between 0-1 in the colormap
                                      main_line_width=3.0,
                                      band_alpha=0.2,
                                      axis_color=:black,
                                      text_size=18,
                                      label_size=22,
                                      grid_visible=true,
                                      grid_color=(:gray, 0.3),
                                      grid_width=0.5)
    # Convert to percentages
    #expected_ps = expected_ps .* 100
    #observed_ps = observed_ps .* 100

    expected_ps = (1.0 .- expected_ps).* 100
    observed_ps = (1.0 .- observed_ps).* 100
    # Get color from colormap
    colormap = :lajolla
    #axis_color = get(ColorSchemes.colorschemes[colormap], 0.4)
    #grid_color = (axis_color, 0.3)
    base_band_color = get(ColorSchemes.colorschemes[colormap], 0.4)
    band_color = (base_band_color, band_alpha)

    line_color = get(ColorSchemes.colorschemes[colormap], 0.4)

    # Get color from colormap
    #color = get(ColorSchemes.colorschemes[colormap], color_value)
    #band_color = (color, band_alpha)
    #line_color=:black

    fig = Figure(resolution=(width, width), fontsize=text_size, figure_padding=30)
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        xlabel="Expected Confidence Level",
        ylabel="Observed Confidence Level",
        limits=(0, 100, 0, 100),
        xlabelsize=label_size,
        ylabelsize=label_size,
        xticklabelsize=text_size,
        yticklabelsize=text_size,
        spinewidth=1.5,
        xgridvisible=grid_visible,
        ygridvisible=grid_visible,
        xgridcolor=grid_color,
        ygridcolor=grid_color,
        xgridwidth=grid_width,
        ygridwidth=grid_width
    )

    #Set spine and tick colors
    ax.bottomspinecolor = axis_color
    ax.leftspinecolor = axis_color
    ax.rightspinecolor = axis_color
    ax.topspinecolor = axis_color

    ax.xticklabelcolor = axis_color
    ax.yticklabelcolor = axis_color
    ax.xlabelcolor = axis_color
    ax.ylabelcolor = axis_color

    # Main line - made bolder
    lines!(ax, expected_ps, observed_ps, color=line_color, linewidth=main_line_width)

    # Diagonal reference line
    lines!(ax, expected_ps, expected_ps, linestyle=:dash, color=line_color, alpha=0.6, linewidth=1.5)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=band_color)
    #band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks
    ax.xticks = 0:20:100
    ax.yticks = 0:20:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    return fig
end

In [None]:
function custom_histogram(data;
    width=600,
    bins=500,
    colormap=:viridis,
    color_value=0.6,
    title="Histogram",
    xlabel="Value",
    ylabel="Frequency",
    fill_alpha=0.8,
    edge_linewidth=1.0,
    axis_color=:black,
    text_size=18,
    label_size=22,
    grid_visible=false,
    grid_color=(:gray, 0.3),
    grid_linewidth=0.5,
    bar_color=nothing,
    edge_color=nothing,
    normalize=false,
    kde=false,
    kde_linewidth=3.0,
    kde_color=:black)

# Get color from colormap if no specific colors provided
base_color = get(ColorSchemes.colorschemes[colormap], color_value)
bar_color = isnothing(bar_color) ? (base_color, fill_alpha) : bar_color
#edge_color = isnothing(edge_color) ? darker(base_color, 0.2) : edge_color

# Create figure and axis with better formatting
fig = Figure(resolution=(width, width), fontsize=text_size)

# Calculate sensible limits with buffer
data_min = minimum(data)
data_max = maximum(data)
#buffer = (data_max - data_min) * 0.05
#x_min = data_min - buffer
#x_max = data_max + buffer
#x_min = -0.01
#x_max = 1.0
x_min= data_min
x_max = data_max

# Create axis with formatting
ax = Axis(fig[1, 1],
#title=title,
xlabel=xlabel,
ylabel=ylabel,
xlabelsize=label_size,
ylabelsize=label_size,
titlesize=label_size,
xticklabelsize=text_size,
yticklabelsize=text_size,
spinewidth=1.5,
xgridvisible=grid_visible,
ygridvisible=grid_visible,
xgridcolor=grid_color,
ygridcolor=grid_color,
xgridwidth=grid_linewidth,
ygridwidth=grid_linewidth
)

# Set spine and tick colors

ax.bottomspinecolor = axis_color
ax.leftspinecolor = axis_color
ax.rightspinecolor = axis_color
ax.topspinecolor = axis_color

#ax.xticks = 0:0.2:1.0
#ax.yticks = 0:10:50

ax.xticklabelcolor = axis_color
ax.yticklabelcolor = axis_color
ax.xlabelcolor = axis_color
ax.ylabelcolor = axis_color
ax.titlecolor = axis_color

# Add extra padding to avoid cutting off labels
#fig.margin = 20

# Create the histogram
hist = hist!(ax, data,
bins=bins,
color=bar_color,
#strokecolor=edge_color,
strokecolor=bar_color,
strokewidth=edge_linewidth,
normalization=normalize ? :pdf : :none)

# Optionally add KDE curve
if kde
density = kde!(ax, data,
color=kde_color,
linewidth=kde_linewidth,
label="KDE")

# Add legend if KDE is used
axislegend(ax, position=:rt, framevisible=true,
framecolor=(:black, 0.2),
padding=(10, 10, 10, 10),
labelsize=text_size-2)
end

# Adjust x limits
ax.limits = (x_min, x_max, nothing, nothing)

return fig
end

In [None]:
function custom_histogram2(data1, data2;
    width=600,
    bins=10,
    colormap=:viridis,
    color_value=0.0,
    title="Histogram",
    xlabel="Score Values",
    ylabel="Frequency",
    fill_alpha=0.5,
    edge_linewidth=1.0,
    axis_color=:black,
    text_size=18,
    label_size=22,
    grid_visible=false,
    grid_color=(:gray, 0.3),
    grid_linewidth=0.5,
    bar_color=nothing,
    edge_color=nothing,
    normalize=false,
    kde=false,
    kde_linewidth=3.0,
    kde_color=:black)

# Get color from colormap if no specific colors provided
base_color1 = get(ColorSchemes.colorschemes[colormap], color_value+0.3)
bar_color1 = isnothing(bar_color) ? (base_color1, 0.8) : bar_color

base_color2 = get(ColorSchemes.colorschemes[colormap], color_value)
bar_color2 = isnothing(bar_color) ? (base_color2, fill_alpha) : bar_color

#base_color1 = :black
#bar_color1 = (:black, fill_alpha)
#
#base_color2 = :black
#bar_color2 = (:black, fill_alpha)

#edge_color = isnothing(edge_color) ? darker(base_color, 0.2) : edge_color

# Create figure and axis with better formatting
fig = Figure(resolution=(width, width), fontsize=text_size)

# Calculate sensible limits with buffer
#data1 = [datum for datum in data1 if datum < 2]
#data2 = [datum for datum in data2 if datum < 2]
data_min = minimum([data1;data2])
data_max = maximum([data1;data2])
#buffer = (data_max - data_min) * 0.05
#x_min = data_min - buffer
#x_max = data_max + buffer
x_min = data_min
x_max = data_max
#x_min = -0.01
#x_max = 1.0

# Create axis with formatting
ax = Axis(fig[1, 1],
#title=title,
xlabel=xlabel,
ylabel=ylabel,
xlabelsize=label_size,
ylabelsize=label_size,
titlesize=label_size,
xticklabelsize=text_size,
yticklabelsize=text_size,
spinewidth=1.5,
xgridvisible=grid_visible,
ygridvisible=grid_visible,
xgridcolor=grid_color,
ygridcolor=grid_color,
xgridwidth=grid_linewidth,
ygridwidth=grid_linewidth
)

# Set spine and tick colors

ax.bottomspinecolor = axis_color
ax.leftspinecolor = axis_color
ax.rightspinecolor = axis_color
ax.topspinecolor = axis_color

#ax.xticks = 0:0.2:1.0
#ax.yticks = 0:10:50

ax.xticklabelcolor = axis_color
ax.yticklabelcolor = axis_color
ax.xlabelcolor = axis_color
ax.ylabelcolor = axis_color
ax.titlecolor = axis_color

# Add extra padding to avoid cutting off labels
#fig.margin = 20

# Create the histogram
hist = hist!(ax, data1,
bins=bins,
color=bar_color1,
#strokecolor=edge_color,
strokecolor=bar_color1,
strokewidth=edge_linewidth,
normalization=normalize ? :pdf : :none)

hist = hist!(ax, data2,
bins=bins,
color=bar_color2,
#strokecolor=edge_color,
strokecolor=bar_color2,
strokewidth=edge_linewidth,
normalization=normalize ? :pdf : :none)

# Optionally add KDE curve
#if kde
#density = kde!(ax, data,
#color=kde_color,
#linewidth=kde_linewidth,
#label="KDE")
#
## Add legend if KDE is used
#axislegend(ax, position=:rt, framevisible=true,
#framecolor=(:black, 0.2),
#padding=(10, 10, 10, 10),
#labelsize=text_size-2)
#end

# Adjust x limits
ax.limits = (x_min, x_max, nothing, nothing)

return fig
end

In [None]:
calib_residuals = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref)


In [None]:
pristine_idxs = randperm(102)

index_pairs = Tuple{Int64,Int64}[]
for i in 1:102
    pi = pristine_idxs[i]
    large_8x_idxs = randperm(102)
    for j in 1:10
        push!(index_pairs, (pi, large_8x_idxs[j]))
    end
end



In [None]:
# pristine, 8x
p8only_ediff_combined_test_ref = Float64[]
p8only_ediff_combined_test_pred = Float64[]
p8only_sys1_combined_test_uqs = Float64[]
p8only_sys2_combined_test_uqs = Float64[]
p8only_ediff_combined_test_cov_uq = Float64[]

p8only_sys1_combined_test_epreds = Float64[]
p8only_sys2_combined_test_epreds = Float64[]
for i in eachindex(index_pairs)
    pi = index_pairs[i][1]
    fi = index_pairs[i][2]
    sys1 = large_pristine_ds[pi]
    sys2 = large_8x_frenkel_ds[fi]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(p8only_ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(p8only_sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(p8only_sys2_combined_test_epreds,e2_pred)

    push!(p8only_ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(p8only_sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(p8only_sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(p8only_ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
p8only_test_ediff_uq = sqrt.(p8only_sys1_combined_test_uqs .+ p8only_sys2_combined_test_uqs .+ 2*p8only_ediff_combined_test_cov_uq)
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
p8only_test_abs_residuals_combined = abs.(p8only_ediff_combined_test_pred .- p8only_ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

#p8only_alpha_pred, debug1, debug2 = generate_predicted_alphas(ediff_combined_calib_scores,p8only_test_ediff_uq, p8only_test_abs_residuals_combined)
p8only_alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,p8only_test_ediff_uq, p8only_test_abs_residuals_combined)

In [None]:
p8only_ediff_fig = make_custom_calibration_plot1(alpha_refs,p8only_alpha_pred; text_size=24, label_size=28)
save("p8only_calibration.svg", p8only_ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,p8only_alpha_pred)

In [None]:
qhat= calibrate(ediff_combined_calib_pred,ediff_combined_calib_ref,calib_ediff_uq, 0.2)

qhat_uq = qhat*p8only_test_ediff_uq

In [None]:
p8only_scores = p8only_test_abs_residuals_combined ./p8only_test_ediff_uq
#p8hist_fig = custom_histogram2(ediff_combined_calib_scores, p8only_scores, bins=50, colormap=:lipari, text_size=24, label_size=28)
#save("p8hist.svg", p8hist_fig)

#p8hist_fig = custom_histogram2([datum for datum in calib_ediff_uq if datum < 0.5],
#                                [datum for datum in p8only_test_ediff_uq if datum < 0.5], bins=50, colormap=:roma, text_size=24, label_size=28)
#
p8hist_fig = custom_histogram2([datum for datum in p8only_test_ediff_uq[1:750] if datum < 10.0],
                                [datum for datum in calib_ediff_uq if datum < 10.0],
                                 bins=50, colormap=:roma, text_size=28, label_size=32, xlabel="Uncertainty Estimate (eV)")
save("p8_uq_histogram.svg",p8hist_fig)


In [None]:
p8_res_hist_fig = custom_histogram2([datum for datum in p8only_test_abs_residuals_combined[1:750]  if datum < 5.0],
                                [datum for datum in calib_residuals if datum < 5.0],
                                bins=50, colormap=:roma, text_size=28, label_size=32, xlabel="Absolute Residuals (eV)")
save("p8_res_hist_fig.svg", p8_res_hist_fig)

In [None]:
function custom_parity_plot(qhat_uq, res;
    title="Parity Plot Subset",
    xlabel="Heuristic Uncertainty (eV)",
    ylabel="Residuals(eV)",
    width=600,
    colormap=:viridis,
    color_value=0.6,
    marker_size=10,
    line_width=3.0,
    axis_color=:black,
    text_size=18,
    label_size=22,
    grid_visible=false,
    grid_color=(:gray, 0.3),
    grid_linewidth=0.5,
    errorbar_color=nothing,
    marker_color=nothing,
    diagonal_color=:red,
    diagonal_alpha=0.6,
    diagonal_style=:dash)



# Create figure and axis with better formatting
fig = Figure(resolution=(width, width), fontsize=text_size, figure_padding=30)

# Calculate min and max for setting plot limits
min_val = min(minimum(qhat_uq), minimum(res))
max_val = max(maximum(qhat_uq), maximum(res))
#
#min_val = -3257.35
#max_val = -3257.15

## Add a small buffer to the limits to avoid cutting off points or labels
#buffer = (max_val - min_val) * 0.05
#plot_min = min_val - buffer
#plot_max = max_val + buffer

ax = Axis(fig[1, 1],
#title=title,
xlabel=xlabel,
ylabel=ylabel,
#limits=(plot_min, plot_max, plot_min, plot_max),
#limits=(min_val, max_val, min_val-0.5, max_val+0.5),
#limits=(0.0,1.0,0.0,1.0),
titlesize=label_size,
xlabelsize=label_size,
ylabelsize=label_size,
xticklabelsize=text_size,
yticklabelsize=text_size,
spinewidth=1.5,
xgridvisible=grid_visible,
ygridvisible=grid_visible,
xgridcolor=grid_color,
ygridcolor=grid_color,
xgridwidth=grid_linewidth,
ygridwidth=grid_linewidth
)

# Set spine and tick colors
ax.bottomspinecolor = axis_color
ax.leftspinecolor = axis_color
ax.rightspinecolor = axis_color
ax.topspinecolor = axis_color

ax.xticklabelcolor = axis_color
ax.yticklabelcolor = axis_color
ax.xlabelcolor = axis_color
ax.ylabelcolor = axis_color
ax.titlecolor = axis_color

# Add diagonal reference line
lines!(ax, [min_val,max_val], [min_val,max_val],
color=diagonal_color,
linestyle=diagonal_style,
linewidth=line_width - 1,  # Slightly thinner than main points
alpha=diagonal_alpha,
label="Perfect Prediction")


# Scatter plot of points
covered_qhatuq = Float64[]
covered_res = Float64[]
uncovered_qhatuq = Float64[]
uncovered_res = Float64[]
for i in eachindex(qhat_uq)
    if res[i] > qhat_uq[i]
        push!(uncovered_qhatuq, qhat_uq[i])
        push!(uncovered_res, res[i])
    else
        push!(covered_qhatuq, qhat_uq[i])
        push!(covered_res, res[i])
    end
end

#marker_color=:black
#scatter!(ax,qhat_uq, res,
#color=marker_color,
#markersize=marker_size)

@show uncovered_qhatuq
@show uncovered_res
scatter!(ax,covered_qhatuq, covered_res,
color=:black,
markersize=marker_size)


scatter!(ax,uncovered_qhatuq, uncovered_res,
color=:red,
markersize=marker_size)
# Set equal aspect ratio (usually important for parity plots)
#ax.aspect = DataAspect()

# Add legend with better formatting
#axislegend(ax, position=:lt, framevisible=true, framecolor=(:black, 0.2),
#padding=(10, 10, 10, 10), labelsize=text_size-2)

return fig
end

In [None]:
my_qhatuq = [datum for datum in qhat_uq if datum < 10.0]
my_res = [datum for datum in p8only_test_abs_residuals_combined if datum < 10.0]

custom_parity_plot(my_qhatuq , my_res)
#custom_parity_plot(sort(debug1)[1:end-1] , sort(debug2)[1:end-1])

In [None]:
pristine_idxs = randperm(102)

index_pairs2 = Tuple{Int64,Int64}[]
for i in 1:102
    pi = pristine_idxs[i]
    large_dilute_idxs = randperm(102)
    for j in 1:10
        push!(index_pairs2, (pi, large_dilute_idxs[j]))
    end
end

In [None]:
# pristine, dilute
pdonly_ediff_combined_test_ref = Float64[]
pdonly_ediff_combined_test_pred = Float64[]
pdonly_sys1_combined_test_uqs = Float64[]
pdonly_sys2_combined_test_uqs = Float64[]
pdonly_ediff_combined_test_cov_uq = Float64[]

pdonly_sys1_combined_test_epreds = Float64[]
pdonly_sys2_combined_test_epreds = Float64[]
for i in eachindex(index_pairs2)
    pi = index_pairs2[i][1]
    fi = index_pairs2[i][2]
    sys1 = large_pristine_ds[pi]
    sys2 = large_dilute_frenkel_ds[fi]
    e1_ref = ustrip(get_values(get_energy(sys1)))
    e2_ref = ustrip(get_values(get_energy(sys2)))

    push!(pdonly_ediff_combined_test_ref, e2_ref - e1_ref)

    e1_pred = ustrip(PotentialLearning.potential_energy(sys1, my_cmte))
    push!(pdonly_sys1_combined_test_epreds,e1_pred)
    e2_pred = ustrip(PotentialLearning.potential_energy(sys2, my_cmte))
    sys2_epreds = push!(pdonly_sys2_combined_test_epreds,e2_pred)

    push!(pdonly_ediff_combined_test_pred, e2_pred - e1_pred)

    sys1_uq = ustrip(compute(cmte_energy,sys1,my_cmte))
    push!(pdonly_sys1_combined_test_uqs,sys1_uq)
    sys2_uq = ustrip(compute(cmte_energy,sys2,my_cmte))
    push!(pdonly_sys2_combined_test_uqs,sys2_uq)

    cov_uq = ustrip(compute(cmte_cov_energy,sys1,sys2,my_cmte; flip_second_sign=true))
    push!(pdonly_ediff_combined_test_cov_uq, cov_uq)
end

In [None]:
pdonly_test_ediff_uq = sqrt.(pdonly_sys1_combined_test_uqs .+ pdonly_sys2_combined_test_uqs .+ 2*pdonly_ediff_combined_test_cov_uq)
ediff_combined_calib_scores = abs.(ediff_combined_calib_pred .- ediff_combined_calib_ref) ./ calib_ediff_uq
pdonly_test_abs_residuals_combined = abs.(pdonly_ediff_combined_test_pred .- pdonly_ediff_combined_test_ref)
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

#pdonly_alpha_pred, debug1, debug2 = generate_predicted_alphas(ediff_combined_calib_scores,pdonly_test_ediff_uq, pdonly_test_abs_residuals_combined)
pdonly_alpha_pred = generate_predicted_alphas(ediff_combined_calib_scores,pdonly_test_ediff_uq, pdonly_test_abs_residuals_combined)

In [None]:
pdonly_ediff_fig = make_custom_calibration_plot1(alpha_refs,pdonly_alpha_pred; text_size=24, label_size=28)
#save("pdonly_calibration.svg", pdonly_ediff_fig)

In [None]:
compute_miscalibration_area(alpha_refs,pdonly_alpha_pred)

In [None]:
qhat_pd= calibrate(ediff_combined_calib_pred,ediff_combined_calib_ref,calib_ediff_uq, 0.2)

qhat_pd_uq = qhat_pd*pdonly_test_ediff_uq

In [None]:
pdonly_scores = pdonly_test_abs_residuals_combined ./pdonly_test_ediff_uq
#p8hist_fig = custom_histogram2(ediff_combined_calib_scores, pdonly_scores, bins=50, colormap=:lipari, text_size=24, label_size=28)
#save("p8hist.svg", p8hist_fig)

#p8hist_fig = custom_histogram2([datum for datum in calib_ediff_uq if datum < 0.5],
#                                [datum for datum in pdonly_test_ediff_uq if datum < 0.5], bins=50, colormap=:roma, text_size=24, label_size=28)
#
pdhist_fig = custom_histogram2([datum for datum in pdonly_test_ediff_uq[1:750] if datum < 10.0],
                                [datum for datum in calib_ediff_uq if datum < 10.0],
                                 bins=50, colormap=:roma, text_size=28, label_size=32, xlabel="Uncertainty Estimate (eV)")
save("pdhist_uq.svg", pdhist_fig)

In [None]:
pdhist_res_fig = custom_histogram2([datum for datum in pdonly_test_abs_residuals_combined[1:750]  if datum < 5.0],
                                [datum for datum in calib_residuals if datum < 5.0],
                                bins=50, colormap=:roma, text_size=28, label_size=32, xlabel="Absolute Residuals (eV)")
save("pdhist_res.svg", pdhist_res_fig)

In [None]:
my_qhatuq_pd = [datum for datum in qhat_pd_uq if datum < 10.0]
my_res_pd = [datum for datum in pdonly_test_abs_residuals_combined if datum < 10.0]

custom_parity_plot(my_qhatuq_pd , my_res_pd)
#custom_parity_plot(sort(debug1)[1:end-1] , sort(debug2)[1:end-1])

I still think it's possible that my calibration plots are flipped somehow. Ah, I've had a mistake in my calibration plots, plotting alpha not 1-alpha. Dumb, dumb. Going back and revising