example of automatic differentiation for the purpose of machine learning.

implementing proximal gradient descent for minimization of a squared error loss function

source: https://discourse.julialang.org/t/types-and-gradients-including-forward-gradient/946

In [7]:
using ForwardDiff: gradient, derivative
using LinearAlgebra, Random

ArgumentError: ArgumentError: Package ForwardDiff not found in current path:
- Run `Pkg.add("ForwardDiff")` to install the ForwardDiff package.


In [1]:
# model
linear_regression(w, b, x) = w*x .+ b

linear_regression (generic function with 1 method)

In [2]:
# loss function
mean_squared_error(ŷ, y) = sum(abs2, ŷ .- y) / size(y,2)

mean_squared_error (generic function with 1 method)

In [3]:
# get gradient w.r.t to `w`
loss∇w(model, loss, w, b, x, y) = gradient(w -> loss(model(w, b, x), y), w)

loss∇w (generic function with 1 method)

In [4]:
# get derivative w.r.t to `b`
#
# (`derivative` is used instead of `gradient` because `b` is a scalar instead of an array)
lossdb(model, loss, w, b, x, y) = derivative(b -> loss(model(w, b, x), y), b)

lossdb (generic function with 1 method)

In [5]:
# optimization algorithm
function proximal_gradient_descent(model, loss, w, b, x, y; lr=.1)
    w -= lmul!(lr, loss∇w(model, loss, w, b, x, y))
    b -= lr * lossdb(model, loss, w, b, x, y)
    return w, b
end

proximal_gradient_descent (generic function with 1 method)

In [6]:
function main(T=Array, n = 100000)
    Random.seed!(0)

    p = 25
    x = randn(n,p)'
    y = sum(x[1:5,:]; dims=1) .+ randn(n)'*0.1

    w = 0.0001*randn(1,p)
    b = 0.0

    x = T(x)
    y = T(y)
    w = T(w)

    model = linear_regression
    loss = mean_squared_error
    @time for i=1:p
        w, b = proximal_gradient_descent(model, loss, w, b, x, y)
        println(loss(model(w,b,x),y))
    end
end

main (generic function with 3 methods)

In [9]:
main(Array)

UndefVarError: UndefVarError: Random not defined

# Exercise

- play around with this code and define your own model or loss function
- is this implementation generic enough that it just *works* with different array types
- try using either `CuArrays` or `DArrays`