In [None]:
include("utils.jl")

mkpath("models")

#### Choose system and specify model

In [None]:
system = "3D-LJ-Tvar"
datadir = "data/$(system)"
dx = 0.01
c1_with_T = true
window_width = 3.5  # in sigma from center of window
trainable=(:c1, :μ)
#trainable=(:c1,)
hidden_nodes = [128, 64, 32]
window_bins = round(Int, window_width*2/dx) + 1

#### Plot random simulation result from the dataset

In [None]:
file = jldopen(rand(readdir(datadir, join=true)), "r")
if "μ" in keys(file)
    println("μ = ", file["μ"])
end
println("T = ", file["T"])
plot(file["xs"], file["ρ"], label="ρ (sim)")
plot!(file["xs"], file["Vext"], label="Vext (sim)")

#### Prepare data

In [None]:
ρ_profiles, Vext_profiles, μ_values, T_values = read_sim_data(datadir)
ρ_windows, sim_onehots, ρ_values, Vext_values = generate_inout(ρ_profiles, Vext_profiles; window_bins)

size(ρ_windows), size(sim_onehots), size(ρ_values), size(Vext_values)

#### Train model

In [None]:
μ_init = :μ in trainable ? zeros(size(μ_values)) : Float32.(μ_values)
T_init = :T in trainable ? ones(size(T_values)) : Float32.(T_values)

model = Model(size(ρ_windows)[1], size(sim_onehots)[1]; T_init, μ_init, c1_with_T, trainable, hidden_nodes) |> gpu
display(model)

ρ_windows, sim_onehots, ρ_values, Vext_values = (ρ_windows, sim_onehots, ρ_values, Vext_values) |> gpu

opt = Flux.setup(Adam(), model)

batchsize = 128
loader = Flux.DataLoader((ρ_windows, sim_onehots, ρ_values, Vext_values), batchsize=batchsize, shuffle=true, partial=false)
ρ0_batch = zeros(size(ρ_windows)[1], batchsize) |> gpu
ρ0_windows = zero(ρ_windows) |> gpu

loss_EL(c1, μ, T, ρ_values, Vext_values) = Flux.mse(c1 .+ (μ .- Vext_values) ./ T .- log.(ρ_values), 0)

get_learning_rate(epoch; initial=0.001, rate=0.05) = initial * (1 - rate)^epoch

timestamp = now()
epochs = 100
model_state_history = []
for epoch in 1:epochs
    push!(model_state_history, Flux.state(model |> cpu))
    learning_rate = get_learning_rate(epoch)
    Flux.adjust!(opt, learning_rate)
    @printf "Epoch: %3i (learning_rate: %.2e)..." epoch learning_rate; flush(stdout)
    Flux.train!(model, loader, opt) do model, ρ_windows_batch, sim_onehots_batch, ρ_values_batch, Vext_values_batch
        c1, μ, T = model(ρ_windows_batch, sim_onehots_batch)
        loss_EL(c1, μ, T, ρ_values_batch, Vext_values_batch)
    end
    c1, μ, T = model(ρ_windows, sim_onehots)
    @printf " loss_EL: %.5f" loss_EL(c1, μ, T, ρ_values, Vext_values)
    @printf "\n"; flush(stdout)
end
push!(model_state_history, Flux.state(model |> cpu))

c1_model_savefile = "models/c1_model_$(system)_$(timestamp).bson"
c1_model = model.c1 |> cpu
BSON.@save c1_model_savefile c1_model
model_state_history_savefile = "models/model_state_history_$(system)_$(timestamp).jld2"
jldsave(model_state_history_savefile; model_state_history)