In [None]:
# default_exp util

# Util ssb

> API details.

In [None]:
#export 
from smpr3d.torch_imports import *
import numpy as np
import cupy as cp
import time
from numba import cuda
from math import *
import math as m
import cmath as cm

In [None]:
#export 

def disk_overlap_function(Qx_all, Qy_all, Kx_all, Ky_all, aberrations, theta_rot, alpha, lam):
    n_batch = Qx_all.shape[0]
    Gamma = cp.zeros((n_batch,) + (Ky_all.shape[0], Kx_all.shape[0]), dtype=cp.complex64)
    gs = Gamma.shape
    threadsperblock = 2 ** 8
    blockspergrid = m.ceil(np.prod(gs) / threadsperblock)
    strides = cp.array((np.array(Gamma.strides) / (Gamma.nbytes / Gamma.size)).astype(np.int))
    disk_overlap_kernel[blockspergrid, threadsperblock](Gamma, strides, Qx_all, Qy_all, Kx_all, Ky_all, aberrations,
                                                        theta_rot, alpha, lam)
    return Gamma


@cuda.jit
def disk_overlap_kernel(Γ, strides, Qx_all, Qy_all, Kx_all, Ky_all, aberrations, theta_rot, alpha, lam):
    def aperture2(qx, qy, lam, alpha_max):
        qx2 = qx ** 2
        qy2 = qy ** 2
        q = m.sqrt(qx2 + qy2)
        ktheta = m.asin(q * lam)
        return ktheta < alpha_max

    def chi3(qy, qx, lam, C):
        """
        Zernike polynomials in the cartesian coordinate system
        :param qx:
        :param qy:
        :param lam: wavelength in Angstrom
        :param C:   (12 ,)
        :return:
        """

        u = qx * lam
        v = qy * lam
        u2 = u ** 2
        u3 = u ** 3
        u4 = u ** 4
        # u5 = u ** 5

        v2 = v ** 2
        v3 = v ** 3
        v4 = v ** 4
        # v5 = v ** 5

        # aberr = Param()
        # aberr.C1 = C[0]
        # aberr.C12a = C[1]
        # aberr.C12b = C[2]
        # aberr.C21a = C[3]
        # aberr.C21b = C[4]
        # aberr.C23a = C[5]
        # aberr.C23b = C[6]
        # aberr.C3 = C[7]
        # aberr.C32a = C[8]
        # aberr.C32b = C[9]
        # aberr.C34a = C[10]
        # aberr.C34b = C[11]

        chi = 0

        # r-2 = x-2 +y-2.
        chi += 1 / 2 * C[0] * (u2 + v2)  # r^2
        # r-2 cos(2*phi) = x"2 -y-2.
        # r-2 sin(2*phi) = 2*x*y.
        chi += 1 / 2 * (C[1] * (u2 - v2) + 2 * C[2] * u * v)  # r^2 cos(2 phi) + r^2 sin(2 phi)
        # r-3 cos(3*phi) = x-3 -3*x*y'2. r"3 sin(3*phi) = 3*y*x-2 -y-3.
        chi += 1 / 3 * (C[5] * (u3 - 3 * u * v2) + C[6] * (3 * u2 * v - v3))  # r^3 cos(3phi) + r^3 sin(3 phi)
        # r-3 cos(phi) = x-3 +x*y-2.
        # r-3 sin(phi) = y*x-2 +y-3.
        chi += 1 / 3 * (C[3] * (u3 + u * v2) + C[4] * (v3 + u2 * v))  # r^3 cos(phi) + r^3 sin(phi)
        # r-4 = x-4 +2*x-2*y-2 +y-4.
        chi += 1 / 4 * C[7] * (u4 + v4 + 2 * u2 * v2)  # r^4
        # r-4 cos(4*phi) = x-4 -6*x-2*y-2 +y-4.
        chi += 1 / 4 * C[10] * (u4 - 6 * u2 * v2 + v4)  # r^4 cos(4 phi)
        # r-4 sin(4*phi) = 4*x-3*y -4*x*y-3.
        chi += 1 / 4 * C[11] * (4 * u3 * v - 4 * u * v3)  # r^4 sin(4 phi)
        # r-4 cos(2*phi) = x-4 -y-4.
        chi += 1 / 4 * C[8] * (u4 - v4)
        # r-4 sin(2*phi) = 2*x-3*y +2*x*y-3.
        chi += 1 / 4 * C[9] * (2 * u3 * v + 2 * u * v3)
        # r-5 cos(phi) = x-5 +2*x-3*y-2 +x*y-4.
        # r-5 sin(phi) = y*x"4 +2*x-2*y-3 +y-5.
        # r-5 cos(3*phi) = x-5 -2*x-3*y-2 -3*x*y-4.
        # r-5 sin(3*phi) = 3*y*x-4 +2*x-2*y-3 -y-5.
        # r-5 cos(5*phi) = x-5 -10*x-3*y-2 +5*x*y-4.
        # r-5 sin(5*phi) = 5*y*x-4 -10*x-2*y-3 +y-5.

        chi *= 2 * np.pi / lam

        return chi

    gs = Γ.shape
    N = gs[0] * gs[1] * gs[2]
    n = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    j = n // strides[0]
    iky = (n - j * strides[0]) // strides[1]
    ikx = (n - (j * strides[0] + iky * strides[1])) // strides[2]

    if n < N:
        Qx = Qx_all[j]
        Qy = Qy_all[j]
        Kx = Kx_all[ikx]
        Ky = Ky_all[iky]

        Qx_rot = Qx * cos(theta_rot) - Qy * sin(theta_rot)
        Qy_rot = Qx * sin(theta_rot) + Qy * cos(theta_rot)

        Qx = Qx_rot
        Qy = Qy_rot

        chi = chi3(Ky, Kx, lam, aberrations)
        A = aperture2(Ky, Kx, lam, alpha) * cm.exp(-1j * chi)
        chi = chi3(Ky + Qy, Kx + Qx, lam, aberrations)
        Ap = aperture2(Ky + Qy, Kx + Qx, lam, alpha) * cm.exp(-1j * chi)
        chi = chi3(Ky - Qy, Kx - Qx, lam, aberrations)
        Am = aperture2(Ky - Qy, Kx - Qx, lam, alpha) * cm.exp(-1j * chi)

        Γ[j, iky, ikx] = A.conjugate() * Am - A * Ap.conjugate()

In [None]:
#export 
def double_overlap_intensitities_in_range(G_max, thetas, Qx_max, Qy_max, Kx, Ky, aberrations,
                                          aberration_angles, alpha_rad, lam, do_plot=False):
    xp = cp.get_array_module(G_max)
    intensities = np.zeros((len(thetas)))
#     plotcxmosaic(G_max.get(), 'G_max')
    # g = G_max.get()
    # f = np.abs(g) / np.max(np.abs(g))
    # plotmosaic(np.angle(g) * f, 'G_max')
    A = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
    Ap = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
    Am = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
    for i, theta_rot in enumerate(thetas):
        start = time.perf_counter()
        if th.cuda.is_available():
            Gamma = disk_overlap_function(Qx_max, Qy_max, Kx, Ky, aberrations, theta_rot, alpha_rad,lam)

        end = time.perf_counter()
        # print(f"disk_overlap_kernel took {end - start}")
        # plotmosaic(np.angle(Gamma.get()), f"Gamma")

#         if do_plot and i % 15 == 0:
            # plotmosaic(np.angle(Gamma.get()), f"Gamma")
#             plotcxmosaic(Gamma.get() * G_max.get(), f"theta_rot = {np.rad2deg(theta_rot)}")

        intensities[i] = xp.sum(xp.abs(G_max * Gamma.conj()))

    if do_plot:
        f, ax = plt.subplots()
        ax.scatter(np.rad2deg(thetas), intensities)
        plt.show()

    return intensities

In [None]:
#export
def find_rotation_angle_with_double_disk_overlap(G, lam, k_max, dxy, alpha_rad, mask=None, n_fit=6, ranges=[360, 30],
                                                 partitions=[144, 120], verbose=False, manual_frequencies=None, aberrations=None):
    """
        Finds the best rotation angle by maximizing the double disk overlap intensity of the 4D dataset. Only valid
        for datasets where the scan step size is roughly on the same length scale as the illumination half-angle alpha.

    :param G: G function. 4DSTEM dataset Fourier transformed along the scan coordinates
    :param lam:
    :param k_max:
    :param dxy:
    :param alpha_rad:
    :param n_fit: number of object spatial frequencies to fit
    :param ranges:
    :param verbose:
    :return: the best rotation angle in radians.
    """
    ny, nx, nky, nkx = G.shape
    xp = cp.get_array_module(G)

    def get_qx_qy_1D(M, dx, dtype, fft_shifted=False):
        qxa = xp.fft.fftfreq(M[0], dx[0]).astype(dtype)
        qya = xp.fft.fftfreq(M[1], dx[1]).astype(dtype)
        if fft_shifted:
            qxa = xp.fft.fftshift(qxa)
            qya = xp.fft.fftshift(qya)
        return qxa, qya

    Kx, Ky = get_qx_qy_1D([nkx, nky], k_max, G[0, 0, 0, 0].real.dtype, fft_shifted=True)
    Qx, Qy = get_qx_qy_1D([nx, ny], dxy, G[0, 0, 0, 0].real.dtype, fft_shifted=False)

    if aberrations is None:
        aberrations = xp.zeros((12))
    aberration_angles = xp.zeros((12))

    if manual_frequencies is None:
        Gabs = xp.sum(xp.abs(G), (2, 3))
        if mask is not None:
            gg = Gabs * mask
#             plot(gg.get(), 'Gabs * mask')
            inds = xp.argsort((gg).ravel()).get()
        else:
            inds = xp.argsort(Gabs.ravel()).get()
        strongest_object_frequencies = np.unravel_index(inds[-1 - n_fit:-1], G.shape[:2])

        G_max = G[strongest_object_frequencies]
        Qy_max = Qy[strongest_object_frequencies[0]]
        Qx_max = Qx[strongest_object_frequencies[1]]
    else:
        strongest_object_frequencies = manual_frequencies
        G_max = G[strongest_object_frequencies]
        Qy_max = Qy[strongest_object_frequencies[0]]
        Qx_max = Qx[strongest_object_frequencies[1]]

    if verbose:
        print(f"strongest_object_frequencies: {strongest_object_frequencies}")

    best_angle = 0

    for j, (range, parts) in enumerate(zip(ranges, partitions)):
        thetas = np.linspace(best_angle - np.deg2rad(range / 2), best_angle + np.deg2rad(range / 2), parts)
        intensities = double_overlap_intensitities_in_range(G_max, thetas, Qx_max, Qy_max, Kx, Ky, aberrations,
                                                            aberration_angles, alpha_rad, lam, do_plot=False)

        sortind = np.argsort(intensities)
        max_ind0 = sortind[-1]
        max_ind1 = sortind[0]
        best_angle = thetas[max_ind0]
        best_angle1 = thetas[max_ind1]
        if verbose:
            A = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
            Ap = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
            Am = xp.zeros(G_max.shape[1:], dtype=xp.complex64)
            print(f"Iteration {j}: current best rotation angle: {np.rad2deg(best_angle)}")
            Gamma = disk_overlap_function(Qx_max, Qy_max, Kx, Ky, aberrations, best_angle, alpha_rad,lam)
#             plotcxmosaic(Gamma.get() * G_max.get(), f"best theta_rot = {np.rad2deg(best_angle)} (maximum)")

            Gamma = disk_overlap_function(Qx_max, Qy_max, Kx, Ky, aberrations, best_angle1, alpha_rad,lam)
#             plotcxmosaic(Gamma.get() * G_max.get(), f"best theta_rot = {np.rad2deg(best_angle1)} (minimum)")

    max_ind = np.argsort(intensities)[-1]

    return max_ind, thetas, intensities