# FluxML Example

Flux is the Julia Machine Learning. The website comes with plenty of resources: https://fluxml.ai/
Let's jump to functional code. 

In [None]:
import Flux
import Statistics

import Plots;
using LaTeXStrings

function _flux_linear_fit(nepocs, X, Y, W0, b0, plot)
    # Prepare the training data (horizontal concatenation)
    Xd = reduce(hcat, X)
    Yd = reduce(hcat, Y)
    data = [(Xd, Yd)]
    println("Data", data)

    # Define the model. Weight and bias are arrays initialized to some values.
    # Training a model will adjust weights and biases.
    model = Flux.Dense(1, 1) # model = W x + b
    model.weight .= [W0]
    model.bias .= [b0]

    # Define the loss function
    loss(model, x, y) = Statistics.mean((model(x) .- y) .^ 2)

    # Define an optimizer
    # optimizer = Flux.Descent(1)
    # optimizer = Flux.ADAM()
    optimizer = Flux.ADAM(1, (0.99, 0.999))
    opt = Flux.setup(optimizer, model)

    Yd_0 = model(Xd)
    println("Initial solution: ", Yd_0)

    # Plot the initial data
    if plot
        Plots.plot(X, Y, st=:scatter, label="y", legend=:topleft)
    end

    for iter = 1:nepocs
        Flux.train!(loss, model, data, opt)

        # Plot the evolution of the model
        if iter % 10 == 0
            if plot
                if iter <= 100
                    Yd_nE = model(Xd)
                    Plots.plot!(Xd', Yd_nE', lc=:blue, label=L"y_{%$iter}")
                elseif 100 < iter && iter <= 1000 && iter % 200 == 0
                    Yd_nE = model(Xd)
                    Plots.plot!(Xd', Yd_nE', lc=:orange, label=L"y_{%$iter}")
                elseif 1000 < iter && iter <= nepocs && iter % 500 == 0
                    Yd_nE = model(Xd)
                    Plots.plot!(Xd', Yd_nE', lc=:red, label=L"y_{%$iter}")
                end
            end
            println("Epoch: ", iter, " Loss: ", loss(model, Xd, Yd))
        end
    end

    Yd_nE = model(Xd)
    println("Final solution: ", Yd_nE)

    if plot
        Plots.plot!(Xd', Yd_0', lc=:green, label=L"y_0")
        Plots.plot!(xlabel="x", ylabel="y", title="Flux.jl Linear Fit")
    end

end

function test()
    x = Float32[    256,     2_100,     4_096,       512,     5_500,     8_192]
    y = Float32[435_867, 2_959_963, 1_475_489, 1_485_569, 2_592_234, 4_518_030]

    W0 = 20.0
    b0 = 500.0

    nepocs = 2_000
    plot = true

    _flux_linear_fit(nepocs, x, y, W0, b0, plot)
end

test()
