# Depth section algorithm after H.G. Brown, P.P. Pelz, et al. cite arxiv

>

In [None]:
#default_exp algorithm

In [None]:
#export 

import torch
torch.set_printoptions(precision=10)
import numba.cuda as cuda
import cmath as cm
import numpy as np
import torch as th
import math as m
from smpr3d.util import *
from tqdm import tqdm
import scipy.io as sio
import numpy as np

In [None]:
#export 

import math as m
import cmath as cm
from numba import cuda
@cuda.jit
def propagate_parallax_shift_Smatrix_kernel(S, lam, q, q2, beam_coords, t, out):
    n = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    B, NY, NX = S.shape
    N = B * NY * NX

    b = n // (NY * NX)
    ny = (n - b * (NY * NX)) // NX
    nx = (n - b * (NY * NX) - ny * NX)

    if n < N:
        dy = cm.tan(q[0, beam_coords[b, 0], beam_coords[b, 1]] * lam[0])
        dx = cm.tan(q[1, beam_coords[b, 0], beam_coords[b, 1]] * lam[0])
        phase = cm.pi * lam[0] * q2[ny, nx] + 2 * cm.pi * (q[0, ny, nx] * dy + q[1, ny, nx] * dx)
        val = cm.exp(1j * t[0] * phase)
        Sc = S[b, ny, nx]
        v = Sc * val
        # out[b, ny, nx, 0] = v.real
        # out[b, ny, nx, 1] = v.imag
        out[b, ny, nx] = v


def propagate_parallax_shift_Smatrix(S, lam, q, q2, beam_coords, t, out=None):
    if out is None:
        out = th.zeros_like(S)
    gpu = cuda.get_current_device()
    stream = th.cuda.current_stream().cuda_stream
    threadsperblock = gpu.MAX_THREADS_PER_BLOCK
    blockspergrid = m.ceil(np.prod(np.array(S.shape)) / threadsperblock)
    propagate_parallax_shift_Smatrix_kernel[blockspergrid, threadsperblock, stream](S, lam, q, q2, beam_coords, t, out)
    return out

import torch as th
from tqdm import tqdm
def depth_section(smpr_solution : SMPRSolution, wavelength : float, t : th.tensor):
    """
    Create a depth section from the S-Matrix by propagating each beam back in spack and interfering them. 
    :param smpr_solution: SMPRSolution with a reconstructed (B, NY, NX) S-Matrix and SMeta metadata
    :param wavelength: float, wavelength in Angstrom
    :param t: (T,) 1D tensor of defocus values to create depth section images at
    :return: stack of complex exit waves at depths t
    """
    smeta = smpr_solution.s_matrix_meta
    wt = th.as_tensor(smeta.natural_neighbor_weights).cuda().float()
    s_full = []
    for i, wi in enumerate(tqdm(wt, desc="Interpolating full S-Matrix")):
        s_full_i = th.sum(wi[:, None, None] * smpr_solution.smatrix, 0)
        s_full.append(s_full_i.cpu().numpy())
    s_full = np.array(s_full)
    S = th.as_tensor(s_full).cuda()
    S1 = th.fft.fft2(S, norm='ortho')
    device = S1.device
    lam1 = th.tensor([wavelength], device=device)
    EW = th.zeros(len(t), S.shape[1], S.shape[2], dtype=S.dtype, device=th.device('cpu'))
    for i, T in enumerate(tqdm(t, desc="Optical sectioning")):
        tt = th.tensor([T], device=device)
        SS = propagate_parallax_shift_Smatrix(S1, lam1, smeta.q, smeta.q2, smeta.all_beams_coords, tt)
        EW[i] = th.fft.ifft2(th.sum(SS, axis=0), norm='ortho').cpu()
    return EW

# old code
# @cuda.jit
# def propagate_parallax_shift_Smatrix_kernel(S, lam, q, q2, beam_coords, t, out):
#     n = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
#     B, NY, NX, c = S.shape
#     N = B * NY * NX
# 
#     b = n // (NY * NX)
#     ny = (n - b * (NY * NX)) // NX
#     nx = (n - b * (NY * NX) - ny * NX)
# 
#     if n < N:
#         dy = cm.tan(q[0, beam_coords[b, 0], beam_coords[b, 1]] * lam[0])
#         dx = cm.tan(q[1, beam_coords[b, 0], beam_coords[b, 1]] * lam[0])
#         phase = cm.pi * lam[0] * q2[ny, nx] + 2 * cm.pi * (q[0, ny, nx] * dy + q[0, ny, nx] * dx)
#         val = cm.exp(1j * t[0] * phase)
#         Sc = S[b, ny, nx, 0] + 1j * S[b, ny, nx, 1]
#         v = Sc * val
#         out[b, ny, nx, 0] = v.real
#         out[b, ny, nx, 1] = v.imag
# 
# def propagate_parallax_shift_Smatrix(S, lam, q, q2, beam_coords, t, out=None):
#     if out is None:
#         out = th.zeros_like(S)
#     gpu = cuda.get_current_device()
#     stream = th.cuda.current_stream().cuda_stream
#     threadsperblock = gpu.MAX_THREADS_PER_BLOCK
#     blockspergrid = m.ceil(np.prod(np.array(S.shape[:-1])) / threadsperblock)
#     propagate_parallax_shift_Smatrix_kernel[blockspergrid, threadsperblock, stream](S, lam, q, q2, beam_coords, t, out)
#     return out
# 
# def depth_section(S, beam_indices, beam_coords, lam, dx, output_beam_mask, beam_block_mask, t):
#     """
#     
#     :param S: 
#     :param beam_indices: (B,) integer indices of beams
#     :param beam_coords: (B, 2) integer coordinates of beams
#     :param lam: wavelength
#     :param dx: real-space mapling of S-matrix
#     :param output_beam_mask: mask applied after shifting S-matrix to zero 
#     :param beam_block_mask: mask applied after summing S-matrix
#     :param t: (n_depths,) defocus in Angstrom
#     :return: 
#     """
#     device = S.device
#     dtype = S.dtype
# 
#     # Get array shape of scattering matrix
#     B, Y, X = S.shape[:-1]
#     S1 = th.fft(S, 2, True)
#     q = fourier_coordinates_2D([Y, X], dx, centered=False)
#     q = th.from_numpy(q).type(dtype).to(device)
#     q2 = q[0] ** 2 + q[1] ** 2
#     # Shift all beams to origin for each component of the scattering matrix / remove beam tilts
#     mask = output_beam_mask.to(device)
#     mask2 = beam_block_mask.unsqueeze(0).unsqueeze(-1).to(device)
#     for ib, beam in zip(beam_indices, beam_coords):
#         by, bx = [x.item() for x in beam]
#         mask_shift = th.roll(mask, [-by, -bx], (0, 1))
#         S1[ib] = th.roll(S1[ib], [-by, -bx], [0, 1]) * mask_shift[..., None]
# 
#     beam_coords1 = th.as_tensor(beam_coords, device=device)
#     lam1 = th.tensor([lam], device=device)
# 
#     EW = th.zeros(len(t), Y, X, dtype=dtype, device=th.device('cpu'))
#     for i, T in enumerate(tqdm(t, desc="Optical section")):
#         tt = th.tensor([T], device=device)
#         SS = propagate_parallax_shift_Smatrix(S1, lam1, q, q2, beam_coords1, tt)
#         EW[i] = th.ifft(th.sum(SS * mask2.expand_as(SS), axis=0), 2, True).cpu()
# 
#     return EW