In [None]:
# default_exp util

# Util kernel

> API details.

In [None]:
#export 
import numpy as np
import torch as th
import sigpy as sp
import numpy as np
from smpr3d.torch_imports import *
import numba.cuda as cuda

@cuda.jit
def center_of_mass_kernel(comx, comy, indices, counts, frame_dimensions, no_count_indicator, qx, qy):
    ny, nx = cuda.grid(2)
    NY, NX, _ = indices.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            if idx1d != no_count_indicator:
                cuda.atomic.add(comy, (ny, nx), counts[ny, nx, i] * qy[my, mx])
                cuda.atomic.add(comx, (ny, nx), counts[ny, nx, i] * qx[my, mx])

In [None]:
#export

from numba import cuda
@cuda.jit
def sparse_to_dense_datacube_kernel_crop(dc, indices, counts, frame_dimensions, bin, start, end, no_count_indicator):
    ny, nx = cuda.grid(2)
    NY, NX, MYBIN, MXBIN = dc.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            if my >= start[0] and mx >= start[1] and my < end[0] and mx < end[1]:
                mybin = (my - start[0]) // bin
                mxbin = (mx - start[1]) // bin
                if idx1d != no_count_indicator:
                    cuda.atomic.add(dc, (ny, nx, mybin, mxbin), counts[ny, nx, i])

In [None]:
#export
def sparse_to_dense_datacube_crop(indices, counts, scan_dimensions, frame_dimensions, center, radius, bin=1):
    xp = sp.backend.get_array_module(indices)
    radius = int(np.ceil(radius / bin) * bin)
    start = center - radius
    end = center + radius
    frame_size = 2 * radius // bin

    dc = xp.zeros((*scan_dimensions, frame_size, frame_size), dtype=indices.dtype)

    threadsperblock = (16, 16)
    blockspergrid = tuple(np.ceil(np.array(indices.shape[:2]) / threadsperblock).astype(np.int))

    no_count_indicator = np.iinfo(indices.dtype).max

    sparse_to_dense_datacube_kernel_crop[blockspergrid, threadsperblock](dc, indices, counts, xp.array(frame_dimensions), bin,
                                                                         start, end, no_count_indicator)
    return dc

In [None]:
#export
@cuda.jit
def sparse_to_dense_datacube_crop_gain_mask_kernel(dc, frames, counts,
                                                   frame_dimensions,
                                                   center_frame, center_data,
                                                   radius_data_int, binning,
                                                   fftshift):
    ny, nx = cuda.grid(2)
    NY, NX, MYBIN, MXBIN = dc.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        for i in range(frames[ny, nx].shape[0]):
            idx1d = frames[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            my_center = my - center_data[0]
            mx_center = mx - center_data[1]
            dist_center = m.sqrt(my_center ** 2 + mx_center ** 2)
            if dist_center < radius_data_int:
                mybin = int(center_frame[0] + my_center // binning)
                mxbin = int(center_frame[1] + mx_center // binning)
                if fftshift:
                    mybin = (mybin - center_frame[0]) % (center_frame[0] * 2)
                    mxbin = (mxbin - center_frame[1]) % (center_frame[1] * 2)
                if (mxbin >= 0 and mybin >= 0):
                    cuda.atomic.add(dc, (ny, nx, mybin, mxbin), counts[ny, nx, i])

In [None]:
#export

def sparse_to_dense_datacube_crop_gain_mask(indices, counts, scan_dimensions, frame_dimensions, center_data, radius_data,
                                            radius_max, binning=1, fftshift=False):
    radius_data_int = int(np.ceil(radius_data / binning) * binning)
    radius_max_int = int(np.ceil(radius_max / binning) * binning)
    frame_size = 2 * radius_max_int // binning
    print(f'radius_data_int : {radius_data_int} ')
    print(f'radius_max_int  : {radius_max_int} ')
    print(f'Dense frame size: {frame_size}x {frame_size}')
    cuda.select_device(1)
    dev = th.device('cuda:1')
    stream = th.cuda.current_stream().cuda_stream
    
    dc0 = np.zeros((scan_dimensions[0],scan_dimensions[1], frame_size, frame_size), dtype=np.uint8) 
    dc = th.zeros((scan_dimensions[0]//2,scan_dimensions[1], frame_size, frame_size), dtype=th.float32, device=dev) 
    
    center_frame = th.tensor([frame_size // 2, frame_size // 2], device=dev)
    fd = th.as_tensor(frame_dimensions, device=dev) 
    center = th.as_tensor(center_data, device=dev) 
    inds = th.as_tensor(indices[:scan_dimensions[0]//2,...], device=dev) 
    cts = th.as_tensor(counts[:scan_dimensions[0]//2,...].astype(np.float32), dtype=th.float32, device=dev) 

    threadsperblock = (16, 16)
    blockspergrid = tuple(np.ceil(np.array(indices.shape[:2]) / threadsperblock).astype(np.int))
# sparse_to_dense_datacube_crop_gain_mask dtypes: int16 uint32 int16 int64
    print('sparse_to_dense_datacube_crop_gain_mask dtypes:',dc.dtype, inds.dtype, cts.dtype, frame_dimensions.dtype)
    
    sparse_to_dense_datacube_crop_gain_mask_kernel[blockspergrid, threadsperblock, stream](dc, inds, cts, fd,
                                                                                   center_frame, center,
                                                                                   radius_data_int, binning,
                                                                                   fftshift)
    
    dc0[:scan_dimensions[0]//2,...] = dc.cpu().type(th.uint8).numpy()
    
    dc[:] = 0 
    inds = th.as_tensor(indices[scan_dimensions[0]//2:,...], device=dev) 
    cts = th.as_tensor(counts[scan_dimensions[0]//2:,...].astype(np.float32), dtype=th.float32, device=dev) 
    
    sparse_to_dense_datacube_crop_gain_mask_kernel[blockspergrid, threadsperblock, stream](dc, inds, cts, fd,
                                                                                   center_frame, center,
                                                                                   radius_data_int, binning,
                                                                                   fftshift)
    dc0[scan_dimensions[0]//2:,...] = dc.cpu().type(th.uint8).numpy()
    cuda.select_device(0)
    return dc0

In [None]:
#export
from numba import cuda
@cuda.jit
def fftshift_kernel(indices, center_frame, scan_dimensions, no_count_indicator):
    ny, nx = cuda.grid(2)
    NY, NX = scan_dimensions
    MY = center_frame[0] * 2
    MX = center_frame[1] * 2
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            mysh = (my - center_frame[0]) % (center_frame[0] * 2)
            mxsh = (mx - center_frame[1]) % (center_frame[1] * 2)
            if idx1d != no_count_indicator:
                indices[ny, nx, i] = mysh * MX + mxsh

In [None]:
#export
from numba import cuda
@cuda.jit
def fftshift_pad_kernel(indices, center_frame, scan_dimensions, new_frame_dimensions, no_count_indicator_old, no_count_indicator_new):
    ny, nx = cuda.grid(2)
    NY, NX = scan_dimensions
    MX = center_frame[1] * 2
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            mysh = (my - center_frame[0]) % (new_frame_dimensions[0])
            mxsh = (mx - center_frame[1]) % (new_frame_dimensions[1])
            if idx1d != no_count_indicator_old:
                indices[ny, nx, i] = mysh * new_frame_dimensions[1] + mxsh
            else:
                indices[ny, nx, i] = no_count_indicator_new

In [None]:
#export
from numba import cuda
@cuda.jit
def virtual_annular_image_kernel(img, indices, counts, radius_inner, radius_outer, center_frame, frame_dimensions, no_count_indicator):
    ny, nx = cuda.grid(2)
    NY, NX, _ = indices.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            my_center = my - center_frame[0]
            mx_center = mx - center_frame[1]
            radius = m.sqrt(my_center ** 2 + mx_center ** 2)
            if radius < radius_outer and radius >= radius_inner and idx1d != no_count_indicator:
                cuda.atomic.add(img, (ny,nx), counts[ny, nx, i])

In [None]:
#export 
@cuda.jit
def crop_symmetric_around_center_kernel(new_frames, old_frames, center_frame, old_frame_dimensions, center_data, radius_data_int):
    ny, nx = cuda.grid(2)
    NY, NX, _ = old_frames.shape
    MY, MX = old_frame_dimensions
    MXnew = center_frame[1] * 2
    if ny < NY and nx < NX:
        k = 0
        for i in range(old_frames[ny, nx].shape[0]):
            idx1d = old_frames[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            my_center = my - center_data[0]
            mx_center = mx - center_data[1]
            dist_center = m.sqrt(my_center ** 2 + mx_center ** 2)
            if dist_center < radius_data_int:
                mybin = int(center_frame[0] + my_center)
                mxbin = int(center_frame[1] + mx_center)
                new_frames[ny, nx, k] = mybin * MXnew + mxbin
                k += 1

# Cell
def crop_symmetric_around_center(old_frames, old_frame_dimensions, center_data, max_radius):
    xp = sp.backend.get_array_module(old_frames)
    max_radius_int = int(max_radius)
    frame_size = 2 * max_radius_int
    center_frame = xp.array([frame_size // 2, frame_size // 2])
    new_frame_dimensions = np.array([frame_size,frame_size])

    threadsperblock = (16, 16)
    blockspergrid = tuple(np.ceil(np.array(old_frames.shape[:2]) / threadsperblock).astype(np.int))

    new_frames = xp.zeros_like(old_frames)
    new_frames[:] = xp.iinfo(new_frames.dtype).max

    crop_symmetric_around_center_kernel[blockspergrid, threadsperblock](new_frames, old_frames, center_frame,
                                                                        xp.array(old_frame_dimensions),
                                                                        xp.array(center_data), max_radius_int)

    max_counts = xp.max(xp.sum(new_frames > 0, 2).ravel())
    # print(f'max counts: {max_counts}')
    res = new_frames[:,:,:max_counts].get()
    return res, new_frame_dimensions

In [None]:
#export 
from numba import cuda
import math as m

@cuda.jit
def rotate_kernel(frames, center_frame, old_frame_dimensions, center_data, no_count_indicator, angle_rad):
    ny, nx = cuda.grid(2)
    NY, NX, _ = frames.shape
    MY, MX = old_frame_dimensions
    MXnew = center_frame[1] * 2
    if ny < NY and nx < NX:
        for i in range(frames[ny, nx].shape[0]):
            idx1d = frames[ny, nx, i]
            if idx1d != no_count_indicator:
                my = idx1d // MX
                mx = idx1d - my * MX
                my_center = my - center_data[0]
                mx_center = mx - center_data[1]
                #rotate
                mx_center_rot = round(mx_center * m.cos(angle_rad) - my_center * m.sin(angle_rad))
                my_center_rot = round(mx_center * m.sin(angle_rad) + my_center * m.cos(angle_rad))
                mybin = int(center_frame[0] + my_center_rot)
                mxbin = int(center_frame[1] + mx_center_rot)
                frames[ny, nx, i] = mybin * MXnew + mxbin

In [None]:
#export

def rotate(frames, old_frame_dimensions, center, angle_rad):
    threadsperblock = (16, 16)
    blockspergrid = tuple(np.ceil(np.array(frames.shape[:2]) / threadsperblock).astype(np.int))
    no_count_indicator = th.iinfo(frames.dtype).max
    new_frames= th.tensor(frames, device=frames.device)
    rotate_kernel[blockspergrid, threadsperblock](new_frames, center, th.tensor(old_frame_dimensions, device=frames.device), 
        th.tensor(center, device=frames.device), no_count_indicator, angle_rad)
    return new_frames.get() 

In [None]:
#export 

@cuda.jit
def sum_kernel(indices, counts, frame_dimensions, sum, no_count_indicator):
    ny, nx = cuda.grid(2)
    NY, NX, _ = indices.shape
    MY, MX = frame_dimensions
    if ny < NY and nx < NX:
        for i in range(indices[ny, nx].shape[0]):
            idx1d = indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            if idx1d != no_count_indicator:
                cuda.atomic.add(sum, (my, mx), counts[ny, nx, i])

In [None]:
#export
def sum_frames(frames, counts, frame_dimensions):
    threadsperblock = (16, 16)
    blockspergrid = tuple(np.ceil(np.array(frames.shape[:2]) / threadsperblock).astype(np.int))

    frames1 = sp.to_device(frames,0)
    xp = sp.backend.get_array_module(frames1)
    sum = xp.zeros(frame_dimensions)
    counts1 = xp.array(counts)
    no_count_indicator = xp.iinfo(frames.dtype).max
    
    sum_kernel[blockspergrid, threadsperblock](frames1, counts1, xp.array(frame_dimensions), sum, no_count_indicator)
    return sum.get()

In [None]:
#export 


@cuda.jit
def rebin_kernel(indices, counts, new_frame_center, old_indices, old_counts, old_frame_center, no_count_indicator,
                 bin_factor):
    ny, nx = cuda.grid(2)
    NY, NX, _ = indices.shape
    MY = old_frame_center[0] * 2
    MX = old_frame_center[1] * 2
    MXnew = new_frame_center[1] * 2
    if ny < NY and nx < NX:
        k = 0
        for i in range(old_indices[ny, nx].shape[0]):
            idx1d = old_indices[ny, nx, i]
            my = idx1d // MX
            mx = idx1d - my * MX
            my_center = my - old_frame_center[0]
            mx_center = mx - old_frame_center[1]
            if idx1d != no_count_indicator:
                mybin = int(new_frame_center[0] + my_center // bin_factor)
                mxbin = int(new_frame_center[1] + mx_center // bin_factor)
                indices[ny, nx, k] = mybin * MXnew + mxbin
                counts[ny, nx, k] = old_counts[ny, nx, i]
                k += 1