In [None]:
# default_exp util

# Util Main

> API details.

In [None]:
#export
from numpy.fft import fftfreq
import numpy as np
from smpr3d.torch_imports import *
from numba import cuda
import math as m
import cmath as cm
import sigpy as sp

In [None]:
#export
def fftshift_checkerboard(w, h):
    re = np.r_[w * [-1, 1]]  # even-numbered rows
    ro = np.r_[w * [1, -1]]  # odd-numbered rows
    return np.row_stack(h * (re, ro))

In [None]:
#export
def cartesian_aberrations(qx, qy, lam, C):
    """
    Zernike polynomials in the cartesian coordinate system
    :param qx:
    :param qy:
    :param lam: wavelength in Angstrom
    :param C:   12 x D
    :return:
    """

    u = qx * lam
    v = qy * lam
    u2 = u ** 2
    u3 = u ** 3
    u4 = u ** 4
    # u5 = u ** 5

    v2 = v ** 2
    v3 = v ** 3
    v4 = v ** 4
    # v5 = v ** 5

    aberr = Param()
    aberr.C1 = C[0].unsqueeze(1).unsqueeze(1)
    aberr.C12a = C[1].unsqueeze(1).unsqueeze(1)
    aberr.C12b = C[2].unsqueeze(1).unsqueeze(1)
    aberr.C21a = C[3].unsqueeze(1).unsqueeze(1)
    aberr.C21b = C[4].unsqueeze(1).unsqueeze(1)
    aberr.C23a = C[5].unsqueeze(1).unsqueeze(1)
    aberr.C23b = C[6].unsqueeze(1).unsqueeze(1)
    aberr.C3 = C[7].unsqueeze(1).unsqueeze(1)
    aberr.C32a = C[8].unsqueeze(1).unsqueeze(1)
    aberr.C32b = C[9].unsqueeze(1).unsqueeze(1)
    aberr.C34a = C[10].unsqueeze(1).unsqueeze(1)
    aberr.C34b = C[11].unsqueeze(1).unsqueeze(1)

    chi = 0

    # r-2 = x-2 +y-2.
    chi += 1 / 2 * aberr.C1 * (u2 + v2) # r^2
    #r-2 cos(2*phi) = x"2 -y-2.
    # r-2 sin(2*phi) = 2*x*y.
    chi += 1 / 2 * (aberr.C12a * (u2 - v2) + 2 * aberr.C12b * u * v) # r^2 cos(2 phi) + r^2 sin(2 phi)
    # r-3 cos(3*phi) = x-3 -3*x*y'2. r"3 sin(3*phi) = 3*y*x-2 -y-3.
    chi += 1 / 3 * (aberr.C23a * (u3 - 3 * u * v2) + aberr.C23b * (3 * u2 * v - v3))# r^3 cos(3phi) + r^3 sin(3 phi)
    # r-3 cos(phi) = x-3 +x*y-2.
    # r-3 sin(phi) = y*x-2 +y-3.
    chi += 1 / 3 * (aberr.C21a * (u3 + u * v2) + aberr.C21b * (v3 + u2 * v))# r^3 cos(phi) + r^3 sin(phi)
    # r-4 = x-4 +2*x-2*y-2 +y-4.
    chi += 1 / 4 * aberr.C3 * (u4 + v4 + 2 * u2 * v2)# r^4
    # r-4 cos(4*phi) = x-4 -6*x-2*y-2 +y-4.
    chi += 1 / 4 * aberr.C34a * (u4 - 6 * u2 * v2 + v4)# r^4 cos(4 phi)
    # r-4 sin(4*phi) = 4*x-3*y -4*x*y-3.
    chi += 1 / 4 * aberr.C34b * (4 * u3 * v - 4 * u * v3) # r^4 sin(4 phi)
    # r-4 cos(2*phi) = x-4 -y-4.
    chi += 1 / 4 * aberr.C32a * (u4 - v4)
    # r-4 sin(2*phi) = 2*x-3*y +2*x*y-3.
    chi += 1 / 4 * aberr.C32b * (2 * u3 * v + 2 * u * v3)
    # r-5 cos(phi) = x-5 +2*x-3*y-2 +x*y-4.
    # r-5 sin(phi) = y*x"4 +2*x-2*y-3 +y-5.
    # r-5 cos(3*phi) = x-5 -2*x-3*y-2 -3*x*y-4.
    # r-5 sin(3*phi) = 3*y*x-4 +2*x-2*y-3 -y-5.
    # r-5 cos(5*phi) = x-5 -10*x-3*y-2 +5*x*y-4.
    # r-5 sin(5*phi) = 5*y*x-4 -10*x-2*y-3 +y-5.

    chi *= 2 * np.pi / lam

    return chi

In [None]:
#export
def memory_mb(x, dtype=None):
    if isinstance(x, th.Tensor):
        return x.nelement() * x.element_size() / 2 ** 20
    elif isinstance(x, tuple):
        assert dtype is not None, 'memory_mb: dtype must not be None'
        element_size = th.zeros(1, dtype=dtype).element_size()
        nelement = np.prod(np.asarray(x))
        return nelement * element_size / 2 ** 20


def memory_gb(x, dtype=None):
    if isinstance(x, th.Tensor):
        return x.nelement() * x.element_size() / 2 ** 30
    elif isinstance(x, tuple):
        assert dtype is not None, 'memory_mb: dtype must not be None'
        element_size = th.zeros(1, dtype=dtype).element_size()
        nelement = np.prod(np.asarray(x))
        return nelement * element_size / 2 ** 30

In [None]:
#export
def fourier_coordinates_2D(N, dx=[1.0, 1.0], centered=True):
    qxx = fftfreq(N[1], dx[1])
    qyy = fftfreq(N[0], dx[0])
    if centered:
        qxx += 0.5 / N[1] / dx[1]
        qyy += 0.5 / N[0] / dx[0]
    qx, qy = np.meshgrid(qxx, qyy)
    q = np.array([qy, qx]).astype(np.float32)
    return q

In [None]:
#export
def array_split_divpoints(ary, indices_or_sections, axis=0):
    """
    Split an array into multiple sub-arrays.
    Please refer to the ``split`` documentation.  The only difference
    between these functions is that ``array_split`` allows
    `indices_or_sections` to be an integer that does *not* equally
    divide the axis. For an array of length l that should be split
    into n sections, it returns l % n sub-arrays of size l//n + 1
    and the rest of size l//n.
    See Also
    --------
    split : Split array into multiple sub-arrays of equal size.
    Examples
    --------
    >>> x = np.arange(8.0)
    >>> np.array_split(x, 3)
        [array([0.,  1.,  2.]), array([3.,  4.,  5.]), array([6.,  7.])]
    >>> x = np.arange(7.0)
    >>> np.array_split(x, 3)
        [array([0.,  1.,  2.]), array([3.,  4.]), array([5.,  6.])]
    """
    try:
        Ntotal = ary.shape[axis]
    except AttributeError:
        Ntotal = len(ary)
    try:
        # handle array case.
        Nsections = len(indices_or_sections) + 1
        div_points = [0] + list(indices_or_sections) + [Ntotal]
    except TypeError:
        # indices_or_sections is a scalar, not an array.
        Nsections = int(indices_or_sections)
        if Nsections <= 0:
            raise ValueError('number sections must be larger than 0.')
        Neach_section, extras = divmod(Ntotal, Nsections)
        section_sizes = ([0] +
                         extras * [Neach_section + 1] +
                         (Nsections - extras) * [Neach_section])
        div_points = np.array(section_sizes, dtype=np.intp).cumsum()

    return div_points

In [None]:
#export
def R_factor(z, a, world_size=1):
    """Calculate R error metric = sum(|z-a|)/sum(|a|)."""
    num = th.norm(th.abs(z) - a,p=1)
    denom = th.norm(a,p=1)
    if world_size > 1:
        dist.all_reduce(num)
        dist.all_reduce(denom)
    return num / denom

In [None]:
#export
def distance(z, x):
    """
    Distance of two complex vectors
    :param z: tensor
    :param x: tensor
    :return:
    """
    c = th.vdot(z.ravel(), x.ravel())
    phi = -th.angle(c)
    exp_minus_phi = th.exp(1j * phi)
    p = exp_minus_phi.to(x.device)
    x_hat = x * p
    res = th.norm(z - x_hat,2)
    return res

In [None]:
#export
def rel_dist(z, x):
    """
    Distance of two complex vectors
    :param z: tensor
    :param x: tensor
    :return:
    """
    d = distance(z, x)
    x_norm = th.norm(x,2)
    return d / x_norm

In [None]:
#export
PARAM_PREFIX = 'pars'
class Param(dict):
    """
    Convenience class: a dictionary that gives access to its keys
    through attributes.
    
    Note: dictionaries stored in this class are also automatically converted
    to Param objects:
    >>> p = Param()
    >>> p.x = {}
    >>> p
    Param({})
    
    While dict(p) returns a dictionary, it is not recursive, so it is better in this case
    to use p.todict(). However, p.todict does not check for infinite recursion. So please
    don't store a dictionary (or a Param) inside itself.
    
    BE: Please note also that the recursive behavior of the update function will create
    new references. This will lead inconsistency if other objects refer to dicts or Params
    in the updated Param instance. 
    """
    _display_items_as_attributes = True
    _PREFIX = PARAM_PREFIX

    def __init__(self, __d__=None, **kwargs):
        """
        A Dictionary that enables access to its keys as attributes.
        Same constructor as dict.
        """
        dict.__init__(self)
        if __d__ is not None: self.update(__d__)
        self.update(kwargs)

    def __getstate__(self):
        return self.__dict__.items()

    def __setstate__(self, items):
        for key, val in items:
            self.__dict__[key] = val

    def __repr__(self):
        return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self))

    # def __str__(self):
    #     from .verbose import report
    #     return report(self,depth=7,noheader=True)

    def __setitem__(self, key, value):
        # BE: original behavior modified as implicit conversion may destroy references
        # Use update(value,Convert=True) instead
        # return super(Param, self).__setitem__(key, Param(value) if type(value) == dict else value)
        return super(Param, self).__setitem__(key, value)

    def __getitem__(self, name):
        # item = super(Param, self).__getitem__(name)
        # return Param(item) if type(item) == dict else item
        return super(Param, self).__getitem__(name)

    def __delitem__(self, name):
        return super(Param, self).__delitem__(name)

    def __delattr__(self, name):
        return super(Param, self).__delitem__(name)

    # __getattr__ = __getitem__
    def __getattr__(self, name):
        try:
            return self.__getitem__(name)
        except KeyError as ke:
            raise AttributeError(ke)

    __setattr__ = __setitem__

    def copy(self, depth=0):
        """
        :returns Param: A (recursive) copy of P with depth `depth` 
        """
        d = Param(self)
        if depth > 0:
            for k, v in d.iteritems():
                if isinstance(v, self.__class__): d[k] = v.copy(depth - 1)
        return d

    def __dir__(self):
        """
        Defined to include the keys when using dir(). Useful for
        tab completion in e.g. ipython.
        If you do not wish the dict key's be displayed as attributes
        (although they are still accessible as such) set the class 
        attribute `_display_items_as_attributes` to False. Default is
        True.
        """
        if self._display_items_as_attributes:
            return self.keys()
            # return [item.__dict__.get('name',str(key)) for key,item in self.iteritems()]
        else:
            return []

    def update(self, __d__=None, in_place_depth=0, Convert=False, **kwargs):
        """
        Update Param - almost same behavior as dict.update, except
        that all dictionaries are converted to Param if `Convert` is set 
        to True, and update may occur in-place recursively for other Param
        instances that self refers to.
        
        Parameters
        ----------
        Convert : bool 
                  If True, convert all dict-like values in self also to Param.
                  *WARNING* 
                  This mey result in misdirected references in your environment
        in_place_depth : int 
                  Counter for recursive in-place updates 
                  If the counter reaches zero, the Param to a key is
                  replaced instead of updated
        """

        def _k_v_update(k, v):
            # If an element is itself a dict, convert it to Param
            if Convert and hasattr(v, 'keys'):
                # print 'converting'
                v = Param(v)
            # new key 
            if not k in self:
                self[k] = v
            # If this key already exists and is already dict-like, update it
            elif in_place_depth > 0 and hasattr(v, 'keys') and isinstance(self[k], self.__class__):
                self[k].update(v, in_place_depth - 1)
                """
                if isinstance(self[k],self.__class__):
                    # Param gets recursive in_place updates
                    self[k].update(v, in_place_depth - 1)
                else:
                    # dicts are only updated in-place once
                    self[k].update(v)
                """
            # Otherwise just replace it
            else:
                self[k] = v

        if __d__ is not None:
            if hasattr(__d__, 'keys'):
                # Iterate through dict-like argument
                for k, v in __d__.items():
                    _k_v_update(k, v)

            else:
                # here we assume a (key,value) list.
                for (k, v) in __d__:
                    _k_v_update(k, v)

        for k, v in kwargs.items():
            _k_v_update(k, v)

        return None

    def _to_dict(self, Recursive=False):
        """
        Convert to dictionary (recursively if needed).
        """
        if not Recursive:
            return dict(self)
        else:
            d = dict(self)
            for k, v in d.items():
                if isinstance(v, self.__class__): d[k] = v._to_dict(Recursive)
        return d

    @classmethod
    def _from_dict(cls, dct):
        """
        Make Param from dict. This is similar to the __init__ call
        """
        # p=Param()
        # p.update(dct.copy())
        return Param(dct.copy())


def validate_standard_param(sp, p=None, prefix=None):
    """\
    validate_standard_param(sp) checks if sp follows the standard parameter convention.
    validate_standard_param(sp, p) attemps to check if p is a valid implementation of sp.

    NOT VERY SOPHISTICATED FOR NOW!
    """
    if p is None:
        good = True
        for k, v in sp.iteritems():
            if k.startswith('_'): continue
            if type(v) == type(sp):
                pref = k if prefix is None else '.'.join([prefix, k])
                good &= validate_standard_param(v, prefix=pref)
                continue
            else:
                try:
                    a, b, c = v
                    if prefix is not None:
                        print('    %s.%s = %s' % (prefix, k, str(v)))
                    else:
                        print('    %s = %s' % (k, str(v)))
                except:
                    good = False
                    if prefix is not None:
                        print('!!! %s.%s = %s <--- Incorrect' % (prefix, k, str(v)))
                    else:
                        print('!!! %s = %s <--- Incorrect' % (k, str(v)))

        return good
    else:
        raise RuntimeError('Checking if a param fits with a standard is not yet implemented')


def format_standard_param(p):
    """\
    Pretty-print a Standard Param class.
    """
    lines = []
    if not validate_standard_param(p):
        print('Standard parameter does not')
    for k, v in p.iteritems():
        if k.startswith('_'): continue
        if type(v) == type(p):
            sublines = format_standard_param(v)
            lines += [k + '.' + s for s in sublines]
        else:
            lines += ['%s = %s #[%s] %s' % (k, str(v[1]), v[0], v[2])]
    return lines


def asParam(obj):
    """
    Convert the input to a Param.
    
    Parameters
    ----------
    a : dict_like
        Input structure, in any format that can be converted to a Param.
        
    Returns:
    out : Param
        The Param structure built from a. No copy is done if the input
        is already a Param.  
    """
    return obj if isinstance(obj, Param) else Param(obj)


def make_default(default_dict_or_file):
    """
    convert description dict to a module dict using a possibly verbose Q & A game
    """
    pass

In [None]:
#export

def single_sideband_reconstruction(G, Qx_all, Qy_all, Kx_all, Ky_all, aberrations, theta_rot, alpha_rad,
                                   Ψ_Qp, Ψ_Qp_left_sb, Ψ_Qp_right_sb, eps, lam):
    xp = sp.backend.get_array_module(G)
    threadsperblock = 2 ** 8
    blockspergrid = m.ceil(np.prod(G.shape) / threadsperblock)
    strides = xp.array((np.array(G.strides) / (G.nbytes / G.size)).astype(np.int))
    scale = 1
    single_sideband_kernel[blockspergrid, threadsperblock](G, strides, Qx_all, Qy_all, Kx_all, Ky_all, aberrations,
                                                           theta_rot, alpha_rad, Ψ_Qp, Ψ_Qp_left_sb,
                                                           Ψ_Qp_right_sb, eps, lam, scale)
    xp.cuda.Device(Ψ_Qp.device).synchronize()

@cuda.jit
def single_sideband_kernel(G, strides, Qx_all, Qy_all, Kx_all, Ky_all, aberrations, theta_rot, alpha,
                           Ψ_Qp, Ψ_Qp_left_sb, Ψ_Qp_right_sb, eps, lam, scale):
    def aperture2(qx, qy, lam, alpha_max, scale):
        qx2 = qx ** 2
        qy2 = qy ** 2
        q = m.sqrt(qx2 + qy2)
        ktheta = m.asin(q * lam)
        return (ktheta < alpha_max) * scale

    def chi3(qy, qx, lam, C):
        """
        Zernike polynomials in the cartesian coordinate system
        :param qx:
        :param qy:
        :param lam: wavelength in Angstrom
        :param C:   (12 ,)
        :return:
        """

        u = qx * lam
        v = qy * lam
        u2 = u ** 2
        u3 = u ** 3
        u4 = u ** 4
        # u5 = u ** 5

        v2 = v ** 2
        v3 = v ** 3
        v4 = v ** 4
        # v5 = v ** 5

        # aberr = Param()
        # aberr.C1 = C[0]
        # aberr.C12a = C[1]
        # aberr.C12b = C[2]
        # aberr.C21a = C[3]
        # aberr.C21b = C[4]
        # aberr.C23a = C[5]
        # aberr.C23b = C[6]
        # aberr.C3 = C[7]
        # aberr.C32a = C[8]
        # aberr.C32b = C[9]
        # aberr.C34a = C[10]
        # aberr.C34b = C[11]

        chi = 0

        # r-2 = x-2 +y-2.
        chi += 1 / 2 * C[0] * (u2 + v2)  # r^2
        # r-2 cos(2*phi) = x"2 -y-2.
        # r-2 sin(2*phi) = 2*x*y.
        chi += 1 / 2 * (C[1] * (u2 - v2) + 2 * C[2] * u * v)  # r^2 cos(2 phi) + r^2 sin(2 phi)
        # r-3 cos(3*phi) = x-3 -3*x*y'2. r"3 sin(3*phi) = 3*y*x-2 -y-3.
        chi += 1 / 3 * (C[5] * (u3 - 3 * u * v2) + C[6] * (3 * u2 * v - v3))  # r^3 cos(3phi) + r^3 sin(3 phi)
        # r-3 cos(phi) = x-3 +x*y-2.
        # r-3 sin(phi) = y*x-2 +y-3.
        chi += 1 / 3 * (C[3] * (u3 + u * v2) + C[4] * (v3 + u2 * v))  # r^3 cos(phi) + r^3 sin(phi)
        # r-4 = x-4 +2*x-2*y-2 +y-4.
        chi += 1 / 4 * C[7] * (u4 + v4 + 2 * u2 * v2)  # r^4
        # r-4 cos(4*phi) = x-4 -6*x-2*y-2 +y-4.
        chi += 1 / 4 * C[10] * (u4 - 6 * u2 * v2 + v4)  # r^4 cos(4 phi)
        # r-4 sin(4*phi) = 4*x-3*y -4*x*y-3.
        chi += 1 / 4 * C[11] * (4 * u3 * v - 4 * u * v3)  # r^4 sin(4 phi)
        # r-4 cos(2*phi) = x-4 -y-4.
        chi += 1 / 4 * C[8] * (u4 - v4)
        # r-4 sin(2*phi) = 2*x-3*y +2*x*y-3.
        chi += 1 / 4 * C[9] * (2 * u3 * v + 2 * u * v3)
        # r-5 cos(phi) = x-5 +2*x-3*y-2 +x*y-4.
        # r-5 sin(phi) = y*x"4 +2*x-2*y-3 +y-5.
        # r-5 cos(3*phi) = x-5 -2*x-3*y-2 -3*x*y-4.
        # r-5 sin(3*phi) = 3*y*x-4 +2*x-2*y-3 -y-5.
        # r-5 cos(5*phi) = x-5 -10*x-3*y-2 +5*x*y-4.
        # r-5 sin(5*phi) = 5*y*x-4 -10*x-2*y-3 +y-5.

        chi *= 2 * np.pi / lam

        return chi

    gs = G.shape
    N = gs[0] * gs[1] * gs[2] * gs[3]
    n = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    iqy = n // strides[0]
    iqx = (n - iqy * strides[0]) // strides[1]
    iky = (n - (iqy * strides[0] + iqx * strides[1])) // strides[2]
    ikx = (n - (iqy * strides[0] + iqx * strides[1] + iky * strides[2])) // strides[3]

    if n < N:

        Qx = Qx_all[iqx]
        Qy = Qy_all[iqy]
        Kx = Kx_all[ikx]
        Ky = Ky_all[iky]

        Qx_rot = Qx * m.cos(theta_rot) - Qy * m.sin(theta_rot)
        Qy_rot = Qx * m.sin(theta_rot) + Qy * m.cos(theta_rot)

        Qx = Qx_rot
        Qy = Qy_rot

        A = aperture2(Ky, Kx, lam, alpha, scale) * cm.exp(-1j * chi3(Ky, Kx, lam, aberrations))
        chi_KplusQ = chi3(Ky + Qy, Kx + Qx, lam, aberrations)
        A_KplusQ = aperture2(Ky + Qy, Kx + Qx, lam, alpha, scale) * cm.exp(-1j * chi_KplusQ)
        chi_KminusQ = chi3(Ky - Qy, Kx - Qx, lam, aberrations)
        A_KminusQ = aperture2(Ky - Qy, Kx - Qx, lam, alpha, scale) * cm.exp(-1j * chi_KminusQ)

        Γ = A.conjugate() * A_KminusQ - A * A_KplusQ.conjugate()

        Kplus = sqrt((Kx + Qx) ** 2 + (Ky + Qy) ** 2)
        Kminus = sqrt((Kx - Qx) ** 2 + (Ky - Qy) ** 2)
        K = sqrt(Kx ** 2 + Ky ** 2)
        bright_field = K < alpha / lam
        double_overlap1 = (Kplus < alpha / lam) * bright_field * (Kminus > alpha / lam)
        double_overlap2 = (Kplus > alpha / lam) * bright_field * (Kminus < alpha / lam)

        Γ_abs = abs(Γ)
        take = Γ_abs > eps and bright_field
        if take:
            val = G[iqy, iqx, iky, ikx] * Γ.conjugate()
            cuda.atomic.add(Ψ_Qp.real, (iqy, iqx), val.real)
            cuda.atomic.add(Ψ_Qp.imag, (iqy, iqx), val.imag)
        if double_overlap1:
            val = G[iqy, iqx, iky, ikx] * Γ.conjugate()
            cuda.atomic.add(Ψ_Qp_left_sb.real, (iqy, iqx), val.real)
            cuda.atomic.add(Ψ_Qp_left_sb.imag, (iqy, iqx), val.imag)
        if double_overlap2:
            val = G[iqy, iqx, iky, ikx] * Γ.conjugate()
            cuda.atomic.add(Ψ_Qp_right_sb.real, (iqy, iqx), val.real)
            cuda.atomic.add(Ψ_Qp_right_sb.imag, (iqy, iqx), val.imag)
        if iqx == 0 and iqy == 0:
            val = abs(G[iqy, iqx, iky, ikx]) + 1j * 0
            cuda.atomic.add(Ψ_Qp.real, (iqy, iqx), val.real)
            cuda.atomic.add(Ψ_Qp_left_sb.real, (iqy, iqx), val.real)
            cuda.atomic.add(Ψ_Qp_right_sb.real, (iqy, iqx), val.real)

In [None]:
#export

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import numpy as np
from PIL import Image

def sector_mask(shape, centre, radius, angle_range=(0,360)):
    """
    Return a boolean mask for a circular sector. The start/stop angles in
    `angle_range` should be given in clockwise order.
    """

    x, y = np.ogrid[:shape[0], :shape[1]]
    cx, cy = centre
    tmin, tmax = np.deg2rad(angle_range)

    # ensure stop angle > start angle
    if tmax < tmin:
        tmax += 2 * np.pi

    # convert cartesian --> polar coordinates
    r2 = (x - cx) * (x - cx) + (y - cy) * (y - cy)
    theta = np.arctan2(x - cx, y - cy) - tmin

    # wrap angles between 0 and 2*pi
    theta %= (2 * np.pi)

    # circular mask
    circmask = r2 <= radius * radius

    # angular mask
    anglemask = theta <= (tmax - tmin)

    return circmask * anglemask

In [None]:
#export
from math import sqrt
def wavelength(E_eV):
    emass = 510.99906;  # electron rest mass in keV
    hc = 12.3984244;  # h*c
    lam = hc / m.sqrt(E_eV * 1e-3 * (2 * emass + E_eV * 1e-3))  # in Angstrom
    return lam  


def DOF(alpha, E_eV):
    E0 = E_eV

    # Calculate wavelength and electron interaction parameter
    m = 9.109383 * 10 ** -31
    e = 1.602177 * 10 ** -19
    c = 299792458
    h = 6.62607 * 10 ** -34

    lam = h / sqrt(2 * m * e * E0) / sqrt(1 + e * E0 / 2 / m / c ** 2) * 10 ** 10
    DOF = 2 * lam / alpha ** 2
    return DOF

In [None]:
#export
@cuda.jit
def dense_to_sparse_kernel(dense, indices, counts, frame_dimensions):
    ny, nx = cuda.grid(2)
    NY, NX, MYBIN, MXBIN = dense.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        k = 0
        for mx in range(MX):
            for my in range(MY):
                idx1d = my * MX + mx
                if dense[ny,nx,my,mx] > 0:
                    indices[ny,nx,k] = idx1d
                    counts[ny,nx,k] = dense[ny,nx,my,mx]
                    k += 1                    

def advanced_raster_scan(ny=10, nx=10, fast_axis=1, mirror=[1, 1], theta=0, dy=1, dx=1):
    """
    Generates as raster scan.
    
    Parameters
    ----------
    ny, nx : int
        Number of steps in *y* (vertical) and *x* (horizontal) direction
        *x* is the fast axis
        
    dy, dx : float
        Step size (grid spacinf) in *y* and *x*  
        2
    Returns
    -------
    pos : ndarray
        A (N,2)-array of positions.
        
    Examples
    --------
    """
    iix, iiy = np.indices((nx, ny))
    if fast_axis != 1:
        tmp = iix
        iix = iiy
        iiy = tmp

    # print iix.shape, iiy.shape
    positions = np.array([(dx * i, dy * j) for i, j in zip(iix.ravel(), iiy.ravel())]).astype(np.float32)

    mins = np.array([positions[:, 0].min(), positions[:, 1].min()])
    maxs = np.array([positions[:, 0].max(), positions[:, 1].max()])

    center = mins + (maxs - mins) / 2.0
    positions -= center

    positions[:, 0] *= mirror[0]
    positions[:, 1] *= mirror[1]

    theta_rad = theta / 180.0 * np.pi
    R = np.array([[np.cos(theta_rad), -np.sin(theta_rad)],
                  [np.sin(theta_rad), np.cos(theta_rad)]])
    # rotate counterclockwise by theta
    positions = positions.dot(R)
    mins = np.array([positions[:, 0].min(), positions[:, 1].min()])
    positions -= mins
    return positions.astype(np.float32)

In [None]:
#export
def advanced_raster_scan(ny=10, nx=10, fast_axis=1, mirror=[1, 1], theta=0, dy=1, dx=1):
    """
    Generates as raster scan.
    
    Parameters
    ----------
    ny, nx : int
        Number of steps in *y* (vertical) and *x* (horizontal) direction
        *x* is the fast axis
        
    dy, dx : float
        Step size (grid spacinf) in *y* and *x*  
        
    Returns
    -------
    pos : ndarray
        A (N,2)-array of positions.
        
    Examples
    --------
    """
    iiy, iix = np.indices((ny, nx))
    if fast_axis != 1:
        tmp = iix
        iix = iiy
        iiy = tmp

    # print iix.shape, iiy.shape
    positions = np.array([(dy * i, dx * j) for i, j in zip(iiy.ravel(), iix.ravel())]).astype(np.float32)

    mins = np.array([positions[:, 0].min(), positions[:, 1].min()])
    maxs = np.array([positions[:, 0].max(), positions[:, 1].max()])

    center = mins + (maxs - mins) / 2.0
    positions -= center

    positions[:, 0] *= mirror[0]
    positions[:, 1] *= mirror[1]

    theta_rad = theta / 180.0 * np.pi
    R = np.array([[np.cos(theta_rad), -np.sin(theta_rad)],
                  [np.sin(theta_rad), np.cos(theta_rad)]])
    # rotate counterclockwise by theta
    positions = positions.dot(R)
    mins = np.array([positions[:, 0].min(), positions[:, 1].min()])
    positions -= mins
    return positions.astype(np.float32)

In [None]:
#export

def get_qx_qy_1D(M, dx, dtype, fft_shifted=False):
    xp = sp.backend.get_array_module(dx)
    qxa = xp.fft.fftfreq(M[0], dx[0]).astype(dtype)
    qya = xp.fft.fftfreq(M[1], dx[1]).astype(dtype)
    if fft_shifted:
        qxa = xp.fft.fftshift(qxa)
        qya = xp.fft.fftshift(qya)
    return qxa, qya


def get_qx_qy_2D(M, dx, dtype, fft_shifted=False):
    xp = sp.backend.get_array_module(dx)
    qxa = xp.fft.fftfreq(M[0], dx[0]).astype(dtype)
    qya = xp.fft.fftfreq(M[1], dx[1]).astype(dtype)
    [qxn, qyn] = xp.meshgrid(qxa, qya)
    if fft_shifted:
        qxn = xp.fft.fftshift(qxn)
        qyn = xp.fft.fftshift(qyn)
    return qxn, qyn

Converted 01_util.ipynb.
Converted 10a_fasta.ipynb.


In [None]:
#export
import numpy as np
import torch as th
from math import sqrt

def scatter_add_patches(input: th.Tensor, out: th.Tensor, axes, positions, patch_size, reduce_dim=None) -> th.Tensor:
    """
    Scatter_adds K patches of size :patch_size: at axes [ax1, ax2] into the output tensor. The patches are added at
    positions :positions:. Additionally, several dimensions of the input tensor can be summed, specified by reduce_dims.

    :param input:   K x M1 x M2 at least 3-dimensional tensor of patches,
    :param out: at least two-dimensional tensor that is the scatter_add target
    :param axes: (2,) axes at which to scatter the input
    :param positions: K x 2 LongTensor
    :param patch_size: (2,) LongTensor
    :param reduce_dims: (N,) LongTensor
    :return: out, the target of scatter_add_
    """
    if reduce_dim is not None:
        other1 = th.split(input, 1, dim=reduce_dim)
        other = tuple()
        for one in other1:
            other += (one.squeeze_().contiguous(),)
        # now we have D tensors of shape K B M1 M2 2
    else:
        other = [input]

    r = positions
    s = patch_size
    K = r.shape[0]

    # patches has dimension  K x M1 x M2 x 2

    index0 = th.arange(s[0], device=input.device, dtype=th.long).view(s[0], 1).expand(s[0], s[1])
    index1 = th.arange(s[1], device=input.device, dtype=th.long).view(1, s[1]).expand(s[0], s[1])

    # size is patch_size
    # print(f"strides at axes: {input.stride(axes[0]), input.stride(axes[1])}")
    index = out.stride(axes[0]) * (index0 + r[:, 0].view(K, 1, 1)) + out.stride(axes[1]) * (
            index1 + r[:, 1].view(K, 1, 1))

    # print(f"new index shape: {index.shape}")
    higher_dim_offsets = th.arange(out.stride(axes[1]), device=input.device).view(1, 1, 1, out.stride(axes[1]))
    index = index.view(index.shape[0], index.shape[1], index.shape[2], 1).expand(
        (index.shape[0], index.shape[1], index.shape[2], out.stride(axes[1])))

    # print(higher_dim_offsets, higher_dim_offsets.shape, index.shape)
    index = index + higher_dim_offsets
    # now we have the K x M1 x M2 x 2 indices into the N1 x N2 x 2 array

    # print(f"new index shape: {index.shape}")
    # index = index.view(index.shape[0], 1, index.shape[1], index.shape[2], index.shape[3]).expand(
    #     (index.shape[0], B, index.shape[1], index.shape[2], index.shape[3]))

    # print(f"max index   : {th.max(index.view(-1))}")
    # print(f"len   out   : {out.view(-1).shape[0]}")
    # print(f"index shape : {index.view(-1).shape[0]}")
    # print(f"others shape: {other[0].view(-1).shape[0]}")

    for i, one in enumerate(other):
        # print(f"others [{i}] shape: {one.view(-1).shape[0]}")
        # print(one.shape)
        # print(index.shape)
        out.view(-1).scatter_add_(0, index.view(-1), one.view(-1))
    return out

In [None]:
#export
def gather_patches(input, axes, positions, patch_size, out=None) -> th.Tensor:
    """
    Gathers K patches of size :patch_size: at axes [ax1, ax2] of the input tensor. The patches are collected started at
    K positions pos.

    if :input: is an n-dimensional tensor with size (x_0, x_1, x_2, ..., x_a, x_ax1, x_ax2, x_b, ..., x_{n-1})
    then :out: is an n-dimensional tensor with size  (K, x_0, x_1, x_2, ..., x_a, patch_size[0], patch_size[1], x_3, ..., x_{n-1})

    :param input: at least two-dimensional tensor
    :param axes: axes at which to gather the patches
    :param positions: K x 2 LongTensor
    :param patch_size: (2,) LongTensor
    :param out: n-dimensional tensor with size  (K, x_0, x_1, x_2, ..., x_a, patch_size[0], patch_size[1], x_3, ..., x_{n-1})
    :return:
    """
    # print(f"input shape: {input.shape}")
    # print(f"positions.dtype {positions.dtype}")
    r = positions
    s = patch_size
    K = positions.shape[0]

    # condense all dimensions x_0 ... x_a
    dim0size = th.prod(th.Tensor([input.shape[:axes[0]]])).int().item() if axes[0] > 0 else 1
    view = [dim0size]
    for d in input.shape[axes[0]:]:
        view.append(d)
    y = input.view(th.Size(view)).squeeze()

    index0 = th.arange(s[0], device=input.device, dtype=th.long).view(s[0], 1).expand(s[0], s[1])
    index1 = th.arange(s[1], device=input.device, dtype=th.long).view(1, s[1]).expand(s[0], s[1])

    # size is patch_size
    # print(f"strides at axes: {input.stride(axes[0]), input.stride(axes[1])}")
    # print(index0.dtype, r.dtype)
    index = input.stride(axes[0]) * (index0 + r[:, 0].view(K, 1, 1)) + input.stride(axes[1]) * (
            index1 + r[:, 1].view(K, 1, 1))
    # print(f"new index shape: {index.shape}")
    higher_dim_offsets = th.arange(input.stride(axes[1]), device=input.device).view(1, 1, 1, input.stride(axes[1]))
    index = index.view(index.shape[0], index.shape[1], index.shape[2], 1).expand(
        (index.shape[0], index.shape[1], index.shape[2], input.stride(axes[1])))
    # print(higher_dim_offsets, higher_dim_offsets.shape, index.shape)
    index = index + higher_dim_offsets
    # print(f"new index shape: {index.shape}")
    index = index.view(index.shape[0], 1, index.shape[1], index.shape[2], index.shape[3]).expand(
        (index.shape[0], dim0size, index.shape[1], index.shape[2], index.shape[3]))
    # print(f"new index shape: {index.shape}")
    lower_dim_offset = th.arange(dim0size, device=input.device) * y.stride(0)
    lower_dim_offset = lower_dim_offset.view(1, dim0size, 1, 1, 1).long()
    index = index + lower_dim_offset
    # print(f"new index shape: {index.shape}")
    index = index.contiguous().view(-1)
    out = th.index_select(y.view(-1), 0, index, out=out)

    out_view = (K,)
    for ax in input.shape[:axes[0]]:
        out_view += (ax,)
    out_view += (patch_size[0].item(),)
    out_view += (patch_size[1].item(),)
    for ax in input.shape[axes[1] + 1:]:
        out_view += (ax,)

    out = out.view(out_view)
    return out



In [None]:
#export
def array_split_divpoints(ary, indices_or_sections, axis=0):
    """
    Split an array into multiple sub-arrays.
    Please refer to the ``split`` documentation.  The only difference
    between these functions is that ``array_split`` allows
    `indices_or_sections` to be an integer that does *not* equally
    divide the axis. For an array of length l that should be split
    into n sections, it returns l % n sub-arrays of size l//n + 1
    and the rest of size l//n.
    See Also
    --------
    split : Split array into multiple sub-arrays of equal size.
    Examples
    --------
    >>> x = np.arange(8.0)
    >>> np.array_split(x, 3)
        [array([0.,  1.,  2.]), array([3.,  4.,  5.]), array([6.,  7.])]
    >>> x = np.arange(7.0)
    >>> np.array_split(x, 3)
        [array([0.,  1.,  2.]), array([3.,  4.]), array([5.,  6.])]
    """
    try:
        Ntotal = ary.shape[axis]
    except AttributeError:
        Ntotal = len(ary)
    try:
        # handle array case.
        Nsections = len(indices_or_sections) + 1
        div_points = [0] + list(indices_or_sections) + [Ntotal]
    except TypeError:
        # indices_or_sections is a scalar, not an array.
        Nsections = int(indices_or_sections)
        if Nsections <= 0:
            raise ValueError('number sections must be larger than 0.')
        Neach_section, extras = divmod(Ntotal, Nsections)
        section_sizes = ([0] +
                         extras * [Neach_section + 1] +
                         (Nsections - extras) * [Neach_section])
        div_points = np.array(section_sizes, dtype=np.intp).cumsum()

    return div_points

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_imports.ipynb.
Converted 01_util.ipynb.
Converted 01a_util_plot.ipynb.
Converted 01b_util_kernels.ipynb.
Converted 01c_util_ssb.ipynb.
Converted 01d_util_illumination.ipynb.
Converted 01e_util_io.ipynb.
Converted 06_sparse_data.ipynb.
Converted 10a_fasta.ipynb.
Converted 20_setup.ipynb.
Converted 30_operators.ipynb.
Converted 40_operators.kernels.ipynb.
Converted 50_functional.ipynb.
Converted 60_loss_functions.ipynb.
Converted 90_core.ipynb.
Converted index.ipynb.
