In [2]:
# Use package versions builtin to this repository.
import Pkg, Random
Pkg.activate(@__DIR__)
Pkg.instantiate()

# Load Flux and PlotlyJS for sweet interactive graphics
using Flux, PlotlyJS

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`

# Flux By Example: Differentiable Programming

At the heart of Flux is the idea of "differentiable programming"; a technique that allows us to identify how a particular piece of a computation effects the end result.  As an example, we will build a polynomial approximator.  Any function can be approximated by a polynomial (some functions less successfully than others).  We will use differentiable programming to build polynomial approximations for a variety of functions within this notebook, explaining fundamental concepts of Flux as we do so.

Mathematically, we define our polynomial function $f^{(N)}(x)$ as:

$$
    f^{(N)}(x) = \sum_{i=1}^N \left( x^{i-1} \cdot w_i \right)
$$

Where $x$ represents the (scalar) input to function $f$ and $w$ represents the internal coefficients of $f$.  The structure of the computation performed within $f$ is set; we know we will be calculating a polynomial.  However, there is great freedom in the choice of the values of $w$, which is what will give this equation its ability to approximate other functions.  As a first example, we will simply approximate the function $|x|$:

In [3]:
x = collect(-10:.1:10)
plot([
    scatter(;x=x, y=abs.(x), name="|x|"),
])

We will arbitrarily define `f(x)` to be of order 3 for now, and we will initialize the coefficients `w` randomly:

In [27]:
# Get predictable random numbers for the sake of notebook reliability
Random.seed!(3)

# Initialize `w` as three (small) random numbers:
w = randn(3)

# Define f(x):
f(x) = sum([x^(i-1) .* w[i] for i in 1:3])

f (generic function with 1 method)

Excellent.  Let's plot it and see how we're doing:

In [50]:
plot([
    scatter(;x=x, y=abs.(x), name="|x|"),
    scatter(;x=x, y=f.(x), name="f(x)"),
])

A very poor approximation indeed.  The question now becomes, "how do I tweak $w$ such that $f(x)$ becomes more similar to $|x|$"?  This is the question that differentiable programming answers.  We take the gradient of $f(x)$ with respect to $w$, then use that to nudge $f(x)$ closer to $|x|$.

To do so, we will define the following equivalencies:

$$
    y = |x|
$$

$$
    \hat{y} = f(x)
$$

And we will mathematically define our problem as an optimizaiton problem; searching for the set of $w$ parameters that best minimizes the difference between $y$ and $\hat{y}$, with respect to the $\ell_2$ norm.

$$
    \underset{w}{\text{minimize}} \,\, \big\|y - \hat{y} \big\|_2
$$

This is, of course, equivalent to:
$$
    \underset{w}{\text{minimize}} \,\, \left\||x| - \sum_{i=1}^N \left( x^{i-1} \cdot w_i \right) \right\|_2
$$

To discover how best to change $w$ to get a better result, we will therefore first calculate a $\hat{y}$ using our current values of $w$, compare them to the true $y$, use that difference to determine how $w$ should change, then update our $w$ values and do the whole thing again.  This is our training loop, and will form the basis of how we interact with the parameters of our function $f(x)$.

In order for `Flux` to know that $w$ is the part of our function $f(x)$ that should be changed, we must wrap $w$ with the `param()` function.  We do so, and use it within a separate `f_tracked` function to illustrate the slight difference in output datatype between `f(x)` and `f_tracked(x)`:

In [51]:
w_tracked = param(w)
f_tracked(x) = sum([x^(i-1) .* w_tracked[i] for i in 1:3])

x = 1
@show f(x)
@show f_tracked(x)

f(x) = 0.746632461796028
f_tracked(x) = 0.746632461796028 (tracked)


0.746632461796028 (tracked)

Notice how the output of `f_tracked()` says `(tracked)`.  This denotes that Flux knows this output is a function of a parameter, and we can thus calculate how those parameters can effect this output.  We now calculate the difference between this calculated output and our desired output:

In [52]:
# Define l2_loss() function to calculate a measure of how far off we are from our target output
l2_loss(y, y_hat) = sqrt(sum((y .- y_hat).^2))

l = l2_loss(abs(x), f_tracked(x))
@show l

l = 0.25336753820397195 (tracked)


0.25336753820397195 (tracked)

This nonzero loss shows that we missed the target a bit (obviously, as 0.746... is not equal to 1, the desired output).  However, the real magic begins by using the function `Flux.back(l)` to take the loss `l` and push it "back" through the computation, attributing to `w` the changes that must be made in order to reduce this loss:

In [53]:
Flux.Tracker.back!(l)

A little underwhelming perhaps, however now if we inspect `w_tracked` we can see that there are proposed changes attributed to the values:

In [54]:
w_tracked.grad

3-element Array{Float64,1}:
 -1.0
 -1.0
 -1.0

Let's use these proposed changes to modify `w_tracked.data` (the actual values used in the calculation), zero out the `.grad` values (so that they are ready to be set by a future `Flux.Tracker.back!()` invocation) and then see how this change has effected our guess.  Note that we will take a conservative step here, (by multiplying by $eta$, which moves the parameters by a very small portion of the gradient, rather than the entire gradient) so as to only effect a small change in the output.

In [55]:
η = 1e-5
w_tracked.data[:] -= w_tracked.grad .* η
w_tracked.grad[:] .= 0.0

@show l
@show l2_loss(abs(x), f_tracked(x))

l = 0.25336753820397195 (tracked)
l2_loss(abs(x), f_tracked(x)) = 0.25333753820397176 (tracked)


0.25333753820397176 (tracked)

Excellent!  We're a little closer to our goal.  Let's plot it and see what the function looks like:

In [56]:
# We need to build a helper function to drop the "tracked" part of f_tracked(),
# because plotting functions don't know how to deal with TrackedArrays:
f_tracked_data(x) = f_tracked(x).data

function plot_x_fx_ftrackedx()
    x = collect(-10:.1:10)
    plot([
        scatter(;x=x, y=abs.(x), name="|x|"),
        scatter(;x=x, y=f.(x), name="f(x)"),
        scatter(;x=x, y=f_tracked_data.(x), name="f_tracked_data(x)"),
    ])
end

plot_x_fx_ftrackedx()

Hmmm.  Not that much of a difference.  Let's do an iteration evaluating at $x = 6$:

In [57]:
x = 6

# Perform forward pass, calculating loss
l = l2_loss(abs(x), f_tracked(x))

@show l

# Attribute loss back onto w
Flux.Tracker.back!(l)

# Update w
w_tracked.data[:] -= w_tracked.grad .* η
w_tracked.grad[:] .= 0.0

# Recalculate l, showing that loss has improved
l = l2_loss(abs(x), f_tracked(x))
@show l

l = 54.766739169892546 (tracked)
l = 54.753409169892535 (tracked)


54.753409169892535 (tracked)

In [58]:
plot_x_fx_ftrackedx()

Alright, a bit more of a difference!  Now let's write a loop that runs an iteration for every x from -10 to 10, taking a step of 0.1 each time, and returning the average loss.  We'll call this a `training_epoch()`:

In [59]:
function training_epoch(true_model, model)
    avg_loss = 0.0
    range = -10:0.1:10
    for x in range
        # Perform forward pass, calculating loss
        l = l2_loss(true_model(x), model(x))

        # Attribute loss back onto w
        Flux.Tracker.back!(l)
        
        # Accumulate average loss
        avg_loss += l.data ./ length(range)

        # Update w
        w_tracked.data[:] -= w_tracked.grad .* η
        w_tracked.grad[:] .= 0.0
    end
    return avg_loss
end

training_epoch(abs, f_tracked)

64.89673077450595

In [60]:
plot_x_fx_ftrackedx()

Even more improvement!  Now let's do that over and over again, 500 times, keeping track of the losses, to show that we eventually converge to a point at which our loss no longer improves, then examine the final product:

In [61]:
losses = Float64[]
for epoch in 1:500
    l = training_epoch(abs, f_tracked)
    push!(losses, l)
end

In [66]:
x = collect(-10:.1:10)
plot([
    scatter(;x=x, y=abs.(x), name="|x|"),
    scatter(;x=x, y=f_tracked_data.(x), name="f_tracked_data(x)", marker_color="green"),
])

Say, that's not bad at all.  And if we look at the loss versus training iteration, we can convince ourselves that we have indeed converged to close to the best this model can deliver:

In [80]:
plot(losses; name="Training epochs")

Congratulations, you have just written your first differentiable program in `Flux`.  To recap: we set out to use a polynomial to approximate $|x|$, and used differentiable programming to learn the polynomial coefficients.  We did this by starting with a random guess for the coefficients, calculating values of $f(x)$ with those junk coefficients, then iteratively refining them using the gradients upon $w$ that `Flux` is able to calculate with the `Flux.Tracker.back!()` function.

This is, of course, an extremely low-level way of working with `Flux`.  `Flux` provides much higher-level ways of dealing with models of significant complexity, which will be addressed in later notebooks.  For now, let's have some fun trying to fit various nonlinear functions with our new function approximator:

### Fitting the step function

In [85]:
# Start from randomness again
w_tracked.data[:] = randn(3)

# Try to fit the step function:
u(x) = Float64(x > 0)

# Run training loop, keeping track of losses
losses = Float64[]
for epoch in 1:500
    l = training_epoch(u, f_tracked)
    push!(losses, l)
end

x = collect(-10:.1:10)
plot([
    scatter(;x=x, y=u.(x), name="step(x)"),
    scatter(;x=x, y=Flux.Tracker.data.(f_tracked.(x)), name="f(x)", marker_color="green"),
])
# Plot function and tracked loss
#display(p)
#p = lineplot(losses; title="Loss versus training epochs")
#display(p)

In [None]:
plot(losses; name="Training epochs")

### Fitting a piecewise linear function

In [17]:
# Start from randomness again
w_tracked.data[:] = randn(3)

# Run training loop, keeping track of losses
losses = Float64[]
for epoch in 1:500
    l = training_epoch(relu, f_tracked)
    push!(losses, l)
end

# Plot function and tracked loss
p = lineplot([relu, f_tracked_data], -10, 10)
display(p)
p = lineplot(losses; title="Loss versus training epochs")
display(p)

[37m           ┌────────────────────────────────────────┐[39m                  
        [37m10[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⡆[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[35m⡜[39m[37m│[39m [34mNNlib.relu(x)[39m    
          [37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⡇[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39

[37m             Loss versus training epochs[39m
[37m      ┌────────────────────────────────────────┐[39m 
   [37m40[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m│[39m 
     [37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀