In [1]:
#import Libraries
using Zygote
using LinearAlgebra

In [2]:
#basic model of Linear
mutable struct Linear
    W
    b
end

(l::Linear)(x) = l.W * x .+ l.b
(l::Linear)(x,y) = norm(l.W * x .+ l.b .- y, 2)

In [3]:
#generate fake data points
weights_gt = [1.1, 2.8, 1.8, 0.3]'
bias_gt = 0.4
X = randn(length(weights_gt), 10000)
Y = weights_gt * X .+ bias_gt
X .+= 0.001.*randn(size(X))

4×10000 Array{Float64,2}:
  1.01179     1.10688   -1.30755   …  -0.928794  -0.203992  -1.96328 
  0.347247   -0.253724  -0.655867     -0.265327   0.19086    0.982526
 -0.0663809  -1.71468   -0.573365      0.544118  -0.935005   0.702351
  0.438522    1.20973    1.27484      -0.415798  -1.43292    1.83048 

In [4]:
model = Linear(rand(1,4), rand(1))

Linear([0.613273 0.578343 0.736697 0.299476], [0.116939])

In [5]:
function sgd_update!(model::Linear, dmodel, η = 0.001)
    model.W .-= η .* dmodel.W
    model.b -= η * dmodel.b
end

sgd_update! (generic function with 2 methods)

In [6]:
#training
for idx in 1:10000
    i = rand(1:size(X,2))
    grads = gradient(model -> model(X[:,i],Y[i]),model)[1][]
    sgd_update!(model, grads)
end

In [7]:
@info("Ground truth weights: $(weights_gt)")
@info("Learned weights: $(round.(model.W; digits=3))")
@info("Ground truth bias: $(bias_gt)")
@info("Learned bias: $(round.(model.b; digits=3))")

┌ Info: Ground truth weights: [1.1 2.8 1.8 0.3]
└ @ Main In[7]:1
┌ Info: Learned weights: [1.103 2.797 1.804 0.299]
└ @ Main In[7]:2
┌ Info: Ground truth bias: 0.4
└ @ Main In[7]:3
┌ Info: Learned bias: [0.399]
└ @ Main In[7]:4
