In [None]:
using Revise

In [None]:
using Cairn 
using LinearAlgebra, Random, Statistics, StatsBase # Do need StatsBase this time
using PotentialLearning 
using Molly, AtomsCalculators
using AtomisticQoIs 
using SpecialPolynomials


In [None]:
include("/Users/swyant/cesmix/dev/Cairn.jl/src/makie/makie.jl")

In [None]:
ref = MullerBrownRot()
limits = [[-4.4,1.5],[-2,2]]
coord_grid = coord_grid_2d(limits,0.05) # from the makie.jl file
ctr_lvls = -150:50:1000
#ctr_lvls = -150:20:400

In [None]:
basisfam = Jacobi{0.5,0.5}
order=20
pce0 = PolynomialChaos(order,2,basisfam,xscl=limits)

In [None]:
coords_eval = potential_grid_2d(ref,limits,0.04,cutoff=800)
#@show coords_eval
#sys_eval = define_ens(ref,coords_eval) # This has been replaced with Ensemble()
sys_eval = Ensemble(ref,coords_eval)

In [None]:
ζ = [ustrip.(Vector(coords)) for coords in coords_eval]
GQint = GaussQuadrature(ζ,ones(length(ζ)) ./length(ζ))


In [None]:
f0, ax0 = plot_contours_2d(ref,coord_grid; fill=true, lvls=ctr_lvls) # there is a mistake here in her example, ctrl_lvls2
#coordmat = reduce(hcat, get_values(coords_eval))' # no get_values for Vector{SVector}
coordmat = reduce(hcat, [get_values(crd) for crd in coords_eval])'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="test points")
axislegend(ax0)
f0

In [None]:
f, _ = plot_density(ref, coord_grid, GQint)
f

In [None]:
pce = deepcopy(pce0)
# she has a custom LinearProblem call where the energies and forces are calculated from the reference, rather than just assumed to be pre-computed and obtained from get_forces,get_energies
lp = learn!(sys_eval, ref, pce, [1000,1], false; e_flag=true,f_flag=true)

p = define_gibbs_dist(ref) # defines gibbs distribution based off the potential, Gibbs struct is defined in AtomisticQoIs
q = define_gibbs_dist(pce, θ=lp.β)
fish = FisherDivergence(GQint) # divergence calculation parameterized by the quadrature points
fd_best = compute_divergence(p,q,fish)


In [None]:
#training set 1: grid over main support

coords1= potential_grid_2d(ref,limits,0.05,cutoff=800)
trainset1 = Ensemble(deepcopy(pce0),coords1) # I used this rather than define_ens

In [None]:
f0, ax0 = plot_contours_2d(ref,coord_grid; fill=true, lvls=ctr_lvls)
# coordmat = reduce(hcat, get_values(coords1))'  # same issue as above
coordmat = reduce(hcat, [get_values(crd) for crd in coords1])'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 1")
axislegend(ax0)
f0

In [None]:
# training set 2: samples from Langevin MD
sim_langevin = OverdampedLangevin(
                dt =0.002u"ps",
                temperature = 500.0u"K",
                friction=4.0u"ps^-1"
)
sys0 = System(ref,[0.5,0.5], loggers=(coords=CoordinateLogger(100;dims=2),))

In [None]:
sys2 = deepcopy(sys0)
simulate!(sys2, sim_langevin, 1_000_000)

id = StatsBase.sample(1:length(sys2.loggers.coords.history), length(coords1), replace=false)
coords2 = [sys2.loggers.coords.history[i][1] for i in id]
trainset2 = Ensemble(deepcopy(pce0),coords2)

In [None]:
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, [get_values(crd) for crd in coords2])'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 2")
axislegend(ax)
f

In [None]:
# Training set 3: samples from high-T MD 

sim_highT = OverdampedLangevin(
            dt=0.002u"ps",
            temperature=1500.0u"K",
            friction=4.0u"ps^-1",
)
# simulate
sys3 = deepcopy(sys0)
simulate!(sys3, sim_highT, 1_000_000)

id = StatsBase.sample(1:length(sys3.loggers.coords.history), length(coords1), replace=false)
coords3 = [sys3.loggers.coords.history[i][1] for i in id]
trainset3 = Ensemble(deepcopy(pce0), coords3)

In [None]:
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, [get_values(crd) for crd in coords3])'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 3")
axislegend(ax)
f

In [None]:
# train with changing weight λ 
λarr = [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4]
trainsets = [trainset1, trainset2, trainset3]
p = define_gibbs_dist(ref)
fish = FisherDivergence(GQint)

In [None]:
# store results
param_dict = Dict( "ts$j" => Dict(
    "E" => zeros(length(pce.basis)),
    "F" => zeros(length(pce.basis)),
    "EF" => Vector{Vector}(undef, length(λarr)),
    ) for j = 1:length(trainsets)
)

err_dict = Dict( "ts$j" => Dict(
    "E" => 0.0,
    "F" => 0.0,
    "EF" => zeros(length(λarr)),
    ) for j = 1:length(trainsets)
)

fd_dict = Dict( "ts$j" => Dict(
    "E" => 0.0,
    "F" => 0.0,
    "EF" => zeros(length(λarr)),
    ) for j = 1:length(trainsets)
)

In [None]:
# train on E or F only (UnivariateLinearProblem)
for (j,ts) in enumerate(trainsets)
    @show j
    # E objective
    println("train set $j, E only")
    pce = deepcopy(pce0)
    lpe = learn!(ts, ref, pce; e_flag=true, f_flag=false)
    q = define_gibbs_dist(pce, θ=lpe.β)
    err_dict =
    fd_dict["ts$j"]["E"] = compute_divergence(p, q, fish)
    param_dict["ts$j"]["E"] = lpe.β
    
    println("moving on to F")
    # F objective
    println("train set $j, F only")
    pce = deepcopy(pce0)
    lpf = learn!(ts, ref, pce; e_flag=false, f_flag=true)
    q = define_gibbs_dist(pce, θ=lpf.β)
    fd_dict["ts$j"]["F"] = compute_divergence(p, q, fish)
    param_dict["ts$j"]["F"] = lpf.β
end


# I get singular Exception 

In [None]:
# train on EF (CovariateLinearProblem)
for (i,λ) in enumerate(λarr)
    for (j,ts) in enumerate(trainsets)

        # EF objective
        println("train set $j, EF (λ=$λ)")
        pce = deepcopy(pce0)
        lpef = learn!(ts, ref, pce, [λ, 1], false; e_flag=true, f_flag=true)
        q = define_gibbs_dist(pce, θ=lpef.β)
        fd_dict["ts$j"]["EF"][i] = compute_divergence(p, q, fish)
        param_dict["ts$j"]["EF"][i] = lpef.β
    end
end

In [None]:
# plot results, adjusting for the failed E, F due to the singular error
#λlab = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4, 1e5]
λlab = [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4]

f = Figure(resolution=(550,450))
ax = Axis(f[1,1],
    xlabel="λ",
    ylabel="Fisher divergence",
    title="Model Error vs. Weight λ",
    xscale=log10,
    yscale=log10,
    #xticks=(λlab, ["F", "1e-4", "1e-3", "1e-2", "1e-1", "1", "1e1", "1e2", "1e3", "1e4", "E"]))
    xticks=(λlab, ["1e-4", "1e-3", "1e-2", "1e-1", "1", "1e1", "1e2", "1e3", "1e4"]))

for j = 1:3
#    fd_all = reduce(vcat, [[fd_dict["ts$j"]["F"]], fd_dict["ts$j"]["EF"], [fd_dict["ts$j"]["E"]]])
    fd_all = reduce(vcat, [fd_dict["ts$j"]["EF"]])

    scatterlines!(ax, λlab, fd_all, label="train set $j")
end
axislegend(ax, position=:lb)
f
