# Summary of first half-day:

* Various types of profiler: `time`, `timeit`, `prun`, `lprun`, `memit`
* Loops style coding in `Python`
* Vectorized coding with `NumPy`
* Introduction to the `GIL` and the impossibility to have efficient threads in Python
* Multiprocessing and the communication issue

# Addendum

* Hybrid approaches: Mix loops for the reciprocal space map and `numpy.sum`
  Lowers the memory consumption
* Einsum version with `NumPy`
* Multiprocessing with shared memory
* Cython version using: $ e^{a+b} = e^{a} e^{b} $


In [None]:
# Reload necessary packages, magics, functions and variables
%matplotlib inline 
%load_ext memory_profiler
%load_ext line_profiler

import os
import math
import cmath
import numpy as np
from matplotlib.pyplot import subplots
from matplotlib.colors import LogNorm

# Number of cores to use
if hasattr(os, 'sched_getaffinity'):  # Some Unix only
    # Get the number of cores the Python process has access to
    # This provides the number of cores requested in SLURM
    n_cpu = len(os.sched_getaffinity(0))
else:
    n_cpu = os.cpu_count()
n_cpu = min(4, n_cpu)


# Validation functions

def validate_sq(result):
    "Return the error value"
    reference = np.load("reference_sq.npy")
    return abs(reference-result).max()/reference.max()

def validate_ci(result):
    "Return the error value (for exercises)"
    reference = np.load("reference_ci.npy")
    return abs(np.array(reference)-result).max()/reference.max()

def display(result):
    "Display the array"
    fig, ax = subplots()
    fig.suptitle("Bragg peak")
    ax.imshow(result.T, extent=(h.min(), h.max(), k.min(), k.max()), norm=LogNorm(), origin = 'lower')
    ax.set_xlabel('H');ax.set_ylabel('K')
    ax.set_title(f"Crystal {N}x{N}")

    
# Constants
# Miller index of reflection
H = 0 
# Miller index of reflection
K = 4 
# Number of unit cells per direction
N = 32 
# Defines how many points are needed to describe a single Laue fringe (2 = Nyquist frequency)
oversampling = 3

# Radius of the crystal
R = N/2

# Maximum strain at surface
e0 = 0.01 
# Width of the strain profile below the surface
w = 5.

# Generate real and reciprocal space coordinates
n = np.arange(N)
m = np.arange(N)
h = np.arange(H-0.5, H+0.5, 1./(oversampling*N))
k = np.arange(K-0.5, K+0.5, 1./(oversampling*N))

## Hybrid approaches: Mix loops for the reciprocal space map and `numpy.sum`

In [None]:
# Hybrid method: mix loops with numpy.sum:
import itertools

def laue_hybrid(N, h, k):
    n = np.atleast_2d(np.arange(N))
    m = n.T
    result = np.zeros((h.size, k.size))
    for i_h, v_h in enumerate(h): #loop over the reciprocal space coordinates
        for i_k, v_k in enumerate(k):
            tmp = np.exp(2j*np.pi*(v_h*n + v_k*m)).sum()
            result[i_h, i_k] = abs(tmp)**2
    return result

# Compute and check error
%time intensity = laue_hybrid(N, h, k)
print("Error:", validate_sq(intensity))
display(intensity)
perf_sq_hybrid = %timeit -o laue_hybrid(N, h, k)
%memit laue_hybrid(N, h, k)

## Einsum version with `NumPy`

In [None]:
# Using Einstein's summation: https://ajcr.net/Basic-guide-to-einsum/

def laue_einsum(N, h, k):
    n = np.arange(N)
    m = n
    Fhn = np.exp(2j * np.pi * np.outer(h , n))
    Fkm = np.exp(2j * np.pi * np.outer(k , m))
    F = np.einsum('hn,km->hk', Fhn, Fkm )
    return np.abs(F)**2

%time intensity = laue_einsum(N, h, k)
print("Error:", validate_sq(intensity))
display(intensity)
perf_sq_einsum = %timeit -o laue_einsum(N, h, k)
%memit laue_einsum(N, h, k)               

In [None]:
# Using Einstein's summation for the circular strained cristal

def circ_einsum(N, h, k):
    N_2 = N / 2
    x = np.arange(N) - N_2
    y = np.arange(N) - N_2
    o = np.ones(N)
    xa = x.reshape(-1, 1)*o.reshape(1, -1)  # or np.einsum('i,j', x, o)
    ya = o.reshape(-1, 1)*y.reshape(1, -1)  # or np.einsum('i,j', o, y)
    r2 = xa*xa + ya*ya
    mask = r2 <= N_2**2
    r = np.sqrt(r2[mask])
    s = 1. + np.tanh((r - N_2) / w)
    x = xa[mask] * (1. + e0 * s)
    y = ya[mask] * (1. + e0 * s)
    Fhx = np.exp(2j * np.pi * h.reshape(-1, 1)*x.reshape(1, -1))  # or np.einsum('h,x', h, x))
    Fky = np.exp(2j * np.pi * k.reshape(-1, 1)*y.reshape(1, -1))  # or np.einsum('k,y', k, y))
    F = np.einsum('hx,kx->hk', Fhx, Fky)
    return np.abs(F)**2

%time intensity = circ_einsum(N, h, k)
print("Error:", validate_ci(intensity))
display(intensity)
perf_ci_einsum = %timeit -o circ_einsum(N, h, k)
%memit circ_einsum(N, h, k)  

## Multiprocessing with shared memory

In [None]:
# Using shared memory ...
from multiprocessing.pool import ThreadPool
from multiprocessing import shared_memory
from itertools import product

def laue_mps(N, h, k):
    result = np.zeros((h.size, k.size))
    shm = shared_memory.SharedMemory(create=True, size=result.nbytes)
    shared_result = np.ndarray(result.shape, dtype=result.dtype, buffer=shm.buf)
    def laue_sp(N, i, h, j, k, res):
        n = np.arange(N).reshape(-1, 1)
        m = np.arange(N).reshape(1, -1)
        res[i,j] = np.abs(np.exp(2j*np.pi*(h*n + k*m)).sum())**2
    with ThreadPool(n_cpu) as pool:
        tmp = pool.starmap(laue_sp, ( (N,)+i[0]+i[1]+(shared_result,) for i in product(enumerate(h), enumerate(k))))
    result[:] = shared_result[:] #copy the result from the shared memory
    shm.close()
    shm.unlink
    return result

%time intensity = laue_mps(N, h, k)
print("Error:", validate_sq(intensity))
display(intensity)
perf_sq_mps = %timeit -o laue_mps(N, h, k)
%memit laue_mps(N, h, k)               

## Cython

In [None]:
import os

os.environ["OMP_NUM_THREADS"] = str(n_cpu)
# This enables the %cython mode
%load_ext Cython

In [None]:
%%cython --compile-args=-fopenmp --link-args=-fopenmp -a
#cython: embedsignature=True, language_level=3, binding=True
#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False,
## This is for development:
## cython: profile=True, warn.undeclared=True, warn.unused=True, warn.unused_result=False, warn.unused_arg=True

import numpy as np
from cython.parallel import prange
from libc.math cimport sqrt, pi, tanh

# With Cython3: from libc.complex cimport cabs, cexp
# Accessing C code from cython (out of the scope for today)
cdef extern from "complex.h" nogil:
    double cabs(double complex)
    double complex cexp(double complex)


def laue_cython(int N, 
                double[::1] h, 
                double[::1] k,
                double w,
                double e0):
    cdef:
        double complex[:, ::1] ehn, ekm
        double[:, ::1] result
        double complex tmp, two_j_pi
        double complex v_h, v_k
        double radius, N_2, strain
        int i_h, i_k, m, n, h_size, k_size
        
    two_j_pi = np.pi*2j
    h_size = h.shape[0]
    k_size = k.shape[0]
    ehn = np.empty((h_size, N), dtype=np.complex128)
    ekm = np.empty((k_size, N), dtype=np.complex128)
    result = np.zeros((h_size, k_size))
    
    with nogil:
        for i_h in range(h_size):
            for n in range(N):
                ehn[i_h, n] = cexp(two_j_pi*h[i_h]*n)

        for i_k in range(k_size):
            for m in range(N):
                ekm[i_k, m] = cexp(two_j_pi*k[i_k]*m)

        for i_h in prange(h_size):
            for i_k in range(k_size):
                tmp = 0.0
                for n in range(N):
                    for m in range(N):
                        tmp += ehn[i_h, n] * ekm[i_k, m]
                result[i_h, i_k] += cabs(tmp)**2

    return np.asarray(result)

In [None]:
%time intensity = laue_cython(N, h, k, w, e0)
print("Error:", validate_sq(intensity))
display(intensity)
perf_sq_cython = %timeit -o laue_cython(N, h, k, w, e0)
%memit laue_cython(N, h, k, w, e0)

## Results

For square cristal:

In [None]:
print("                                    Runtime (ms)  Speed-up (x)")
ref = perf_sq_hybrid.best
print(f"Hybrid: for loops+numpy.sum        {1000*perf_sq_hybrid.best:6.1f} ms      {ref/perf_sq_hybrid.best:6.3f}x")
print(f"Numpy Einsum                       {1000*perf_sq_einsum.best:6.1f} ms      {ref/perf_sq_einsum.best:6.3f}x")
print(f"Multiprocessing with shared memory {1000*perf_sq_mps.best:6.1f} ms      {ref/perf_sq_mps.best:6.3f}x")
print(f"Cython with e**(a+b)=e**a+e**b     {1000*perf_sq_cython.best:6.1f} ms      {ref/perf_sq_cython.best:6.3f}x")