In [2]:
using BSON: @load
using Flux
using Flux: chunk
using Flux.Data: DataLoader
using ImageFiltering
using Images
using ImageIO
using MLDatasets: FashionMNIST
using LinearAlgebra
using MLDatasets
using Plots

In [77]:
function PLUGIn_onebitCS(G, y, A, max_iter, stepsize, tolerance,lambda, out_toggle)
    
    (_, z_dim) = size(Flux.params(G[1])[1]);
    (m, _) = size(A)
    W = I(z_dim)
  
    #normalize the weights of the network
    for i in 1:length(G)
        _, s, _ = svd(Flux.params(G[i])[1])
        W = Flux.params(G[i])[1] * W /s[1]
    end
  
    z = randn(z_dim)
    iter = 1
    succ_error = 1
  
    while iter <= max_iter && succ_error > tolerance
      
      # d gives the PLUGIn direction
      # d = ( A * G(z) - y .* abs.(A * G(z)) )
      d = ( sign.(A * G(z)) - y )
      d = W'*A'* d

      # d = W'*G(z) - lambda *W'*A'*y/m
      z -= stepsize * d
      succ_error = norm(stepsize * d)
      if iter % out_toggle == 0  
          println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
      end
      iter += 1
    end
    println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
  
    return z
  end

PLUGIn_onebitCS (generic function with 2 methods)

In [83]:
#setup a synthetic problem
G = Chain(
    Dense(5, 75, relu, bias = false; initW =(out,in) ->  randn(75, 5)/sqrt(75)),
    Dense(75, 75, relu, bias = false; initW =(out,in) -> randn(75, 75)/sqrt(75)),
    Dense(75, 150, relu, bias = false; initW =(out,in) -> randn(150, 75)/sqrt(150))
)


z = randn(5)
m = 100000; A = randn(m, 150)/sqrt(m)
y = sign.(A*G(z)) + 1e-14 * randn(m)

stepsize = 1
tolerance = 1e-14
max_iter = 10000
out_toggle = 1000
lambda = 100
z_rec = PLUGIn_onebitCS(G,y,A, max_iter, stepsize, tolerance, lambda, out_toggle)
recov_error = norm(z/norm(z) - z_rec/norm(z_rec))
recon_error = norm(G(z) - G(z_rec))
println("recovery error: $recov_error, reconstruction error: $recon_error")

====> In quasi-gradient: Iteration: 1000 Successive error: 0.006830297176155472
====> In quasi-gradient: Iteration: 2000 Successive error: 0.0038048801707020008
====> In quasi-gradient: Iteration: 3000 Successive error: 0.00271062180622516
====> In quasi-gradient: Iteration: 4000 Successive error: 0.0082761849669763
====> In quasi-gradient: Iteration: 5000 Successive error: 0.008892733204487593
====> In quasi-gradient: Iteration: 6000 Successive error: 0.00908612830493478
====> In quasi-gradient: Iteration: 7000 Successive error: 0.0033661705918312606
====> In quasi-gradient: Iteration: 7453 Successive error: 3.067225334389885e-15
recovery error: 3.5396289516830904e-5, reconstruction error: 14.466736641511517


In [68]:
#using Zygote to solve ERM for a synthetic problem


G = Chain(
    Dense(5, 75, relu, bias = false; initW =(out,in) ->  randn(75, 5)/sqrt(75)),
    Dense(75, 75, relu, bias = false; initW =(out,in) -> randn(75, 75)/sqrt(75)),
    Dense(75, 150, relu, bias = false; initW =(out,in) -> randn(150, 75)/sqrt(150))
)


z = randn(5)
m = 5000; A = randn(m, 150)/sqrt(m)
y = sign.(A*G(z)) + 1e-14 * randn(m)

stepsize = .1
tolerance = 1e-7
max_iter = 1000
out_toggle = 1000
lambda = 10

z_rec = randn(5)


for i in 1:max_iter

    d = gradient(z_rec -> norm(G(z_rec),2)^2 - 2*lambda * y'*(A * G(z_rec))/m, z_rec)[1]
    z_rec -= stepsize * d
    succ_error = norm(stepsize * d)
    if i % out_toggle == 0  
        println("====> In quasi-gradient: Iteration: $i Successive error: $succ_error")
    end
end


# z_rec = PLUGIn_onebitCS(G,y,A, max_iter, stepsize, tolerance, out_toggle)
recov_error = norm(z/norm(z) - z_rec/norm(z_rec))
recon_error = norm(G(z) - G(z_rec))
println("recovery error: $recov_error, reconstruction error: $recon_error")

====> In quasi-gradient: Iteration: 1000 Successive error: 5.186242623879012e-5
recovery error: 0.05487163727428996, reconstruction error: 0.22544613623562687
