In [None]:
using Revise

using PotentialLearning, InteratomicPotentials
using Unitful
using Random
using AtomsBase
using DelimitedFiles
using Statistics: mean, var
using StatsBase
using Clustering, Distances, NearestNeighbors
using Trapz
using LinearAlgebra: Symmetric, eigen, mul!, svd, cond, dot, norm

using MultivariateStats, StatsAPI

using JLD2

#using CairoMakie CairoMakie.activate!()
using GLMakie; GLMakie.activate!(inline=false)

In [None]:
ace = ACE(species           = [:C,:H,:O,:N],
          body_order        = 3,
          polynomial_degree = 10,
          wL                = 2.0,
          csp               = 1.0,
          r0                = 1.43,
          rcutoff           = 4.4 )
lb = LBasisPotential(ace)
length(ace)

qm9_file = "../files/QM9/qm9_fullset_alldata.xyz"
raw_data = load_data(qm9_file, ExtXYZ(u"eV", u"Å"))
raw_data = DataSet([config for config in raw_data if !(:F in atomic_symbol(get_system(config)))])

max_num_train = 120_001
master_perm_idxs = readdlm("./primary_permutation.txt", Int64)
possible_training_idxs = master_perm_idxs[1:max_num_train]
possible_test_idxs = master_perm_idxs[max_num_train+1:end]

num_train = 40_000
train_idxs = possible_training_idxs[1:num_train]

lb.β .= readdlm("qm9_4elem_3body_poly10_fit40K.txt", Float64)

In [None]:
etest_ref = get_all_energies(raw_data[possible_test_idxs])

etest_local_descrs = compute_local_descriptors(raw_data[possible_test_idxs],lb.basis)
ds_test = DataSet(raw_data[possible_test_idxs] .+ etest_local_descrs)
etest_pred = get_all_energies(ds_test,lb)

num_atoms_test = length.(get_system.(raw_data[possible_test_idxs]))

@show e_mae, e_rmse, e_rsq = calc_metrics(etest_pred./num_atoms_test,etest_ref./num_atoms_test)

In [None]:
etrain_ref = get_all_energies(raw_data[train_idxs])

etrain_local_descrs = compute_local_descriptors(raw_data[train_idxs],lb.basis)
ds_train = DataSet(raw_data[train_idxs] .+ etrain_local_descrs)
etrain_pred = get_all_energies(ds_train,lb)

num_atoms_train = length.(get_system.(raw_data[train_idxs]))

@show etrain_mae, etrain_rmse, etrain_rsq = calc_metrics(etrain_pred./num_atoms_train,etrain_ref./num_atoms_train)

So really, I should have a separate validation data set (maybe I can use the calibration dataset?) to compute the standard deviation

In [None]:
train_residuals = (etrain_pred .- etrain_ref) ./ num_atoms_train
trainset_std = sqrt(var(train_residuals))

In [None]:
function compute_mean_features(ds)
    mean_feature_perconfig = Vector{Float64}[]
    for (i,config) in enumerate(ds)
        if i % 100 == 0
            println(i)
        end
        mean_feature = mean(InteratomicPotentials.compute_local_descriptors(get_system(config), lb.basis))
        push!(mean_feature_perconfig,mean_feature)
    end

    reduce(hcat,mean_feature_perconfig)
end

function normdists2centers(feature_vec, km)
    dists = mapslices(x->Distances.euclidean(feature_vec,x), km.centers, dims=1)
    normed_dists = dists ./ sum(dists)
end

In [None]:
mean_train_features = compute_mean_features(raw_data[train_idxs])
mean_test_features  = compute_mean_features(raw_data[possible_test_idxs])

dt = StatsBase.fit(ZScoreTransform, mean_train_features, dims=2)
std_mean_train_features = StatsBase.transform(dt,mean_train_features)
std_mean_test_features = StatsBase.transform(dt,mean_test_features)

In [None]:
uncertainty = trainset_std

fraction_calib = 0.1
peratom = true
alpha = 0.05
num_calib = floor(Int64, fraction_calib*length(possible_test_idxs))
num_test = length(possible_test_idxs) - num_calib

#idxs_wrt_test = Random.randperm(length(possible_test_idxs))
idxs_wrt_test = collect(1:length(possible_test_idxs))

calib_idxs_wrt_test = idxs_wrt_test[1:num_calib]
test_idxs_wrt_test = idxs_wrt_test[num_calib+1:end]

if !peratom
    calib_scores = abs.(etest_pred[calib_idxs_wrt_test] .- etest_ref[calib_idxs_wrt_test]) ./ uncertainty
    test_abs_residuals = abs.(etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test])

else
    calib_scores = ( abs.(etest_pred[calib_idxs_wrt_test] .- etest_ref[calib_idxs_wrt_test])
                    ./ num_atoms_test[calib_idxs_wrt_test] ./ uncertainty )
    test_abs_residuals = abs.(etest_pred[test_idxs_wrt_test] .- etest_ref[test_idxs_wrt_test]) ./ num_atoms_test[test_idxs_wrt_test]

end

q_hat = quantile(calib_scores, ceil((num_calib+1)*(1-alpha))/num_calib)

In [None]:
qhat_scores = q_hat*stdev*ones(num_test)
coverage = sum(test_abs_residuals .> qhat_scores) / num_test

0.06161905949856555 # randomly not good here

In [None]:
abs_res = test_abs_residuals
@show length(abs_res)
for bin_start in 0.000:0.001:0.015
    low = bin_start
    high = bin_start + 0.001
    idxs = [i for i in eachindex(abs_res) if abs_res[i] >= low && abs_res[i] < high]
    local_coverage = 1-sum(abs_res[idxs] .> qhat_scores[idxs])/length(idxs)
    println("$(low)-$(high) : $(length(idxs)) configs with coverage $(local_coverage)")
end
low = 0.012
high = 0.02
idxs = [i for i in eachindex(abs_res) if abs_res[i] >= low && abs_res[i] < high]
local_coverage = 1-sum(abs_res[idxs] .> qhat_scores[idxs])/length(idxs)
println("$(low)-$(high) : $(length(idxs)) configs with coverage $(local_coverage)")

local_coverage = 1 - sum(abs_res .> qhat_scores)/length(abs_res)
println("overall coverage is $(local_coverage)")

length(abs_res) = 8017
0.0-0.001 : 2512 configs with coverage 1.0
0.001-0.002 : 1896 configs with coverage 1.0
0.002-0.003 : 1222 configs with coverage 1.0
0.003-0.004 : 805 configs with coverage 1.0
0.004-0.005 : 550 configs with coverage 1.0
0.005-0.006 : 337 configs with coverage 1.0
0.006-0.007 : 185 configs with coverage 1.0
0.007-0.008 : 147 configs with coverage 0.108843537414966
0.008-0.009000000000000001 : 108 configs with coverage 0.0
0.009-0.009999999999999998 : 61 configs with coverage 0.0
0.01-0.011 : 44 configs with coverage 0.0
0.011-0.012 : 40 configs with coverage 0.0
0.012-0.013000000000000001 : 16 configs with coverage 0.0
0.013-0.013999999999999999 : 16 configs with coverage 0.0
0.014-0.015 : 18 configs with coverage 0.0
0.015-0.016 : 12 configs with coverage 0.0
0.012-0.02 : 87 configs with coverage 0.0
overall coverage is 0.9383809405014345

In [None]:
alpha_complements = collect(range(0.01,0.99,step=0.01))
alpha_refs = 1 .- alpha_complements

alpha_refs = collect(range(0.01,0.99,step=0.01))

predicted_alphas = Float64[]
#for ac in alpha_complements
#    alpha = 1-ac
for alpha in alpha_refs
    qh = quantile(calib_scores, clamp(ceil((num_calib+1)*(1-alpha))/num_calib, 0.0, 1.0))

    qh_scores = qh*stdev*ones(num_test)
    predicted_alpha = sum(test_abs_residuals .> qh_scores) / num_test
    push!(predicted_alphas, predicted_alpha)
end

In [None]:
function compute_miscalibration_area(expected_ps, observed_ps)
    area = 0.0
    #for i in 2:length(expected_ps)-1
    #    trap = abs(trapz(expected_ps[i-1:i+1], observed_ps[i-1:i+1]) -
    #             trapz(expected_ps[i-1:i+1], expected_ps[i-1:i+1]))
    for i in 2:length(expected_ps)
        trap = abs(trapz(expected_ps[i-1:i], observed_ps[i-1:i]) -
                 trapz(expected_ps[i-1:i], expected_ps[i-1:i]))
        area += trap
    end
    area
end

# converted from Medford jupyter notebook via Claude
function make_calibration_plot(expected_ps, observed_ps; width=600)
    # Convert to percentages
    expected_ps = expected_ps .* 100
    observed_ps = observed_ps .* 100

    fig = Figure(resolution=(width, width))
    ax = Axis(fig[1, 1],
        aspect=DataAspect(),
        xlabel="Expected conf. level",
        ylabel="Observed conf. level",
        limits=(0, 100, 0, 100)
    )

    # Main line
    lines!(ax, 1.0 .- expected_ps, observed_ps)

    # Diagonal reference line
    lines!(ax, 1.0 .-expected_ps, 1.0 .-expected_ps, linestyle=:dash, alpha=0.4)

    # Filled area between curves
    band!(ax, expected_ps, expected_ps, observed_ps, color=(:blue, 0.2))

    # Configure ticks - approximately 4 ticks on each axis
    ax.xticks = 0:10:100
    ax.yticks = 0:10:100

    # Add percentage signs to ticks
    ax.xtickformat = xs -> ["$(Int(x))%" for x in xs]
    ax.ytickformat = xs -> ["$(Int(x))%" for x in xs]

    ## Add text for miscalibration area
    #text!(ax, "miscalc. area = $(round(area, digits=3))",
    #    position=(8, 2),
    #    align=(:left, :bottom)
    #)

    return fig
end

In [None]:
make_calibration_plot(alpha_refs,predicted_alphas)

In [None]:
compute_miscalibration_area(alpha_refs, predicted_alphas)

0.010918336035923657

In [None]:
f = jldopen("./spencer_clustering.jld2", "r")

In [None]:
centroids = f["K10_bandwidth10"]["centroids"]

In [None]:
function alt_normdists2centers(feature_vec, centroids)
    dists = mapslices(x->Distances.euclidean(feature_vec,x), centroids, dims=1)
    normed_dists = dists ./ sum(dists)
end

In [None]:
M1 = StatsAPI.fit(MultivariateStats.PCA, std_mean_train_features; mean=0)

pca_std_train_features = StatsAPI.predict(M1, std_mean_train_features)
pca_std_test_features = StatsAPI.predict(M1, std_mean_test_features)

In [None]:
train_dist2centers = mapslices(x->reshape(alt_normdists2centers(x,centroids),:,1), pca_std_train_features; dims=1)
train_assignments = vec(mapslices(x->argmax(x), train_dist2centers; dims=1))
num_inclusters = [length(findall(==(i), train_assignments)) for i in 1:10]
