In [None]:
using Revise

using PotentialLearning, InteratomicPotentials
using Unitful
using Random
using AtomsBase
using DelimitedFiles
using Statistics: mean, var
using StatsBase
using Clustering, Distances
using Trapz
using LinearAlgebra: Symmetric, eigen

#using CairoMakie CairoMakie.activate!()
using GLMakie; GLMakie.activate!(inline=false)

In [None]:
function compute_miscalibration_area(expected_ps, observed_ps)
    area = 0.0
    #for i in 2:length(expected_ps)-1
    #    trap = abs(trapz(expected_ps[i-1:i+1], observed_ps[i-1:i+1]) -
    #             trapz(expected_ps[i-1:i+1], expected_ps[i-1:i+1]))
    for i in 2:length(expected_ps)
        trap = abs(trapz(expected_ps[i-1:i], observed_ps[i-1:i]) -
                 trapz(expected_ps[i-1:i], expected_ps[i-1:i]))
        area += trap
    end
    area
end

# converted from Medford jupyter notebook via Claude
function make_calibration_plot(expected_ps, observed_ps; width=600)
    # Convert to percentages
    expected_ps = expected_ps .* 100
    observed_ps = observed_ps .* 100

    fig = Figure(resolution=(width, width))
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        xlabel="Expected conf. level",
        ylabel="Observed conf. level",
        limits=(0, 100, 0, 100)
    )

    # Main line
    lines!(ax, expected_ps, observed_ps)

    # Diagonal reference line
    lines!(ax, expected_ps, expected_ps, linestyle=:dash, alpha=0.4)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks - approximately 4 ticks on each axis
    ax.xticks = 0:10:100
    ax.yticks = 0:10:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    ## Add text for miscalibration area
    #text!(ax, "miscalc. area = $(round(area, digits=3))",
    #    position=(8, 2),
    #    align=(:left, :bottom)
    #)

    return fig
end

# Claude
function parity_plot(etest_ref, etest_pred, qhat_scored;
                     title="Parity Plot",
                     xlabel="Reference Values",
                     ylabel="Predicted Values",
                     figsize=(600, 600))
    # Create figure and axis
    fig = Figure(size=figsize)
    ax = Axis(fig[1, 1],
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    limits = (-5.0,-4.0,-5.0,-4.0))

    # Calculate min and max for setting plot limits
    min_val = min(minimum(etest_pred), minimum(etest_ref))
    max_val = max(maximum(etest_pred), maximum(etest_ref))

    # Add diagonal reference line
    lines!(ax, [min_val, max_val], [min_val, max_val],
    color=:red,
    linestyle=:dash,
    label="Perfect Prediction")

    # Plot scatter with error bars
    errorbars!(ax, etest_ref, etest_pred, qhat_scored,
    whiskerwidth=1,  # Width of error bar caps
    color=:cyan3)

    # Scatter plot of points
    scatter!(ax, etest_ref, etest_pred,
    color=:teal,
    markersize=10)

    # Set equal aspect ratio
    #ax.aspect = DataAspect()

    # Add legend
    axislegend(ax)

    return fig
end

function uncertainty_vs_residuals(uncertainty, residuals;
                                 title  ="Uncertainty vs. residuals",
                                 xlabel = "Distance",
                                 ylabel = "Residuals",
                                 figsize = (600,600))
    fig = Figure(size=figsize)
    ax  = Axis(fig[1,1],
               title=title,
               xlabel=xlabel,
               ylabel=ylabel)

    hlines!(ax, 0.0, color=:red, linestyle=:dash)

    scatter!(ax, uncertainty, residuals, markersize=1)

    #ax.aspect=DataAspect()
    fig
end

In [None]:
ace = ACE(species           = [:C,:H,:O,:N],
          body_order        = 3,
          polynomial_degree = 10,
          wL                = 2.0,
          csp               = 1.0,
          r0                = 1.43,
          rcutoff           = 4.4 )
lb = LBasisPotential(ace)
length(ace)

In [None]:
qm9_file = "../files/QM9/qm9_fullset_alldata.xyz"
raw_data = load_data(qm9_file, ExtXYZ(u"eV", u"Å"))
raw_data = DataSet([config for config in raw_data if !(:F in atomic_symbol(get_system(config)))])

Removing structures with Fluorine results in 1,1923 fewer configs

In [None]:
master_perm_idxs = readdlm("./primary_permutation.txt", Int64)

In [None]:
max_num_train = 120_001
possible_training_idxs = master_perm_idxs[1:max_num_train]
possible_test_idxs = master_perm_idxs[max_num_train+1:end]

In [None]:
num_train = 40_000
train_idxs = possible_training_idxs[1:num_train]

In [None]:
#param_file = "dummy_param.txt"

#_AtWA, _AtWb = PotentialLearning.ooc_learn_eonly!(lb, raw_data[train_idxs];symmetrize=false, λ=0.01, pbar=false)
#
#open(param_file, "w") do io
#    writedlm(io, lb.β)
#end


In [None]:
lb.β .= readdlm("qm9_4elem_3body_poly10_fit40K.txt", Float64)

In [None]:
etest_ref = get_all_energies(raw_data[possible_test_idxs])

etest_local_descrs = compute_local_descriptors(raw_data[possible_test_idxs],lb.basis)
ds_test = DataSet(raw_data[possible_test_idxs] .+ etest_local_descrs)
etest_pred = get_all_energies(ds_test,lb)

num_atoms_test = length.(get_system.(raw_data[possible_test_idxs]))

In [None]:
@show e_mae, e_rmse, e_rsq = calc_metrics(etest_pred,etest_ref)

(e_mae, e_rmse, e_rsq) = calc_metrics(etest_pred, etest_ref) = (0.04422691459021048, 0.06343019540788734, 0.9999609908582189)

Some notes about computing the distance vector: 
- feature vector for each config is averaged over atoms (not summed)
- feature vector is standardized when generating the k-means cluster
- does appear to be using Euclidean distance with k-means (they pass "Minkowski", which seems to default to p=2)
- when getting the final distance metric, they take the average distance between all cluster centers

In [None]:
function compute_mean_features(ds)
    mean_feature_perconfig = Vector{Float64}[]
    for (i,config) in enumerate(ds)
        if i % 100 == 0
            println(i)
        end
        mean_feature = mean(InteratomicPotentials.compute_local_descriptors(get_system(config), lb.basis))
        push!(mean_feature_perconfig,mean_feature)
    end

    reduce(hcat,mean_feature_perconfig)
end

# can either average over the distances between centers, or find the minimum
# could also vary the distance here...
function heuristic_uncertainty1(mean_feature_vec, km)
    #dist = mean(mapslices(x->Distances.euclidean(mean_feature_vec,x), km.centers, dims=1))
    dist = minimum(mapslices(x->Distances.euclidean(mean_feature_vec,x), km.centers, dims=1))
end

function heuristic_uncertainty2(mean_feature_vec, km)
    dist = mean(mapslices(x->Distances.euclidean(mean_feature_vec,x), km.centers, dims=1))
    #dist = minimum(mapslices(x->Distances.euclidean(mean_feature_vec,x), km.centers, dims=1))
end

In [None]:
mean_train_features = compute_mean_features(raw_data[train_idxs])
mean_test_features  = compute_mean_features(raw_data[possible_test_idxs])

dt = StatsBase.fit(ZScoreTransform, mean_train_features, dims=2)
std_mean_train_features = StatsBase.transform(dt,mean_train_features)
std_mean_test_features = StatsBase.transform(dt,mean_test_features)

In [None]:
Q = Symmetric(mean(di*di' for di in eachrow(std_mean_train_features)))

In [None]:
num_neighbors = 20
km = kmeans(std_mean_train_features, num_neighbors, distance=Distances.Euclidean(), rng=Xoshiro(1))
km_10 = kmeans(std_mean_train_features, 10, distance=Distances.Euclidean(), rng=Xoshiro(1))
#km_50  = kmeans(std_mean_test_features, 50, distance=Distances.Euclidean(), rng=Xoshiro(1)) #Ah i Fucked up here
km_50  = kmeans(std_mean_train_features, 50, distance=Distances.Euclidean(), rng=Xoshiro(1)) #Ah i Fucked up here

In [None]:
show(stdout, "text/plain", pairwise(Distances.Euclidean(), km.centers))

In [None]:
test_feature_distances1 = mapslices(vec -> heuristic_uncertainty1(vec, km), std_mean_test_features, dims=1)
test_feature_distances2 = mapslices(vec -> heuristic_uncertainty2(vec, km), std_mean_test_features, dims=1)

test_feature_distances3 = mapslices(vec -> heuristic_uncertainty1(vec, km_10), std_mean_test_features, dims=1)
test_feature_distances4 = mapslices(vec -> heuristic_uncertainty1(vec, km_50), std_mean_test_features, dims=1)

In [None]:

xlims!(0,150)
hist(test_feature_distances4[1,:], bins=1000)
#xlims!(0,150)

In [None]:
xlims!(0,150)

In [None]:
test_feature_distances = test_feature_distances4
fraction_calib = 0.1
peratom = true
alpha = 0.05
num_calib = floor(Int64, fraction_calib*length(possible_test_idxs))
num_test = length(possible_test_idxs) - num_calib

#idxs_wrt_test = Random.randperm(length(possible_test_idxs))
idxs_wrt_test = collect(1:length(possible_test_idxs))

calib_idxs_wrt_test = idxs_wrt_test[1:num_calib]
test_idxs_wrt_test = idxs_wrt_test[num_calib+1:end]

if !peratom
    calib_scores = abs.(etest_pred[calib_idxs_wrt_test] .- etest_ref[calib_idxs_wrt_test]) ./ test_feature_distances[calib_idxs_wrt_test]
    test_abs_residuals = abs.(etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test])

else
    calib_scores = ( abs.(etest_pred[calib_idxs_wrt_test] .- etest_ref[calib_idxs_wrt_test])
                    ./ num_atoms_test[calib_idxs_wrt_test] ./ test_feature_distances[calib_idxs_wrt_test] )
    test_abs_residuals = abs.(etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test]) ./ num_atoms_test[test_idxs_wrt_test]

end

q_hat = quantile(calib_scores, ceil((num_calib+1)*(1-alpha))/num_calib)

So the Medford paper takes as a quantity of interest as the energy normalized by the number of atoms, rather than the raw energy. I suspect that it doesn't make that much of a difference for this dataset since the number of atoms are pretty similar, but for very big differences I'm sure it probably starts to matter. 

In [None]:
qhat_scores = q_hat*test_feature_distances[test_idxs_wrt_test]
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

In [None]:
f = parity_plot(etest_ref[test_idxs_wrt_test]./ num_atoms_test[test_idxs_wrt_test],etest_pred[test_idxs_wrt_test]./ num_atoms_test[test_idxs_wrt_test],qhat_scores)
# Should probably do this as a residual plot

In [None]:
test_residuals = (etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test]) ./num_atoms_test[test_idxs_wrt_test]
#hist(abs.(etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test]) ./num_atoms_test[test_idxs_wrt_test],bins=100)
hist(test_residuals,bins=1000)

In [None]:
#fig = Figure(size=(600,600))
#ax  = Axis(fig[1,1],
#title="Residuals vs distances",
#xlabel="Distances",
#ylabel="Residuals",
#limits=(0,100,-0.005,0.03))
#hlines!(ax, 0.0, color=:red, linestyle=:dash)
##scatter!(ax, test_feature_distances[test_idxs_wrt_test], test_abs_residuals, markersize=1)
#scatter!(ax, test_feature_distances2[test_idxs_wrt_test], test_abs_residuals, markersize=5)
##ax.aspect=DataAspect()

function uncertainty_vs_residuals(uncertainty, residuals;
                                  title  ="Uncertainty vs. residuals",
                                  xlabel = "Distance",
                                  ylabel = "Residuals",
                                  figsize = (600,600))

    fig = Figure(size=(600,600))
    ax  = Axis(fig[1,1],
    title="Residuals vs distances",
    xlabel="Distances",
    ylabel="Residuals",
    limits=(0,100,-0.01,0.03))

    hlines!(ax, 0.0, color=:red, linestyle=:dash)

    scatter!(ax, uncertainty, residuals, markersize=5)

    #ax.aspect=DataAspect()
    fig
end



#scatter(test_feature_distances[test_idxs_wrt_test], test_abs_residuals)
uncertainty_vs_residuals(test_feature_distances[test_idxs_wrt_test], test_abs_residuals) # idk why this isn't plotting correctly
#fig


In [None]:
xvals = test_feature_distances4[test_idxs_wrt_test]
yvals = test_abs_residuals
A = hcat(ones(length(xvals)), xvals)
coeffs = A \ yvals

y_pred = coeffs[2] .*xvals .+ coeffs[1]

ymean = mean(yvals)
ss_total = sum((yvals .- ymean).^2)
ss_residual = sum((yvals .- y_pred).^2)
r_squared = 1 - (ss_residual / ss_total)

In [None]:
abs_res = test_abs_residuals
@show length(abs_res)
for bin_start in 0.000:0.001:0.015
    low = bin_start
    high = bin_start + 0.001
    idxs = [i for i in eachindex(abs_res) if abs_res[i] >= low && abs_res[i] < high]
    local_coverage = 1-sum(abs_res[idxs] .> qhat_scores[idxs])/length(idxs)
    println("$(low)-$(high) : $(length(idxs)) configs with coverage $(local_coverage)")
end
low = 0.012
high = 0.02
idxs = [i for i in eachindex(abs_res) if abs_res[i] >= low && abs_res[i] < high]
local_coverage = 1-sum(abs_res[idxs] .> qhat_scores[idxs])/length(idxs)
println("$(low)-$(high) : $(length(idxs)) configs with coverage $(local_coverage)")

local_coverage = 1 - sum(abs_res .> qhat_scores)/length(abs_res)
println("overall coverage is $(local_coverage)")


In [None]:
sharpness = 2*mean(qhat_scores)

In [None]:
var(qhat_scores)

so I think the sharpness here should be doubled, because qhat_scores is only one side of the symmetric uncertainty. 
Regardless for 1-alpha =0.68, the mean here is 0.0029635544980581414 or about 3 meV (per atom). 
In contrast https://github.com/medford-group/conformal_prediction_in_latent_space/blob/master/uncertainty/analyze_conformal_feature.ipynb is 
np.mean(test_uncertainty1) = 1.6777220081885507 meV

So I'm getting a sharpness almost 2x as large. 

They plot a histogram (though it doesn't show up) of the their equivalent of qhat_scores. I don't really want to try to set 

In [None]:
hist(qhat_scores,bins=1000)
# I'm not sure, quantitatively, what constitutes being sufficiently adaptive.
# Ultimately it seems dependent on the dataset in addition to the score function

In [None]:
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_refs = collect(range(0.01,0.99,step=0.01))

predicted_alphas = Float64[]
#for ac in alpha_complements
#    alpha = 1-ac
for alpha in alpha_refs
    qh = quantile(calib_scores, clamp(ceil((num_calib+1)*(1-alpha))/num_calib, 0.0, 1.0))

    qh_scores = qh*test_feature_distances[test_idxs_wrt_test]
    predicted_alpha = sum(test_abs_residuals .> qh_scores) / num_test
    push!(predicted_alphas, predicted_alpha)
end

# I feel like I plot the reverse of what I want like 70 is 30 and 30 is 70


In [None]:
compute_miscalibration_area(alpha_refs, predicted_alphas)

In [None]:
make_calibration_plot(alpha_refs,predicted_alphas)

Need to fix miscalibration plot what is 70% should be 30%