In [None]:
import numpy as np
from numba import jit
import timeit
from fortwald.dirlatt import dirvec
from fortwald.dirfast import dirvec as dirfast

In [None]:
def nk( maxki ):
    """Combinatorics, 
    returns number of combis for lattice vectors in 3D,
    (0,0,0), (0,0,1), (0,0,-1), (0,1,0), ..., (0,1,1), (0,1,-1), (0,-1,1), (1,0,1), ..., (1,1,1), ...
    """
    mitnull = 4 * (maxki + 1) ** 3 
    korr = 3 + 6 * maxki + 3 * maxki ** 2
    return mitnull - korr

In [None]:
assert [nk(i) for i in range(3)] == [1, 20, 81]

In [None]:
@jit
def nvec_jit( maxki=3 ):#, maxnk=300, maxr=0):
    """Returns k-vectors for fourier part of ewald 
    
    Arguments:
        maxki (int): max value for kx, ky, kz
        maxnk (int): max # of k vectors, unrestrict with maxnk=0 
        maxr  (int): max length of k, unrestrict with maxr=0
    
    Output:
        k (float): k vectors (dim = 3 x maxnk)
    """
    vec = np.zeros([nk(maxki), 3])
    cnt = 0
    vec[cnt] = (0,0,0)
    cnt += 1
    for i in range(1, maxki + 1):
        vec[cnt], vec[cnt+1] = [0,0,i], [0,0,-i]
        vec[cnt+2], vec[cnt+3] = [0,i,0], [0,-i,0]
        vec[cnt+4], vec[cnt+5] = [i,0,0], [-i,0,0]
        cnt += 6
    for i in range(1, maxki + 1):
        for j in range(1, maxki + 1):
            vec[cnt], vec[cnt+1], vec[cnt+2] = [0,i,j], [0,-i,j], [0,i,-j]
            vec[cnt+3], vec[cnt+4], vec[cnt+5] = [i,0,j], [-i,0,j], [i,0,-j]
            vec[cnt+6], vec[cnt+7], vec[cnt+8] = [i,j,0], [-i,j,0], [i,-j,0]
            cnt += 9
    for i in range(1, maxki + 1):
        for j in range(1, maxki + 1):
            for k in range(1, maxki + 1):
                    vec[cnt] = [i, j, k]
                    vec[cnt+1] = [-i,j,k]
                    vec[cnt+2] = [i,-j,k]
                    vec[cnt+3] = [i,j,-k]
                    cnt += 4
    return vec

In [None]:
def nvec( maxki=3 ):
    vec = np.zeros([nk(maxki), 3], dtype=np.float32)
    cnt = 0
    vec[cnt] = (0,0,0)
    cnt += 1
    for i in range(1, maxki + 1):
        vec[cnt], vec[cnt+1] = [0,0,i], [0,0,-i]
        vec[cnt+2], vec[cnt+3] = [0,i,0], [0,-i,0]
        vec[cnt+4], vec[cnt+5] = [i,0,0], [-i,0,0]
        cnt += 6
    for i in range(1, maxki + 1):
        for j in range(1, maxki + 1):
            vec[cnt], vec[cnt+1], vec[cnt+2] = [0,i,j], [0,-i,j], [0,i,-j]
            vec[cnt+3], vec[cnt+4], vec[cnt+5] = [i,0,j], [-i,0,j], [i,0,-j]
            vec[cnt+6], vec[cnt+7], vec[cnt+8] = [i,j,0], [-i,j,0], [i,-j,0]
            cnt += 9
    for i in range(1, maxki + 1):
        for j in range(1, maxki + 1):
            for k in range(1, maxki + 1):
                    vec[cnt] = [i, j, k]
                    vec[cnt+1] = [-i,j,k]
                    vec[cnt+2] = [i,-j,k]
                    vec[cnt+3] = [i,j,-k]
                    cnt += 4
    return vec

In [None]:
def nvec_if( maxki=3 ):
    
    vec = np.zeros([nk(maxki), 3], dtype=np.float32)
    cnt = 0
    
    for i in range(maxki + 1):
        for j in range(maxki + 1):
            for k in range(maxki + 1):
                    vec[cnt] = [i, j, k]
                    temp = 0
                    for l in range(3):
                        if vec[cnt][l] > 0:
                            vec[cnt+1+temp] = vec[cnt]
                            vec[cnt+1+temp][l] = -vec[cnt][l] 
                            temp += 1
                    cnt += 1 + temp
    return vec

In [None]:
print(nvec_jit(0) == nvec(0))
print(nvec_if(0) == dirvec(0))
print(nvec_if(0) == dirfast(0))

In [None]:
maxki = 15

%timeit nvec_jit(maxki)
%timeit nvec(maxki)
%timeit nvec_if(maxki)
%timeit dirvec(maxki)
%timeit dirfast(maxki)

maxki = 15

| method  | mean/loop | std. dev. |
|---------|-----------|-----------|
| jit     | 1.51 ms   | ±7.96 µs  |
| python  | 13.6 ms   | ±57.9 µs  |
| if      | 39.8 ms   | ±430 µs   |
| fortran | 542 µs    | ±3.26 µs  |
| f. fast | 17.4 µs   | ±158 ns   |

In [None]:
#import itertools as it
#comb = list(it.product(range(1,3), repeat=3))
#print(comb)

In [None]:
from numba import njit

@njit
def gausssum_njit( n ):
    res = 0
    for i in range(n):
        res += i + 1
    return res


print(gausssum_njit(1e9))
%timeit gausssum_njit(1e9)