In [None]:
using Zygote
using Statistics
using Plots
using StatsPlots

"""
Zygote is used for automatic differentiation
Statistics gives us mean()
We'll use Plots for a line plot and StatsPlots for a violin plot
"""

# learning to XOR

Hi. This is a tutorial about building a very simple multilayer perceptron to approximate the exclusive-or function, also known as the XOR function to its friends. It might also be your introduction to the Julia programming language. Developed for scientific computing, Julia is ostensibly something of a faster Python. One interesting feature of the language is that when you see a mathematical definition for a dense layer in a neural network, like so:

$
f(x) = \sigma(\theta_w x + b)
$

You can actually write code that looks very similar, thanks to Julia's support for unicode characters. It doesn't necessarily save you any time typing (symbols are typed by entering the $\LaTeX$ code, _e.g._ `\sigma` and pressing tab), but it does look pretty cool. 

In [None]:
σ(x) = 1 ./ (1 .+ exp.(-x))

f(x, θ) = σ(x * θ[:w] .+ θ[:b])

θ = Dict(:w => randn(32,2)/10, :b => randn(1,2)/100)
x = randn(4,32)

f(x, θ)

The functions below generate model weights and a noisy dataset representing the XOR function. 

In [None]:
get_xor = function(num_samples=512, dim_x=3)
    x = 1*rand(num_samples,dim_x) .> 0.5
    y = zeros(num_samples,1) 

    for ii = 1:size(y)[1]
        y[ii] = reduce(xor, x[ii,:])
    end

    x = x + randn(num_samples,dim_x) / 10

    return x, y
end

init_weights = function(dim_in=2, dim_out=1, dim_hid=4)
    
    wxh = randn(dim_in, dim_hid) / 8
    why = randn(dim_hid, dim_out) / 4
    θ = Dict(:wxh => wxh, :why => why)
    
    return θ
    
end


This next bit defines the model we'll be training: a tiny MLP with 1 hidden layer and no biases. We also need to set up a few helper functions to provide loss and other training metrics (accuracy)

In [None]:
f(x, θ) = σ(σ(x * θ[:wxh]) * θ[:why])


get_accuracy(y, pred, boundary=0.5) = mean(y .== (pred .> boundary)) 

log_loss = function(y, pred)
   
    return -(1 / size(y)[1]) .* sum(y .* log.(pred) .+ (1.0 .- y) .* log.(1.0 .- pred))

end

get_loss = function(x, θ, y, l2=6e-4)

    pred = f(x, θ)
    loss = log_loss(y, pred)
    loss = loss + l2 * (sum(abs.(θ[:wxh].^2)) + sum(abs.(θ[:why].^2)))
    return loss

end

The `gradient` function from Zygote does as the name suggests. We need to give `gradient` a function that returns a scalar (_i.e._ an objective function in this case), which is why we made an explicit `get_loss` function earlier. We'll store the results in a dictionary called $d\theta$, and update our model parameters by following gradient descent. We won't be training with stochastic gradient descent, because in this example we're not using minibatches. 

In [None]:
lr = 1e1;
x, y = get_xor(64,5);
θ = init_weights(5);

old_weights = append!(reshape(θ[:wxh], size(θ[:wxh])[1]*size(θ[:wxh])[2]),
    reshape(θ[:why], size(θ[:why])[1] * size(θ[:why])[2]))
    
    
dθ = gradient((θ) -> get_loss(x, θ, y), θ)
plt = scatter(old_weights, label = "old_weights")

θ[:wxh], θ[:why] = θ[:wxh] .- lr .* dθ[1][:wxh], θ[:why] .- lr .* dθ[1][:why]   

new_weights = append!(reshape(θ[:wxh], size(θ[:wxh])[1]*size(θ[:wxh])[2]),
    reshape(θ[:why], size(θ[:why])[1] * size(θ[:why])[2]))

scatter!(new_weights, label="new weights")
display(plt)

Lastly we need to design the training loop. This function takes training data and parameters as inputs, as well as a few hyperparameters for how long and how fast to train. 

In [None]:
train = function(x, θ, y, max_steps=1000, lr = 1e-2, l2_reg=1e-4)
    
    disp_every = max_steps // 100

    losses = zeros(max_steps)
    acc = zeros(max_steps)

    for step = 1:max_steps
        
        pred = f(x, θ)
        loss = log_loss(y, pred)
        
        losses[step] = loss 
        
        acc[step] = get_accuracy(y, pred)

        dθ = gradient((θ) -> get_loss(x, θ, y, l2_reg), θ)

        θ[:wxh], θ[:why] = θ[:wxh] .- lr .* dθ[1][:wxh], θ[:why] .- lr .* dθ[1][:why]       
        
        if mod(step, disp_every) == 0
            
            val_x, val_y = get_xor(512, size(x)[2]);
            pred = f(val_x, θ) 
            loss = log_loss(val_y, pred)
            accuracy = get_accuracy(val_y, pred)

            println("loss at step $step = $loss, accuracy = $accuracy")
            #save_frame(θ, step);

        end

    end
    return θ, losses, acc
end


With all our functions defined, it's time to set up the data and model and call the training loop. We'll use `violin` plots from the `StatsPlots` package to show how the distributions of weights change over time. Calling a `plot` function with the `!` in-place modifier allows you to add more plots the current figure. If we want to display more than 1 figure per notebook cell, and we do, we need to explicitly call `display` on the figure we want to show. 

In [None]:
dim_x = 3
dim_h = 4
dim_y = 1
l2_reg = 1e-4
lr = 1e-2
max_steps = 1000000

θ = init_weights(dim_x, dim_y, dim_h)
x, y = get_xor(512, dim_x)

println(size(x))


plt = violin([" "], reshape(θ[:wxh],dim_x * dim_h), label="wxh", title="Weights", alpha = 0.5)
violin!([" "], reshape(θ[:why],dim_h*dim_y), label="why", alpha = 0.5)
display(plt)

θ, losses, acc = train(x, θ, y, max_steps, lr, l2_reg)


plt = violin([" "], reshape(θ[:wxh],dim_x * dim_h), label="wxh", title="Weights", alpha = 0.5)
violin!([" "], reshape(θ[:why],dim_h*dim_y), label="why", alpha = 0.5)
display(plt)

steps = 1:size(losses)[1]
plt = plot(steps, losses, title="Training XOR", label="loss")
plot!(steps, acc, label="accuracy")
display(plt)


## plot and gif

Here are some plots from a previous training run.  The violin plot gif shows how the weights changed over time and the accuracy/loss plot shows a typical XOR training curve. 


<img src="violin_weights.gif">
<img src="temp.png">

Finally, it's a good idea to generate a test set to figure how badly our model is overfitted to the training data. If you're unlucky and get poor performance from your model, try changing some of the hyperparameters like learning rate or l2 regularization. You can also generate a larger training dataset for better performance, or try changing the size of the hidden layer by changing `dim_h`. Heck, you could even modify the code to add l1 regularization or add layers to the MLP. Go wild.

In [None]:

test_x, test_y = get_xor(512,3);

pred = f(test_x, θ);
test_accuracy = get_accuracy(test_y, pred);
test_loss = log_loss(test_y, pred);

println("Test loss and accuracy are $test_loss and $test_accuracy")


## testing vs validation

The difference between a test and validation dataset is immaterial when we're generating our data on demand as we are here. But normally you wouldn't want to go back and modify your training algorithm after running your model on a static test dataset. That sort of behavior runs a high risk of data leakage as you can keep tweaking training until you get good performance, but if stop only when the test score is good you'll actually have settled on a lucky score. That doesn't tell you anything about how the model will behave with actual test data that it hasn't seen before. This happens in the real world when researchers collectively iterate on a few standard dataset. Of course there will be [incremental improvement every year on MNIST](https://arxiv.org/abs/1905.10498) if everyone keeps fitting their research and development strategy to the test set!

In any case, thanks for stopping by and I hope you enjoying exploring automatic differentiation in Julia as I did. 

In [None]:
# Bonus, functions for plotting a decision surface and saving to disk for a training gif. 

get_decision_surface = function(θ)
    
    my_surfacex = []
    my_surfacey = []
    my_surfacez = []
    for xx = -0.25:0.005:1.5-0.25
        for yy = -0.25:0.005:1.5-0.25
            closest = 50000.
            my_coords = [100,100,100]
            for zz = -0.25:0.005:1.5-0.25
                pred = f(reshape([xx,yy,zz],1,3), θ)
                if abs(0.5 - pred[1]) < closest
                    my_coords = [xx,yy,zz]
                    closest = abs(0.5 - pred[1])
                end
            end
            append!(my_surfacex, my_coords[1])
            append!(my_surfacey, my_coords[2])
            append!(my_surfacez, my_coords[3])
        end
    end
            
    return my_surfacex, my_surfacey, my_surfacez
    
end


save_frame = function(θ, epoch) 

    plt = scatter3d([0,0,1,1],[0,1,0,1],[1,0,0,1], 
        xlim = (-0.05,1.05 ),
        ylim = (-0.05,1.05 ),
        zlim = (-0.05,1.05 ),
        label = "true",
        markercolor= :blue, markersize=15)

    scatter3d!([0,0,1,1],[0,1,0,1],[0,1,1,0], 
        xlim = (-0.25,1.25 ),
        ylim = (-0.25,1.25 ),
        zlim = (-0.25,1.25 ),
        label = "false",
        markercolor= :red, markersize=15)


    my_surface = get_decision_surface(θ);

    scatter3d!(my_surface[1], my_surface[2], my_surface[3],
        markercolor = :green,
        markerstrokecolor = :green,
        markerstrokealpha = 0.0,
        markeralpha = 0.05,
        label = "decision boundary",
        legend = :outerbottomright )

    for i=1:300 #24
        x, y = get_xor(1, 3)

        pred = f(x, θ)

        if pred[1] > 0.5
            my_color = :blue
        else
            my_color = :red
        end

        scatter3d!([x[1]], [x[2]], [x[3]], 
            color = my_color, 
            alpha = 0.1,
            label = "")
    end 

    savefig(plt, "frames/temp$epoch.png")
    
    return plt
end