In [1]:
using Revise

In [3]:
using Plots; gr()

Plots.GRBackend()

In [189]:
import Nets2
reload("Nets2")



In [190]:
xs = linspace(-pi, pi, 20)
f = x -> sin(x)
ys = f.(xs)

20-element Array{Float64,1}:
 -1.22465e-16
 -0.324699   
 -0.614213   
 -0.837166   
 -0.9694     
 -0.996584   
 -0.915773   
 -0.735724   
 -0.475947   
 -0.164595   
  0.164595   
  0.475947   
  0.735724   
  0.915773   
  0.996584   
  0.9694     
  0.837166   
  0.614213   
  0.324699   
  1.22465e-16

In [236]:
net = Nets2.Chain(
    Nets2.Affine(10, 1),
    Nets2.activation(Nets2.elu),
    Nets2.Affine(10, 10),
    Nets2.activation(Nets2.elu),
    Nets2.Affine(1, 10),
    Nets2.activation(Nets2.elu)
)
Nets2.initialize!(() -> 0.1 * randn(), net)

In [237]:
xx = linspace(-pi, pi, 100)
yy = f.(xx)
yhat = net.(xx)
plt = plot(xx, yy)
plot!(plt, xx, first.(yhat))
plot!(plt, xs, ys, line=nothing, marker=:dot)
plt

In [238]:
batches = [zip(xs, ys)]

1-element Array{Base.Iterators.Zip2{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float64,1}},1}:
 Base.Iterators.Zip2{StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float64,1}}(-3.141592653589793:0.3306939635357677:3.141592653589793, [-1.22465e-16, -0.324699, -0.614213, -0.837166, -0.9694, -0.996584, -0.915773, -0.735724, -0.475947, -0.164595, 0.164595, 0.475947, 0.735724, 0.915773, 0.996584, 0.9694, 0.837166, 0.614213, 0.324699, 1.22465e-16])

In [239]:
loss = (params, batch) -> begin
    n = Nets2.with_params(net, params)
    sum((xy) -> sum(abs2, n(xy[1]) .- xy[2]), batch)
end

(::#206) (generic function with 1 method)

In [243]:
param_vector = CatView(Nets2.params(net))
for i in 1:1000
    for batch in batches
        ∇ = ReverseDiff.gradient(params -> loss(params, batch), param_vector)
        param_vector .-= 0.01 .* ∇
    end
end

In [244]:
yhat = net.(xx)
plt = plot(xx, yy)
plot!(plt, xx, first.(yhat))
plt

In [246]:
param_vector

141-element CatViews.CatView{6,Float64}:
  0.314799  
  0.0493019 
  0.561192  
  0.711589  
  0.313391  
 -1.04097   
 -1.27727   
  0.0658407 
  0.0139828 
 -0.0703404 
 -0.40991   
 -0.105425  
 -0.77582   
  ⋮         
 -0.524718  
  0.00384847
  0.372765  
 -0.0119575 
  0.0727551 
  0.701289  
  0.392373  
  0.328673  
  0.360447  
  0.59701   
  0.852871  
 -0.368602  

In [245]:
ReverseDiff.gradient(params -> loss(params, batch), param_vector)

141-element Array{Float64,1}:
 -1.13765   
 -1.89267e-5
 -2.38438   
 -3.08864   
 -1.21771   
 -0.907063  
  0.449463  
 -0.0217478 
 -0.180902  
  0.0549902 
 -0.651718  
 -0.132547  
 -1.26654   
  ⋮         
  0.424123  
 -0.0777788 
  1.55742   
 -0.036854  
  0.0924701 
  4.17959   
  2.26533   
  1.60159   
  1.37286   
  3.50048   
 -4.20599   
  3.11689   

In [181]:
loss = (param_vector, x, y) -> begin
    params, s, e = splitview(param_vector, param_sizes...)
    n = Nets2.with_params(net, params...)
    ŷ = n(x)
    sum(abs2, ŷ .- y)
end

function update!(loss, net, batches)
    param_sizes = size.(Nets2.params(net))
    function inner(param_vector, x, y)
        params, s, e = splitview(param_vector, param_sizes...)
        n = Nets2.with_params(net, params...)
        ŷ = n(x)
        loss(ŷ, y)
    end
    param_vector = CatView(Nets2.params(net)...)
    ∇ = zeros(size(param_vector))
    grad_result = similar(∇)
    for batch in batches
        ∇ .= 0
        for (x, y) in batch
            ReverseDiff.gradient!(grad_result, params -> inner(params, x, y), param_vector)
            ∇ .+= grad_result ./ length(batch)
        end
    end
    ∇
end
    

update! (generic function with 1 method)

In [184]:
param_vector = CatView(Nets2.params(net)...)
for i in 1:100
    ∇ = update!(loss, net, batches)
    param_vector .-= 0.1 .* ∇
end

In [186]:
update!(loss, net, batches)

31-element Array{Float64,1}:
 -0.00248287 
 -0.000319703
  0.00156972 
  0.000656617
 -0.000565252
 -0.000204491
  0.000226818
 -0.000722073
 -0.00507269 
 -4.57776e-5 
 -0.00123884 
  8.47473e-7 
 -0.000770405
  ⋮          
  0.000219476
  0.000470093
  0.000356932
  0.000676513
  0.00166739 
 -0.00419664 
  0.000179729
  0.000551915
 -0.00391531 
  0.000260779
  0.000404994
  0.000419991

In [81]:
loss = Nets2.sample_loss(net) do x, y
    sum(abs2, x .- y)
end

(::#5) (generic function with 1 method)

In [83]:
loss(Nets2.FlatParams(net), [0.0], [0.0])

0.929818787910725

In [84]:
import Nets

In [None]:
options = Nets.AdamOpts(batch_size = length(first(batches)))
flat = Nets2.FlatParams(net)
optimizer = Nets.AdamOptimizer(options, 
    Nets.StochasticOptimization.Adam(Float64)
    ∇ = zeros(flat.flat)
    

In [27]:
batch = first(batches)
loss2(Nets2.params(net))

246.15923873699808

In [29]:
sum(loss1(Nets2.params(net), x, y) for (x, y) in batch)

246.15923873699808

In [30]:
using CatViews

LoadError: [91mArgumentError: Module CatViews not found in current path.
Run `Pkg.add("CatViews")` to install the CatViews package.[39m