In [None]:
# default_exp util

# Custom modules and functions

> API details.


In [None]:
#export 
from smpr3d.torch_imports import *
from smpr3d.util import cartesian_aberrations, fftshift_checkerboard

class ZernikeProbe2(nn.Module):
    def __init__(self, q: th.Tensor, lam, fft_shifted=True):
        """
        Creates an aberration surface from aberration coefficients. The output is backpropable

        :param q: 2 x M1 x M2 tensor of x coefficients of reciprocal space
        :param lam: wavelength in Angstrom
        :param C: aberration coefficients
        :return: (maximum size of all aberration tensors) x MY x MX
        """

        super(ZernikeProbe2, self).__init__()
        self.q = q
        self.lam = lam
        self.fft_shifted = fft_shifted

        if self.fft_shifted:
            cb = fftshift_checkerboard(self.q.shape[1] // 2, self.q.shape[2] // 2)
            self.cb = th.from_numpy(cb).float().to(q.device)

    def forward(self, C, A):
        chi = cartesian_aberrations(self.q[1], self.q[0], self.lam, C)
        Psi = th.exp(-1j*chi) * A.expand_as(chi)

        if self.fft_shifted:
            Psi = Psi * self.cb

        return Psi
