In [None]:
using LinearAlgebra
using IterTools
using NPZ
using FFTW
using Base.Threads
using PyPlot

println(Threads.nthreads())
println(pwd())

In [None]:
function supercell_2d(frac, nx, ny)
    frac_raw, frac_n = copy(frac), Matrix{Float64}(undef, 0, 2)
    nx_half, ny_half = div(nx, 2), div(ny, 2)
    for (i,j) in product(-nx_half:nx_half, -ny_half:ny_half)
        frac_n = vcat(frac_n, frac_raw .+ [i,j])
    end
    return frac_n
end 

function adjoint_mat(x)
    return det(x)*inv(x)
end

function d_kxkykz(cij, k_vec, ijkl2ij_dict)
    d_mat = Complex.(zeros(3,3))
    for (i,j,m,n) in product(1:3, 1:3, 1:3, 1:3)
        im_denote, jn_denote = ijkl2ij_dict[(i,m)], ijkl2ij_dict[(j,n)]
        km, kn = k_vec[m], k_vec[n]
        d_mat[i,j] += cij[im_denote,jn_denote]*km*kn
    end
    return d_mat
end

function permu_sign(permu, permu_dict)
    if permu in keys(permu_dict)
        return permu_dict[permu]
    else
        return 0
    end
end 

function f_mn(kx, ky, cij, k3_list, ijkl2ij_dict, permu_dict)
    sum_f_13, sum_f_23, sum_g_13, sum_g_23 = 0, 0, 0, 0
    dk3 = diff(k3_list)[1]
    n = 3 #* for σ_m3
    for k3 in k3_list
        if norm([kx, ky, k3]) < 1e-6
            continue
        end
        k_vec = [kx, ky, k3]
        d_mat = d_kxkykz(cij, k_vec, ijkl2ij_dict)
        n_mat = adjoint_mat(d_mat)
        d_det = det(d_mat)

        for (j,l,p,q,s) in product(1:3, 1:3, 1:3, 1:3, 1:3)

            mn_denote_1n, mn_denote_2n = ijkl2ij_dict[(1,n)], ijkl2ij_dict[(2,n)]
            jl_denote = ijkl2ij_dict[(j,l)]
            pq_denote = ijkl2ij_dict[(p,q)]
            f_1s_denote, g_2s_denote = ijkl2ij_dict[(1,s)], ijkl2ij_dict[(2,s)]

            cmnjl_1n, cmnjl_2n = cij[mn_denote_1n,jl_denote], cij[mn_denote_2n,jl_denote]
            cpq1s_f, cpq2s_g = cij[pq_denote,f_1s_denote], cij[pq_denote,g_2s_denote]
            
            kq = k_vec[q]
            n_lp = n_mat[l,p]
            e_js2 = permu_sign((j,s,2), permu_dict)
            e_js1 = permu_sign((j,s,1), permu_dict)

            sum_f_13 += cmnjl_1n*cpq1s_f*kq*n_lp/(2*pi*d_det)*(kx*e_js2-ky*e_js1)*dk3
            sum_f_23 += cmnjl_2n*cpq1s_f*kq*n_lp/(2*pi*d_det)*(kx*e_js2-ky*e_js1)*dk3
            sum_g_13 += cmnjl_1n*cpq2s_g*kq*n_lp/(2*pi*d_det)*(kx*e_js2-ky*e_js1)*dk3
            sum_g_23 += cmnjl_2n*cpq2s_g*kq*n_lp/(2*pi*d_det)*(kx*e_js2-ky*e_js1)*dk3

        end
    end
    return sum_f_13, sum_f_23, sum_g_13, sum_g_23
end

#* some constants
mjm2eva = 6.24150965*1e-5 #* mJ/m^2 to eV/Å^2
eva2gpa = 160.2176621 #* eV/Å^3 to GPa
specie_denote = "" #TODO specie denotion

#* lattice initialize
nx, ny = 0, 0 #TODO nx*ny prim cells
a, c = 0, 0 #TODO lattice constant
ab_angle = pi/2
b = sqrt(a^2+c^2)

#TODO cij part (1120-0001-1100)
cij = npzread("")
ijkl2ij_dict = Dict(
    (1,1)=>1, (2,2)=>2, (3,3)=>3,
    (2,3)=>4, (1,3)=>5, (1,2)=>6,
    (2,1)=>6, (3,1)=>5, (3,2)=>4)
permu_dict = Dict(
    (1,2,3)=>1, (2,3,1)=>1, (3,1,2)=>1,
    (1,3,2)=>-1, (2,1,3)=>-1, (3,2,1)=>-1)
cij = cij/eva2gpa #* convert to eV/Å^3

#* in real lattice the minimum spacing for 1100 projection plane will be b/4
fine_grid_n = 2 #TODO resolution of grid
lx, ly = a*nx/fine_grid_n, c*ny/fine_grid_n
prim_cell_frac = Matrix{Float64}([0.0 0.0])
#* a-c lattice grid
prim_cell_vec = Matrix{Float64}([
    a/fine_grid_n 0;
    0 c/fine_grid_n;
])

len_nx, len_ny = nx, ny 
lattice_grid_2d = zeros((len_nx, len_ny, 2))
half_nx, half_ny = div(nx, 2), div(ny, 2)
for (i, j) in product(-half_nx:half_nx, -half_ny:half_ny)
    lattice_grid_2d[i+half_nx+1, j+half_ny+1,:] .= [i*a/fine_grid_n, j*c/fine_grid_n]
end

#* initialize frequency grid 
len_kx, len_ky = len_nx*2, len_ny*2 #* for mirroring the grid
#* attention that fftw's fftfreq is different from scipy
kx, ky = fftfreq(len_kx, len_nx/lx)*2*pi, fftfreq(len_ky, len_ny/ly)*2*pi

#* grid along k3
kz_lim = 0
kz_sample = 0
kz = range(-kz_lim, kz_lim, length=kz_sample)
kz = Complex.(kz)
f_13_mesh = Complex.(zeros((len_kx, len_ky)))
f_23_mesh = Complex.(zeros((len_kx, len_ky)))
g_13_mesh = Complex.(zeros((len_kx, len_ky)))
g_23_mesh = Complex.(zeros((len_kx, len_ky)))

prod_kxky = collect(product(1:len_kx, 1:len_ky))
Threads.@threads for (ki, kj) in prod_kxky
    kx_val, ky_val = kx[ki], ky[kj]
    f_13, f_23, g_13, g_23 = f_mn(kx_val, ky_val, cij, kz, ijkl2ij_dict, permu_dict)
    f_13_mesh[ki, kj] = f_13
    f_23_mesh[ki, kj] = f_23
    g_13_mesh[ki, kj] = g_13
    g_23_mesh[ki, kj] = g_23
end