In [None]:
using LoopVectorization
using LinearAlgebra
import Optim as opt

# Basic Setup

In [None]:
N_species = 3
numbers = [18, 18, 36]
species = ["e↑", "e↓", "h"]
masses  = [0.1, 0.1, 1]
layers  = [1,1,2]

spin_locked = true

rsM = 16
d = 7

T = [1 1/2; 0 √3/2]
L = rsM * √(π * numbers[3]/det(T))
A = L^2 * det(T)

In [None]:
Rk = 7
G = 2π/L * inv(T)'

khex_dist = k -> (abs(k[1]) + abs(k[1] - k[2]) + abs(k[2]))/2
knorm = k -> norm(G*k)

k_mesh = [[i,j] for i in -Rk:Rk for j in -Rk:Rk if khex_dist([i,j]) <= Rk];
kpoints = [G * k for k in k_mesh]
k_mesh_lookup = Dict(k => i for (i,k) in enumerate(k_mesh))

Nk = length(k_mesh);
HF_loop_indices = [(i,j) for i in 1:Nk for j in 1:Nk];

kinetic = zeros(Float32, Nk, Nk, N_species)

for a in 1:N_species
    kinetic[:,:,a] = diagm(knorm.(k_mesh).^2)/(2 * masses[a])
end

Coulomb_q = [knorm(k1-k2) for k1 in k_mesh, k2 in k_mesh];

intra_Coulomb = (q -> q == 0. ? 0. : 1/q).(Coulomb_q) * 2π/A;
inter_Coulomb = -intra_Coulomb .* exp.(-Coulomb_q * d);

all_Coulomb = cat(intra_Coulomb, inter_Coulomb, dims = 3);

In [None]:
hopping_precomputes = zeros(Int16, Nk, Nk, Nk)
    
for i in 1:Nk
    for j in 1:Nk
        for k in 1:Nk
            thing = k_mesh[k] - (k_mesh[i]-k_mesh[j])
            if thing in k_mesh
                hopping_precomputes[i,j,k] = k_mesh_lookup[thing]
            end
        end
    end
end

# HF Setup

In [None]:
CS_mats = [zeros(ComplexF32, Nk, n) for n in numbers]

D_mats  = zeros(ComplexF32, Nk, Nk, N_species)
F_mats  = zeros(ComplexF32, Nk, Nk, N_species)

Dr_mats = zeros(Float32,    Nk, Nk, N_species)
Di_mats = zeros(Float32,    Nk, Nk, N_species)
Fr_mats = zeros(Float32,    Nk, Nk, N_species)
Fi_mats = zeros(Float32,    Nk, Nk, N_species);

In [None]:
#Core Functions
function CSD!(C_mats)
    for a in 1:N_species
        S = cholesky(Hermitian(C_mats[a]'*C_mats[a]))
        
        CS_mats[a] .= C_mats[a] / S
        
        mul!(view(D_mats,:,:,a), CS_mats[a], C_mats[a]')
    end
    
    conj!(D_mats)
    Dr_mats .= real(D_mats)
    Di_mats .= imag(D_mats)
end

function Fock!(;fock = 1)
    Fr_mats .= kinetic
    Fi_mats .= 0.
    
    for a in 1:N_species        
        @tturbo for i in 1:Nk
            for j in 1:Nk
                for k in 1:Nk
                    # i = k'+q, j = k', k = k, l = k - q
                    # Hartree: (k'+q)(k')  (k-q)(k)
                    # Fock:    (k'+q)(k)   (k-q)(k')
                    
                    l = hopping_precomputes[i,j,k]
                    lind = max(1, l)
                    
                    for b in 1:3
                        layer_sep = abs(layers[a] - layers[b]) + 1
                        Coulomb_V = all_Coulomb[i,j,layer_sep]
                        
                        Fr_mats[i,j,a] += Coulomb_V * ifelse(l == 0, 0, Dr_mats[lind, k, b])
                        Fi_mats[i,j,a] += Coulomb_V * ifelse(l == 0, 0, Di_mats[lind, k, b])
                    end
                    
                    Coulomb_V = all_Coulomb[i,j,1] * fock
                    Fr_mats[i,k,a] -= Coulomb_V * ifelse(l == 0, 0, Dr_mats[lind, j, a])
                    Fi_mats[i,k,a] -= Coulomb_V * ifelse(l == 0, 0, Di_mats[lind, j, a])
                end
            end
        end
    end
    
    F_mats .= complex.(Fr_mats, Fi_mats)
end

function energy_calc!(C_mats)
    CSD!(C_mats)
    Fock!()
    
    energy::Float32 = 0.
    
    @tturbo for i in 1:Nk
        for j in 1:Nk
            for a in 1:N_species
                energy += (kinetic[i,j,a] + Fr_mats[i,j,a]) * Dr_mats[i,j,a]
                energy -= Fi_mats[i,j,a] * Di_mats[i,j,a]
            end
        end
    end
    
    return energy/2
end

function grad_calc!(C_mats)
    CSD!(C_mats)
    Fock!()
    
    out = [
        2*(I - conj(D_mats[:,:,a])) * F_mats[:,:,a] * CS_mats[a] for a in 1:N_species
    ]
    
    spin_locked && (out[1] += out[2])
    
    return out
end

function slater_to_vec(C_mats)
    if spin_locked
        return vcat(vec.(C_mats[[1,3]])...)
    end
    
    return  vcat(vec.(C_mats)...)
end

cnumbers = [0, cumsum(numbers)...];

function vec_to_slater(vec)
    if spin_locked
        vec = [
            vec[1:numbers[1] * Nk]
            vec[1:numbers[1] * Nk]
            vec[numbers[1] * Nk+1:end]
        ]
    end
    
    out = [
        reshape(vec[Nk*cnumbers[a]+1:Nk*cnumbers[a+1]], Nk, :) for a in 1:N_species
    ]
    
    return out
end

energy! = x -> energy_calc!(vec_to_slater(x))
grad!   = (G, x) -> copy!(G, slater_to_vec(grad_calc!(vec_to_slater(x))))

In [None]:
C_mats_init  = [0.01*rand(ComplexF32, Nk, n) for n in numbers]

if true
    hole_sq = round(Int, √numbers[3])
    trig_lat = [T * [i,j] * L/hole_sq for i in 0:hole_sq-1 for j in 0:hole_sq-1];
    hole_init = [exp(-1im* k'*x) for k in kpoints, x in trig_lat];

    C_mats_init[3] += hole_init;
end

if true
    inds = sortperm(norm.(kpoints))[1:numbers[1]]
    fermi_init = [i == ind ? 1. : 0. for i in 1:Nk, ind in inds]
    C_mats_init[1] += fermi_init
    C_mats_init[2] += fermi_init
end;

In [None]:
#long_vec = slater_to_vec(C_mats_init);
long_vec = sol.minimizer
optim_options = opt.Options(time_limit=300, f_reltol=1e-7,);
sol = opt.optimize(energy!, grad!, long_vec, opt.ConjugateGradient(),optim_options)