In [29]:
using LinearAlgebra
using Zygote
using ChainRulesCore
using Base.Threads

function F2G(F,nintg)
    
    ncells=Int64(sqrt(size(F)[1]))
    npix = ncells÷nintg
    nfreqs=size(F)[2]
    @views sum(reshape(abs2.(F),(nintg,npix,nintg,npix,nfreqs)),dims=(1,3))[1,:,1,:,:]

end

function G2mse(G,W,u,η)
    nλ = size(W)[1]
    m = size(u)[1]÷nλ
    norm(W*(G*u + η) - sum(reshape(u,(m,nλ)),dims=1)[1,:])
end

function loss(G,W,uη)
    println("Don't call me. I do nothing.")
    Nothing
end

function ChainRulesCore.rrule(::typeof(loss), G,W,uη)
    
    ret = 0.0
    n = size(uη)[1]
    function loss_pullback(vec)
        ∂G = zeros(size(G))
        ∂W = zeros(size(W))
        Threads.@threads for (u,η) in uη
            mse,back = Zygote.pullback((x,y)->G2mse(x,y,u,η),G,W)
            ret = ret + mse
            ∂G .+= back(1)[1]
            ∂W .+= back(1)[2]
        end
        NoTangent(), ∂G/n, ∂W/n, ZeroTangent()
    end
    ret/n, loss_pullback
end

In [30]:
M,m,n=25,100,100
W=rand(M,m)
G=rand(m,n)
u=rand(n)
η=rand(m)
G2mse(G,W,u,η)
uη = [(rand(m),rand(n)) for i in 1:10]
println(size(uη)[1])
(u,η) = uη[2]
ret,back = Zygote.pullback(loss,G,W,uη)
back(1)


10


([1.703297536905581 1.477335939744295 … 0.9814537029159867 1.1402620009897428; 1.7403361750995319 1.509402482103013 … 1.0027940118738778 1.165009572492453; … ; 1.4903388174930325 1.2926566905944816 … 0.8586643919158201 0.9975950415594343; 1.4438051414085706 1.2522688798319357 … 0.8319212409742238 0.9665246425227256], [5.135948594112162 5.473425119619742 … 4.751541173340523 4.827891351480451; 4.78148060257274 5.095352787541089 … 4.423594989023461 4.494529199957237; … ; 5.045927817927058 5.377833647056707 … 4.668325271663655 4.743166678056749; 5.0196448859538165 5.349425132100313 … 4.6438339923382435 4.718517327965918], nothing)

In [6]:
a=reshape(1:1:36,(6,6))
display(a)
display(reshape(vec(a),(6,6)))
display(sum(reshape(a,(2,3,2,3)),dims=(1,3))[1,:,1,:])
b = vec(a)
display(size(b))
display(sum(reshape(b,(2,3,2,3)),dims=(1,3))[1,:,1,:])
display(b)
display(sum(reshape(b,(3,12)),dims=1)[1,:])
c = zeros(size(a))
display(c)
ij = [(1,2),(2,2),(3,4)]
for (i,j) in ij
    println((i,j))
end
for CartesianIndex in CartesianIndices(a)
    println((CartesianIndex,CartesianIndex[1],CartesianIndex[2]))
end

6×6 reshape(::StepRange{Int64, Int64}, 6, 6) with eltype Int64:
 1   7  13  19  25  31
 2   8  14  20  26  32
 3   9  15  21  27  33
 4  10  16  22  28  34
 5  11  17  23  29  35
 6  12  18  24  30  36

6×6 reshape(::StepRange{Int64, Int64}, 6, 6) with eltype Int64:
 1   7  13  19  25  31
 2   8  14  20  26  32
 3   9  15  21  27  33
 4  10  16  22  28  34
 5  11  17  23  29  35
 6  12  18  24  30  36

3×3 Matrix{Int64}:
 18  66  114
 26  74  122
 34  82  130

(36,)

3×3 Matrix{Int64}:
 18  66  114
 26  74  122
 34  82  130

1:1:36

12-element Vector{Int64}:
   6
  15
  24
  33
  42
  51
  60
  69
  78
  87
  96
 105

6×6 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0

(1, 2)
(2, 2)
(3, 4)
(CartesianIndex(1, 1), 1, 1)
(CartesianIndex(2, 1), 2, 1)
(CartesianIndex(3, 1), 3, 1)
(CartesianIndex(4, 1), 4, 1)
(CartesianIndex(5, 1), 5, 1)
(CartesianIndex(6, 1), 6, 1)
(CartesianIndex(1, 2), 1, 2)
(CartesianIndex(2, 2), 2, 2)
(CartesianIndex(3, 2), 3, 2)
(CartesianIndex(4, 2), 4, 2)
(CartesianIndex(5, 2), 5, 2)
(CartesianIndex(6, 2), 6, 2)
(CartesianIndex(1, 3), 1, 3)
(CartesianIndex(2, 3), 2, 3)
(CartesianIndex(3, 3), 3, 3)
(CartesianIndex(4, 3), 4, 3)
(CartesianIndex(5, 3), 5, 3)
(CartesianIndex(6, 3), 6, 3)
(CartesianIndex(1, 4), 1, 4)
(CartesianIndex(2, 4), 2, 4)
(CartesianIndex(3, 4), 3, 4)
(CartesianIndex(4, 4), 4, 4)
(CartesianIndex(5, 4), 5, 4)
(CartesianIndex(6, 4), 6, 4)
(CartesianIndex(1, 5), 1, 5)
(CartesianIndex(2, 5), 2, 5)
(CartesianIndex(3, 5), 3, 5)
(CartesianIndex(4, 5), 4, 5)
(CartesianIndex(5, 5), 5, 5)
(CartesianIndex(6, 5), 6, 5)
(CartesianIndex(1, 6), 1, 6)
(CartesianIndex(2, 6), 2, 6)
(CartesianIndex(3, 6), 3, 6)
(CartesianIndex(4, 6),

In [7]:
function f(x,y)
    norm(x.^2 .+ y.^3)
end
f([1,1],[2,3])
Zygote.gradient(f,[1,1],[2,3])

([0.6120183608262413, 1.9040571225705285], [3.6721101649574477, 25.704771154702133])