In [7]:
using Flux
using Zygote
using BenchmarkTools
using LinearAlgebra
using Pkg;
using Plots;

### This notebook is just testing how to get gradients from a matrix subsampling procedure.

In [2]:
#Pkg.add(url="https://github.com/JuliaDiff/Diffractor.jl")

In [3]:
function test1(x,W)
    return sum(x * W)
end

function test2(x,W,S)
    temp = view(W, :, S)
    return sum(x * temp)
end

function test2b(x,W)
    S = rand(1:dim2,50);
    temp = view(W, :, S)
    return sum(x * temp)
end

test2b (generic function with 1 method)

In [5]:
dim1,dim2 = 784,1000
x = randn(1,dim1);
W1 = fill(0.01,(dim1,dim2));#randn(dim1,dim2);
W2 = fill(0.01,(dim1,dim2));#randn(dim1,dim2);
S = rand(1:dim2,50);

In [None]:
@btime test1(x,W1) #scales approximately linearly with increasing dim2

In [None]:
@btime test2(x,W2,S) #runs in constant time with increasing dim2

In [None]:
@btime Zygote.gradient(w -> test1(x,w),W1) # 1.554 ms -2x-> 4.073 ms -3x-> 7.634 ms  :scales approx linearly

In [10]:
@btime begin
    tempW2 = view(W2,:,:)
    g = gradient(() -> test2(x,W2,S), Flux.params(tempW2))
end # 321.709 μs -2x-> 488.125 μs -3x-> 810.333 μs  :scales approx linearly

  244.167 μs (74 allocations: 6.29 MiB)


Grads(...)

In [None]:
g[1]

In [None]:
Flux.update!(Descent(),tempW2,g[1])

In [None]:
W2

In [None]:
t = [50 * exp(x) for x in LinRange(0,5,100000)]
plot(t)

In [None]:
t[10000:end]

In [34]:
x = randn(1,784)
dim1,dim2,dim3 = 784,20000,250
WM1 = randn(dim1,dim2)
WM2 = randn(dim2,dim3);

In [35]:
function test3_(x,W1,W2)
    S1 = rand(1:size(W1)[2],50)
    layer1 = x * (view(W1,:,S1)) .|> relu
    S2 = rand(1:size(W2)[2],50)
    layer2 = layer1 * (view(W2,S1,S2))
    return layer2, S1, S2
end

function test3(x,W1,W2)
    test3_(x,W1,W2)[1] |> sum
end

test3 (generic function with 1 method)

In [38]:
test3(x,WM1,WM2)

247.08838225957322

In [None]:
@btime begin
    _,S1,S2 = test3_(x,WM1,WM2);
    gradient((w1,w2) -> test3(x,w1,w2), view(WM1,:,S1), view(WM2,S1,S2))
end
#305 us -> 309 us -> 312 us -> 318 -> 320

In [None]:
d = autodiff(test1,x,Active(W1))

In [1]:
mutable struct layer
    W::Matrix
    Wview::SubArray
end

In [134]:
#Flux.trainable(a::layer) = (a.Wview)

In [17]:
function get_view(a::Matrix,r,c)
    #r,c are arrays, if empty [] then return all elements
    d1,d2 = size(a);
    return view(a, isempty(r) ? (:) : r, isempty(c) ? (:) : c)
end

get_view (generic function with 1 method)

In [19]:
dim2 = 10000
x3 = randn(784,dim2)
x4 = randn(dim2,10)
m1 = layer(copy(x3),get_view(x3,[],[]))
m2 = layer(copy(x4),get_view(x3,[],[]))

layer([-0.15800899510212627 0.8970761182800473 … -0.5264042242254652 -0.8207015713133116; 0.2610516939898696 0.3440111315509064 … -0.1345688954087462 0.21881853672282225; … ; 0.672726604511416 -1.2011174046594366 … -0.5026151100871522 0.7011518549147665; -1.9567018081299312 0.6645637701950665 … -0.36057278823286654 -3.0038553376298838], [0.7032383914781773 -0.6129517470046456 … -0.3833587692611754 1.7173563810062265; 0.8447411309354508 -1.2306566192647046 … -0.7940599564025127 -0.19013273081314708; … ; 0.7798234031888709 -0.009314728529346984 … -1.3722459417492678 -0.26865380815019463; -0.6504932870838276 1.7595967280095441 … -0.8361120327792004 -0.04092435401053748])

In [9]:
function test4(x,ms)
    Zygote.ignore() do
        S = rand(1:size(ms[1].W)[2],50)
        ms[1].Wview = view(ms[1].W,:,S)
        ms[2].Wview = view(ms[2].W,S,:)
    end
    y = x * ms[1].Wview
    y = Flux.relu.(y)
    y = y * ms[2].Wview
    sum(y);
end

test4 (generic function with 1 method)

In [10]:
#m1.Wview

In [11]:
x = randn(1,784)

1×784 Matrix{Float64}:
 -0.807171  0.346642  0.377439  -0.0611428  …  0.194407  -0.890997  0.557802

In [20]:
test4(x,[m1,m2])

494.2163675403265

In [66]:
@btime begin
    gradient((w) -> test4(x,w), [m1,m2])
end
# 144.333 μs (148 allocations: 328.52 KiB)
# 150.208 μs (148 allocations: 328.52 KiB)

  150.208 μs (148 allocations: 328.52 KiB)


(Base.RefValue{Any}[Base.RefValue{Any}((W = nothing, Wview = [-0.2697219406547307 -0.7575446939802694 … 0.1870144551527984 1.507406281144613; 0.1894300944407582 0.5320359277240755 … -0.13134328566444042 -1.0586758848934148; … ; -0.8521688315429239 -2.393412916812281 … 0.590859940212289 4.76255156011884; 0.718454481555286 2.0178609832310888 … -0.4981477335288166 -4.015256584555275])), Base.RefValue{Any}((W = nothing, Wview = [-0.6136838467288679 -0.6136838467288679 … -0.6136838467288679 -0.6136838467288679; 20.32797472305305 20.32797472305305 … 20.32797472305305 20.32797472305305; … ; -21.493028086200155 -21.493028086200155 … -21.493028086200155 -21.493028086200155; -22.911043803392282 -22.911043803392282 … -22.911043803392282 -22.911043803392282]))],)

In [151]:
#gs = gradient((w) -> test4(x,w), [m1,m2])

In [13]:
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any, Any}())

In [49]:
ps_ = [m1.Wview,m2.Wview]
ps = Flux.Params(ps_)
gs = gradient((w) -> test4(x,w), [m1,m2])

(Base.RefValue{Any}[Base.RefValue{Any}((W = nothing, Wview = [1.167698129952942 0.0 … -3.3410579443070247 0.0; -0.5014711114221566 0.0 … 1.43482634567976 0.0; … ; 1.2889657896352384 0.0 … -3.688033131965658 0.0; -0.8069476361627049 0.0 … 2.308866256855126 0.0])), Base.RefValue{Any}((W = nothing, Wview = [12.322694373950252 12.322694373950252 … 12.322694373950252 12.322694373950252; 0.0 0.0 … 0.0 0.0; … ; 4.78461171717355 4.78461171717355 … 4.78461171717355 4.78461171717355; 0.0 0.0 … 0.0 0.0]))],)

In [45]:
ps_ = [m1.Wview,m2.Wview]
ps = Flux.Params(ps_)
gs = Flux.gradient(ps) do
   test4(x,[m1,m2]) 
end

Grads(...)

In [50]:
#ps[1]

In [22]:
#ps_[1]

In [51]:
Flux.update!(opt,ps,gs)

LoadError: MethodError: no method matching getindex(::Tuple{Vector{Base.RefValue{Any}}}, ::SubArray{Float64, 2, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false})
[0mClosest candidates are:
[0m  getindex(::Tuple, [91m::Int64[39m) at tuple.jl:29
[0m  getindex(::Tuple, [91m::Real[39m) at tuple.jl:30
[0m  getindex(::Tuple, [91m::Colon[39m) at tuple.jl:33
[0m  ...

In [32]:
gs[keys(gs)[2]]

In [48]:
m1.Wview #0.013585

784×50 view(::Matrix{Float64}, :, [7859, 4214, 5492, 4630, 8724, 4046, 6727, 7046, 6571, 845  …  3590, 5635, 3462, 485, 8231, 8715, 7238, 9142, 9322, 1425]) with eltype Float64:
  0.013585   -1.24705     0.742505   …   3.2127     -1.32154    -1.1915
 -2.68882     0.668496    0.990674      -1.12904    -0.44524     1.53585
 -0.0827462  -0.80611     0.719515      -1.84538    -0.177525   -1.20195
  1.95749     0.0697655   1.19909        0.157741    0.573435   -0.134393
 -0.355637   -0.090916    0.406813      -1.26082    -0.0203429  -0.214067
 -0.373272    1.03781     0.981843   …   0.607604   -0.584948   -1.66621
  0.106565   -0.923595    1.81244        0.281849    1.21781     0.85089
  0.89297    -0.892133   -0.61964        0.777655   -1.58092     1.12252
 -0.738564    0.341992   -2.24906        1.0063      0.249455    2.52767
  0.339343    0.307758   -0.496446       0.530877    0.809195   -2.0146
 -0.0110729  -0.906679    0.0675756  …   1.13557     1.07278     0.107747
 -0.540206   -0.60

In [129]:
Flux.params([m1.Wview,m2.Wview])[1]

784×50 view(::Matrix{Float64}, :, [8668, 4828, 4943, 268, 4750, 4716, 5044, 5512, 4413, 3249  …  2268, 8761, 2742, 5225, 615, 1077, 5561, 9293, 9936, 316]) with eltype Float64:
  0.780976   -1.58842     0.257255   …  -1.01771   -1.51509      1.65517
 -1.25786    -0.20717    -1.25822        1.15939    0.471244    -0.671743
  1.16555     0.246312    0.670447      -0.179959   0.0654916   -1.3876
  0.8739     -0.400095    0.719001      -0.309691  -1.4436      -0.215796
  1.51037     0.589977   -0.687516       0.472304  -0.798547     1.25884
 -0.0151287   0.863321   -1.44593    …   0.122563   0.00187823   1.78311
  1.52352    -1.05138    -0.315093       0.871862   0.631322     0.713788
 -0.452527   -0.201084    0.0795442      0.206121   0.85076      1.10328
  1.57659    -0.420705    2.22916        0.118898   0.223046    -1.10166
  0.56751    -1.594       1.97352        1.74448    0.825656     0.471841
 -1.28322    -0.0955239  -0.648455   …   0.510153  -0.354664    -0.0479271
  0.973216    1