In [8]:
# Constructs a quantum circuit g with parameters θ, then differentiates the recursive algorithm given in Section 5.1 of https://arxiv.org/abs/1112.2184 to obtain the gradient of p_θ(x) wrt θ, where x is a measurement of g|0>. The differentiation takes polynomial time due to memoization.
# We then compare our results to the finite difference gradient
using Yao, FLOYao
using LinearAlgebra

function create_circuit(nq::Int)
    layers = 2 #Number of brick-wall layers in the circuit
    g = chain(nq)
    for _ in 1:layers
        for i in 1:2:nq-1
            push!(g, rot(kron(nq, i => X, i+1 => X), 0.)) #Nearest-neighbor XX rotation gates
        end
        for i in 2:2:nq-1
            push!(g, rot(kron(nq, i => X, i+1 => Y), 0.)) #Nearest-neighbor XY rotation gates
        end
        for i in 1:nq
            push!(g, put(nq, i => Rz(0.))) #Single qubit Z rotation gates
        end
    end
    return g
end

⊗ = kron

function covariance_matrix(reg::MajoranaReg)
    nq = nqubits(reg)
    G = I(nq) ⊗ [0 1; -1 0]
    return reg.state * G * reg.state'
end

function majoranaindices2kron(nq, i, j) #Returns (im/2)γ_iγ_j, assuming that i≠j
    p = []
    c = (i % 2 == j % 2) ? 1 : -1
    a = min(i, j)
    b = max(i, j)
    first = (a+1) ÷ 2 
    last = (b+1) ÷ 2 
    if first == last #This means i=j-1 and j is even
        c = 1
        push!(p, first => Z)
    else
        if a % 2 == 0
            push!(p, first => X)
            c *= 1
        else
            push!(p, first => Y)
            c *= -1
        end
        for k in first+1:last-1
            push!(p, k => Z)
            c *= -1
        end
        if b % 2 == 0
            push!(p, last => Y)
        else
            push!(p, last => X)
        end
    end
    if i > j
        c *= -1
    end
    return c*kron(nq, p...)
end

function majorana_commutator(nq, i, j) #Returns [γ_i,γ_j]=2γ_iγ_j, due to the anti-commutation of Majorana operators. It needs to be an 'Add' object so that the Yao.expect' function can take it in as input.
    return Add(majoranaindices2kron(nq, i, j)) 
end

function update_opt!(reg::MajoranaReg, theta, b, temp_m, temp_grad_m, probabilities, grad_probabilities) #Evolves all matrices and probabilities and gradients by nq steps, in-place and optimally
    dim = 2*nq
    for i in 1:nq
        t = time()
        if i > 1
            ni = b[i-1]
            cur_prob = probabilities[i-1]
            cur_grad_prob = grad_probabilities[:, i-1]
            cur_prefactor = (-1)^ni / (2*cur_prob)
            cur_grad_prefactor = (-1)^ni / (2*cur_prob^2)
            @inbounds for p in 2*(i-1)+1:dim
                for q in p+1:dim
                    for s in size(temp_grad_m, 1)
                        temp_grad_m[s,p,q] -= cur_grad_prefactor * ((-cur_grad_prob[s] * temp_m[2*(i-1)-1,p] * temp_m[2*(i-1),q]) + (cur_prob * (temp_grad_m[s, 2*(i-1)-1,p] * temp_m[2*(i-1),q] + temp_m[2*(i-1)-1,p] * temp_grad_m[s,2*(i-1),q])))
                        temp_grad_m[s,p,q] += cur_grad_prefactor * ((-cur_grad_prob[s] * temp_m[2*(i-1)-1,q] * temp_m[2*(i-1),p]) + (cur_prob * (temp_grad_m[s, 2*(i-1)-1,q] * temp_m[2*(i-1),p] + temp_m[2*(i-1)-1,q] * temp_grad_m[s,2*(i-1),p])))
                    end
                end
            end
            for p in 2*(i-1)+1:dim
                for q in p+1:dim
                    temp_m[p,q] -= cur_prefactor * (temp_m[2*(i-1)-1,p] * temp_m[2*(i-1),q])
                    temp_m[p,q] += cur_prefactor * (temp_m[2*(i-1)-1,q] * temp_m[2*(i-1),p])
                end
            end
            ni = b[i]
            probabilities[i] = (1+(-1)^ni * temp_m[2*i-1, 2*i]) / 2
            grad_probabilities[:, i] = (-1)^ni * temp_grad_m[:,2*i-1, 2*i] / 2
        else
            dispatch!(g, theta)
            temp_m = covariance_matrix(apply(reg, g))
            ni = b[i]
            probabilities[i] = (1+(-1)^ni * temp_m[2*i-1, 2*i]) / 2
            for p in 1:dim
                for q in p+1:dim
                    ham = majorana_commutator(nq, p, q)
                    temp_grad_m[:,p,q] = expect'(ham, reg => g)[2]
                end
            end
            grad_probabilities[:, i] = (-1)^ni * temp_grad_m[:,2*i-1, 2*i] / 2
        end
        diff = (time() - t)
        t_tot += diff
        println("iteration $i: $diff")
    end
    println("total time: $t_tot")
end

function log_grad_opt(reg::MajoranaReg, theta, b, temp_m, temp_grad_m, probabilities, grad_probabilities) #Returns ∇_θlog(p_θ(b)), evaluated at 'theta' (parameters of circuit) and 'b' (measurement result); 'reg' is the initial register and must be of type MajoranaReg (e.g. FLOYao.zero_state(nq)).
    update_opt!(reg, theta, b, temp_m, temp_grad_m, probabilities, grad_probabilities)
    s = zeros(length(theta))
    for i in 1:nq
        s += grad_probabilities[:, i] / probabilities[i]
    end
    return probabilities, s
end

using Yao.BitBasis
using Flux

function postprocess(g_output::Vector) #turns output of measure  into an Int vector
    result = []
    for i in 1:nq
        push!(result, g_output[1][end - i + 1])
    end
    Int.(result)
end
function d_postprocess(measurement::Vector, nbatch = batchsize)
    aa = breflect.(measurement)
    ret = Matrix(undef, nq, nbatch)
    for i in 1:nbatch
        ret[:,i] = [aa[i]...]
    end
    return ret
end

function g_loss(reg, g, theta, nbatch)
    nq = nqubits(g)
    dispatch!(g, theta)
    measurements = measure(reg, nshots = nbatch)
    discriminator_output = log.(d(d_postprocess(measurements, nbatch)))
    probs = Vector{Float64}(undef, nbatch)
    for i in 1:nbatch
        probs[i] = FLOYao.bitstring_probability(reg, measurements[i])    
    end
    return -discriminator_output * probs |> first #Need the |> first part because it just returns [x] 
end

function reinforce_grad_loss(reg, theta, nbatch)
    dispatch!(g, theta)
    T = Float64
    sampled = Dict{BitStr{nq, BigInt}, Vector{T}}()
    measurements = measure(apply(reg, g), nshots = nbatch)
    discriminator_output = log.(d(d_postprocess(measurements, nbatch)))
    #Initializing temporary matrices and vectors for the optimized version of the algorithm. Note: Do NOT need to reset these temporary matrices at the end of each iteration of the for loop.
    dim = 2*nq
    nparams = nparameters(g)
    temp_m = Matrix{T}(undef, dim, dim)
    temp_grad_m = Array{T}(undef, nparams, dim, dim)
    probabilities = Vector{T}(undef, nq)
    grad_probabilities = Matrix{T}(undef, nparams, nq)
    grad_p = Matrix{T}(undef, nparams, nbatch)
    # println(measurements)
    for i in 1:nbatch
        cur_bitstr = measurements[i]
        if haskey(sampled, cur_bitstr)
            grad_p[:,i] = sampled[cur_bitstr]
        else
            _, log_grad = log_grad_opt(FLOYao.zero_state(nq), theta, cur_bitstr, temp_m, temp_grad_m, probabilities, grad_probabilities)
            grad_p[:,i] = log_grad
            sampled[cur_bitstr] = log_grad
        end
    end
    println(sampled)
    return vec(mean(discriminator_output.*grad_p, dims = 2))
end
mean(x; dims) = sum(x; dims)/length(x)

nq = 9 #Number of qubits
d = Chain(Dense(nq, 10, relu), Dense(10, 1, sigmoid))
nparams = sum(length, Flux.params(d))
println("Number of parameters in critic: $nparams")
g = create_circuit(nq)
p = rand(nparameters(g)).*2π
println(p)
reg = FLOYao.zero_state(nq)
nshots = 100
reinforce_grad_loss(reg, p, nshots)

Number of parameters in critic: 111
[5.073367330561793, 4.726473130123383, 4.834011651447692, 5.043724044369155, 2.4754712054595314, 4.81686463642964, 0.6490460269287002, 3.4713266318819866, 0.8050680552680681, 3.173515681675435, 2.76834123353028, 0.2236472601483282, 1.8346128307720777, 1.8632899515710255, 1.5923104157772212, 0.5765083660910959, 3.481784337951925, 2.693435316045182, 5.018352497489409, 5.022205332744463, 0.46093197343961934, 2.2759022913424225, 4.537635541847899, 4.286201260620129, 5.750435734836762, 3.106819897696939, 3.9162705447552884, 5.9765407227357645, 5.2721479689620265, 4.827027171832567, 0.5470470015137346, 4.3944926373182644, 5.212584008614471, 5.0377628749762]


BenchmarkTools.Trial: 8 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m591.703 ms[22m[39m … [35m765.856 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m9.32% … 8.74%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m682.433 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m9.15%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m675.235 ms[22m[39m ± [32m 53.201 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m9.27% ± 0.52%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m█[34m [39m[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m█[39m█[39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m▁