In [None]:
using DrWatson
@quickactivate "masterarbeit"

In [None]:
#using CUDA
using BenchmarkTools
using ProgressMeter
using Flux
using LaTeXStrings
using Flux: train!
using GLMakie
using Printf
using Dates
using JLD2
using TOML
using StatsBase # for fit(histogram)
using Random

In [None]:
Makie.inline!(true)
fontsize_theme = Theme(fontsize=35)
set_theme!(fontsize_theme)
wblue = Makie.wong_colors()[1]
worange = Makie.wong_colors()[2]
wgreen = Makie.wong_colors()[3]
wpink = Makie.wong_colors()[4]
wlblue = Makie.wong_colors()[5]
worange = Makie.wong_colors()[6]
wyellow = Makie.wong_colors()[7];

In [None]:
using Revise

In [None]:
using masterarbeit

In [None]:
dir=get_mldir()

In [None]:
progress_plots=true
save_samples=true
save=true
comment="4cl, notebook subnet"

omega = 5.12
dphi = 10.0

ytozmap = trident_phasespace(omega)

function f(mapped)
    return dσpT_multithreaded((omega,), mapped...) .* 10000000f0
end

function jacobian_trident(m::Chain, cm::ChannelMapping, x::T) where {T <: AbstractArray{F}} where F<:Real
    cl1 = m[1]
    sl1 = m[2]
    cl2 = m[3]
    sl2 = m[4]
    cl3 = m[5]
    sl3 = m[6]
    cl4 = m[7]
    x2 = cl1(x)
    x2s = sl1(x2)
    x3 = cl2(x2s)
    x3s = sl2(x3)
    x4 = cl3(x3s)
    x4s = sl3(x4)
    det1 = abs.(masterarbeit.cldet_cpu(cl1,  x[cl1.dimA+1:cl1.d,:], cl1.m( x[1:cl1.dimA,:])...))
    det2 = abs.(masterarbeit.cldet_cpu(cl2, x2s[cl2.dimA+1:cl2.d,:], cl2.m(x2s[1:cl2.dimA,:])...))
    det3 = abs.(masterarbeit.cldet_cpu(cl3, x3s[cl3.dimA+1:cl3.d,:], cl3.m(x3s[1:cl3.dimA,:])...)) 
    det4 = abs.(masterarbeit.cldet_cpu(cl4, x4s[cl4.dimA+1:cl4.d,:], cl4.m(x4s[1:cl4.dimA,:])...)) 
    return abs(cmdet(cm)) .* det1 .* det2 .* det3 .* det4
end

function lossf(m::Chain, cm::ChannelMapping, f::Function, x::T) where T<:AbstractArray{F} where F<:Real
    zi = cm(m(x))
    g = 1 ./ jacobian_trident(m, cm, x)
    fz = f(zi)
    fracs = abs.(fz .- g) .^F(2.0) ./ fz
    return sum(fracs) / size(x,2)
end

dim = 5
dimA = 3
optimizer = Adam
activation = relu
batchsize = 1024
N_samples = 2^18  # = 260k
sample_batchsize = batchsize

epochs = 600#parse(Int, ARGS[1])#60
steps = 3#parse(Int, ARGS[2])#2
learning_rate = 0.01#parse(Float64, ARGS[3])#0.01
start_learning_rate = learning_rate
decay = 0.7 #parse(Float64, ARGS[4])#0.7
bins = 20 #parse(Int, ARGS[5])#10#20

In [None]:
function subnet(dimA::Signed, dimB::Signed, bins::Signed, width=32)
    return Chain(
        Split(
            Chain(
                BatchNorm(dimA),
                Dense(dimA => width, activation),
                BatchNorm(width),
                Dense(width => width, activation),
                BatchNorm(width),
                Dense(width => dimB*(bins+1))  
                ), 
            Chain(
                BatchNorm(dimA),
                Dense(dimA => width, activation),
                BatchNorm(width),
                Dense(width => width, activation),
                BatchNorm(width),
                Dense(width => dimB*bins)
                )
            ) 
        )
end

model = Flux.f64(Chain(
    CouplingLayerCPU(dim, dimA, bins, subnet),
    masterarbeit.MaskLayerCPU([false, false, true, true, true]),
    CouplingLayerCPU(dim, dimA, bins, subnet),
    masterarbeit.MaskLayerCPU([true, false, false, true, true]),
    CouplingLayerCPU(dim, dimA, bins, subnet),
    masterarbeit.MaskLayerCPU([false, true, false, true, true]),
    CouplingLayerCPU(dim, dimA, bins, subnet)
) |> cpu );

In [None]:
xtest = Random.rand(dim, batchsize)
f(ytozmap(xtest));

In [None]:
model(xtest);

In [None]:
lossf(model, ytozmap, f, xtest);

In [None]:
Flux.withgradient(m -> lossf(m, ytozmap, f, xtest), model);

In [None]:
losses = Float64[]
if progress_plots
    loopslength = Int(epochs / steps)
    for i in 1:steps
        @info "Training step $i/$steps"
        global losses = train_NN_cpu(model, dim, lossf, losses, ytozmap, f, epochs=loopslength, batchsize=batchsize, optimizer=optimizer, learning_rate=learning_rate, decay=decay)
        #Ea_samples, cta_samples, phia_samples, Eb_samples, ctb_samples = sample_trident(model, ytozmap, dim, N_samples, batchsize)
        #savefig(plot_samples(Ea_samples, cta_samples, "Ea", "cos(theta_a)"), joinpath(dir, "epoch$(i*loopslength)_samples.png"))
        global learning_rate = learning_rate * decay
    end
else
    losses = train_NN_cpu(model, dim, lossf, losses, ytozmap, f, epochs=epochs, batchsize=batchsize, optimizer=optimizer, learning_rate=learning_rate, decay=decay)
end

In [None]:
fig = Figure(size=(1500,1000))
ax = Axis(fig[1,1], xlabel="epoch", ylabel="loss", xlabelsize=50, ylabelsize=50, yscale=log10)
lines!(1:length(losses), losses, linewidth=3, color=wblue, label="loss")
n = 10
lines!(n:length(losses), moving_average(losses, n), linewidth=4, color=worange, label="$n epoch \n moving average")
fig[1,2] = Legend(fig, ax)
#save("trident_nis_loss.png", fig)
fig

In [None]:
samples = sample_nomap_cpu(model, dim, N_samples, sample_batchsize)

In [None]:
Ea_samples, cta_samples, phia_samples, Eb_samples, ctb_samples = ytozmap(samples)

In [None]:
function makie_samples(samplesx, samplesy, xname, yname)
    histo = fit(Histogram, (samplesx, samplesy), nbins=100)
    histo_n = StatsBase.normalize(histo, mode=:pdf)
    fig = Figure(size=(1200,1000), figure_padding=40)
    ax = Axis(fig[1,1], xlabel=latexstring(xname), ylabel=latexstring(xname), 
        aspect=1, xlabelsize=50, ylabelsize=50)
    hm = heatmap!(histo.edges[1], histo.edges[2], histo_n.weights)#, colorrange=(0,5), highclip=cgrad(:viridis)[end])
    fig[1, 2] = GridLayout(width = 20)
    Colorbar(fig[1,3], hm, width=40)
    #ylims=(0.0,1.0)
    return fig
end

In [None]:
makie_samples(Ea_samples, Eb_samples, "Ea", "Eb") 

In [None]:
function weights4cl(m::Chain, cm::ChannelMapping, f::Function, x::T) where {T <: AbstractArray}
    return jacobian_trident(m, cm, x) .* f(cm(m(x)))'
end

function weights4cl_chunked(m, dim, cm, f, N, batchsize)
    if (N%batchsize != 0) 
        x = Random.rand(dim, N%batchsize)
        weights = weights4cl(m, cm, f, x)
        runs = N ÷ batchsize 
    else
        x = Random.rand(dim,   batchsize)
        weights = weights4cl(m, cm, f, x)
        runs = N ÷ batchsize - 1
    end
    for i in 1:runs
        x = Random.rand(dim, batchsize)
        weights = hcat(weights, weights4cl(m, cm, f, x))
    end
    return weights
end

wi = weights4cl_chunked(model, dim, ytozmap, f, N_samples, batchsize)[1, :]

In [None]:
#f_evals = f(ytozmap(samples))[1,:]
f_over_g = wi
mc_int = sum(f_over_g) / size(samples,2) #* cmdet(ytozmap)
mcerror = sqrt(sum((f_over_g .- mc_int).^2) / (size(samples,2)-1))
println("mc integral = $mc_int")
println("standard deviation = $mcerror")

In [None]:
wi[wi .> 0.1]

In [None]:
f_over_g = wi
mc_int = sum(f_over_g) / size(samples,2)
mcerror = sqrt(sum((f_over_g .- mc_int).^2) / (size(samples,2)-1))
println("mc integral = $mc_int")
println("standard deviation = $mcerror")

In [None]:
wi_n = wi ./ mc_int
w_avg = mean(wi_n)
w_max = maximum(wi_n)
uw_eff = w_avg / w_max