In [11]:
using BSON: @load
using Flux
using Flux: chunk
using Flux.Data: DataLoader
using ImageFiltering
using Images
using ImageIO
using MLDatasets: FashionMNIST
using LinearAlgebra
using MLDatasets
using Plots

In [18]:
function PLUGIn_onebitCS(G, y, A, max_iter, stepsize, tolerance, out_toggle)
    
    (_, z_dim) = size(Flux.params(G[1])[1]);
    W = I(z_dim)
  
    #normalize the weights of the network
    for i in 1:length(G)
        _, s, _ = svd(Flux.params(G[i])[1])
        W = Flux.params(G[i])[1] * W /s[1]
    end
  
    z = randn(z_dim)
    iter = 1
    succ_error = 1
  
    while iter <= max_iter && succ_error > tolerance
      
      # d gives the PLUGIn direction
      d = ( A * G(z) - y .* abs.(A * G(z)) )
      d = W'*A'* d
      z -= stepsize * d
      succ_error = norm(stepsize * d)
      if iter % out_toggle == 0  
          println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
      end
      iter += 1
    end
    println("====> In quasi-gradient: Iteration: $iter Successive error: $succ_error")
  
    return z
  end

PLUGIn_onebitCS (generic function with 1 method)

In [21]:
#setup a synthetic problem
G = Chain(
    Dense(20, 500, relu, bias = false; initW =(out,in) ->  randn(500, 20)/sqrt(500)),
    Dense(500, 500, relu, bias = false; initW =(out,in) -> randn(500, 500)/sqrt(500)),
    Dense(500, 784, relu, bias = false; initW =(out,in) -> randn(784, 500)/sqrt(784))
)


z = randn(20)
m = 5000; A = randn(m, 784)/sqrt(m)
y = sign.(A*G(z)) 

stepsize = 3
tolerance = 1e-14
max_iter = 10000
out_toggle = 1000
z_rec = PLUGIn_onebitCS(G,y,A, max_iter, stepsize, tolerance, out_toggle)
recov_error = norm(z/norm(z) - z_rec/norm(z_rec))
recon_error = norm(G(z) - G(z_rec))
println("recovery error: $recov_error, reconstruction error: $recon_error")

====> In quasi-gradient: Iteration: 1000 Successive error: 0.00032662300985185214
====> In quasi-gradient: Iteration: 2000 Successive error: 0.00016026067369383698
====> In quasi-gradient: Iteration: 3000 Successive error: 0.00014630131913014052
====> In quasi-gradient: Iteration: 4000 Successive error: 0.0001891693574401531
====> In quasi-gradient: Iteration: 5000 Successive error: 0.00010269929519833588
====> In quasi-gradient: Iteration: 6000 Successive error: 0.00010697987395349423
====> In quasi-gradient: Iteration: 7000 Successive error: 0.00010707679978769945
====> In quasi-gradient: Iteration: 8000 Successive error: 9.410321501303891e-5
====> In quasi-gradient: Iteration: 9000 Successive error: 8.454578441782778e-5
====> In quasi-gradient: Iteration: 10000 Successive error: 9.341849983813008e-5
====> In quasi-gradient: Iteration: 10001 Successive error: 9.341849983813008e-5
recovery error: 0.1119972676527201, reconstruction error: 1.1528867860001517


In [6]:
#using Zygote to solve ERM for a synthetic problem
G = Chain(
    Dense(20, 500, relu, bias = false; initW =(out,in) ->  randn(500, 20)/sqrt(500)),
    Dense(500, 500, relu, bias = false; initW =(out,in) -> randn(500, 500)/sqrt(500)),
    Dense(500, 784, relu, bias = false; initW =(out,in) -> randn(784, 500)/sqrt(784))
)


z = randn(20)
m = 5000; A = randn(m, 784)/sqrt(m)
y = sign.(A*G(z)) + 1e-14 * randn(m)

stepsize = 1
tolerance = 1e-14
max_iter = 10000
out_toggle = 1000

loss = norm()



z_rec = PLUGIn_onebitCS(G,y,A, max_iter, stepsize, tolerance, out_toggle)
recov_error = norm(z/norm(z) - z_rec/norm(z_rec))
recon_error = norm(G(z) - G(z_rec))
println("recovery error: $recov_error, reconstruction error: $recon_error")

10-element Vector{Float64}:
  1.0
  1.0
 -1.0
  1.0
  1.0
  1.0
 -1.0
  1.0
  1.0
  1.0