In [None]:
using Pkg; Pkg.activate(".")
using Cairn
using Molly 
using PotentialLearning
using LinearAlgebra
using SpecialPolynomials
include("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/makie.jl") # Also uses ActiveSubspaces
include("./my_misc_utils.jl")

In [None]:
Pkg.status()

In [None]:
ref = Himmelblau()
limits = [[-6.25,6.25],[-5.75,5.75]] # boundaries of main support

In [None]:
temp = 100.0u"K"

sim_langevin = OverdampedLangevin(  dt=0.002u"ps",
                                    temperature=temp,
                                    friction=1.0u"ps^-1")
sys0 = init_trajectory(ref, [4.5, -2], logstep=100)
sys  = deepcopy(sys0)
@time simulate!(sys, sim_langevin, 1_000_000)

In [None]:
# for plotting
dist_units = u"nm"
xcoord = Vector(limits[1][1]:0.05:limits[1][2]) .* dist_units
ycoord = Vector(limits[2][1]:0.05:limits[2][2]) .* dist_units 
ctr_grid = [xcoord, ycoord]
ctr_lvls = 0:25:350

# plot
f0 = plot_md_trajectory(sys, ctr_grid, fill=true, lvls=ctr_lvls)

In [None]:
pce = PolynomialChaos(5, 2, Jacobi{0.5,0.5}, xscl=limits)

In [None]:

# initialize system properties
atom_mass = 1.0u"g/mol"
boundary = RectangularBoundary(Inf*u"nm")

In [None]:
x_train = [[3.8464987622491895, -1.7801390420313014],
[4.096079218176701, -1.9623723102484034],
[3.794900754980532, -2.0790635268608757],
[3.2997659303869744, -2.2798944543828574],
[3.3921279139032157, -2.1410288797166183],
[3.5889845930605545, -2.103609154541564],
[3.367066701220913, -1.3277818179304328],
[3.7936780458955686, -2.457270634134617],
[3.4604931855508254, -1.6495593869168982],
[3.591850102112664, -2.390778252852719],
[3.8764987622491895, -1.8101390420313014],
[4.066079218176701, -1.9323723102484034],
[3.824900754980532, -2.1090635268608757],
[3.2697659303869744, -2.2498944543828574],
[3.4221279139032157, -2.1710288797166183],
[3.5589845930605545, -2.133609154541564],
[3.397066701220913, -1.3577818179304328],
[3.8236780458955686, -2.427270634134617],
[3.4304931855508254, -1.6195593869168982],
[3.561850102112664, -2.420778252852719],
[3.691850102112664, -2.290778252852719]]

coords_train = [SVector(x[1], x[2])u"nm" for x in x_train]
ntrain = length(coords_train)
atoms_train = [Atom(mass=atom_mass, σ=1.0u"nm", ϵ=1.0u"kJ * mol^-1") for i in 1:ntrain]

sys0 = [System(
    atoms=[atoms_i],
    coords=[coords_i],
    boundary=boundary,
    general_inters=(pce,), # doesn't actually get used with training though, since these systems are eval'd with himmelblau
    # k = 1.0u"kJ * K^-1 * mol^-1",
) for (atoms_i, coords_i) in zip(atoms_train, coords_train)]

In [None]:
pce0 = deepcopy(pce)
train_potential_e!(sys0, ref, pce0) # wts=[1e4,1])

In [None]:
# plot
ctr_lvls0 = -125:25:400
f0, ax0 = plot_contours_2D(pce0, xcoord, ycoord; fill=true, lvls=ctr_lvls0)
coordmat = reduce(hcat, get_values(coords_train))'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, label="training points")
axislegend(ax0)
f0

In [None]:
ens0 = [System(
    atoms=[atoms_i],
    coords=[coords_i],
    boundary=boundary,
    general_inters=(pce0,), # this does get used this time
    # k = 1.0u"kJ * K^-1 * mol^-1",
    loggers=(
        coords=CoordinateLogger(1; dims=2),
        ksd=StepComponentLogger(1; dims=2),
        # trigger=TriggerLogger(trigger2, 1),
        params=TrainingLogger(),
    )
) for (atoms_i, coords_i) in zip(atoms_train[1:10], coords_train[1:10])] # note that these could have been at any points, arbitrarily taken to be the first ten training configs. Totally separate from training and the subsequent "fixed" points below

In [None]:
# define kernel
rbf = RBF(Euclidean(2), β=0.2)

# define fixed system using initial PCE potential
sys_fix = [System(
    atoms=[atoms_i],
    coords=[coords_i],
    boundary=boundary,
    general_inters=(pce0,),
    # k = 1.0u"kJ * K^-1 * mol^-1",
) for (atoms_i, coords_i) in zip(atoms_train, coords_train)]

# define simulator
sim_svgd = StochasticSVGD(
            dt=0.002u"ps",
            kernel=rbf,
            kernel_bandwidth=median_kernel_bandwidth,
            sys_fix=sys_fix,
            temperature=temp,
            friction=1.0u"ps^-1")

In [None]:
trigger1 = TimeInterval(interval=1000)

In [None]:
# evaluation set: grid over main support
potential(coords) = Cairn.potential_himmelblau(ref, coords) # note that for himmelblau, potential_energy calls out to this function, just intended to do so over a set of atoms. 
coords_eval = potential_grid_points(potential, limits, 0.2, cutoff = 400) # only consider points on grid less than some cutoff
sys_eval = define_sys(ref, coords_eval, boundary)

# use grid to define uniform quadrature points
ξ = [ustrip.(Vector(coords)) for coords in coords_eval]
GQint = GaussQuadrature(ξ, ones(length(ξ))./length(ξ))

In [None]:

al1 = ActiveLearnRoutine(
    ref,
    pce0,
    sys_fix,
    GQint,
    trigger1,
    Dict("fd" => [], "rmse_e" => [], "rmse_f" => []),
    train_func = train_potential_e!,
)
     

In [None]:

ens = deepcopy(ens0)
al1, bwd = active_learn!(ens, sim_svgd, 12_000, al1)

In [None]:
ens[1].loggers.ksd.history

In [None]:
f = plot_md_trajectory(ens, al1.sys_train, ctr_grid, fill=false, lvls=ctr_lvls0, showpath=false)

In [None]:
#MaxVol 
function basis_eval(sys_train::Vector{<:System}, pce::PolynomialChaos)
    coords = get_coords(sys_train)
    xtrain = [ustrip.(coord[1]) for coord in coords]
    A = reduce(hcat, eval_basis.(xtrain, (pce,)))'
    return A
end

trigger2 = MaxVol(x -> basis_eval(x, pce0), thresh=1.22) # this constructor uses the extrap_grade fn, and the assumption is that x will be the training set

In [None]:
al2 = ActiveLearnRoutine(
    ref,
    pce0,
    sys_fix,
    GQint,
    trigger2,
    Dict("fd" => [], "rmse_e" => [], "rmse_f" => []),
    train_func = train_potential_e!,
    burnin=100,
)
     

In [None]:

ens = deepcopy(ens0)
al2, bwd = active_learn!(ens, sim_svgd, 12_000, al2)

In [None]:
f = plot_md_trajectory(ens, al2.sys_train, ctr_grid, fill=false, lvls=ctr_lvls0, showpath=false)

In [None]:
f = Figure(size = (900, 800))
ax1 = Axis(f[1,1],
    xlabel="no. simulation steps (t)",
    ylabel="Fisher divergence", 
    title="Median discrepancy in distribution",
    xgridvisible=false,
    ygridvisible=false,
    # yscale=log10
    )

err = al1.error_hist["rmse_e"]
alsteps = al1.train_steps
err_vec1 = reduce(vcat, [err[i].*ones(alsteps[i+1]-alsteps[i]) for i = 1:length(err)-1])
lines!(ax1, 1:length(err_vec1), err_vec1, color=:skyblue1, label="TimeInterval trigger")
# [vlines!(ax1, alsteps[i], color=(:skyblue1,0.5), linestyle=:dash) for i = 1:length(alsteps)]

err = al2.error_hist["rmse_e"]
alsteps = al2.train_steps
err_vec2 = reduce(vcat, [err[i].*ones(alsteps[i+1]-alsteps[i]) for i = 1:length(err)-1])
lines!(ax1, 1:length(err_vec2), err_vec2, color=:orange, label="MaxVol trigger")
# [vlines!(ax1, alsteps[i], color=(:orange,0.5), linestyle=:dash) for i = 1:length(alsteps)]

axislegend(ax1)
f
     