In [None]:
# default_exp util

# Loss functions

> API details.


In [None]:
#export 
import numpy as np
import torch as th
import math as m
import numba.cuda as cuda

@cuda.jit
def prox_D_gaussian_kernel(z, z_hat, a, beta, a_strides):
    """

    :param z:           D x K x My x Mx x 2
    :param z_hat:       D x K x My x Mx x 2
    :param a:           D x K x My x Mx
    :param beta:        1
    :param a_strides:   (4,)
    :return:
    """
    n = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    D, K, MY, MX, _ = z.shape
    N1 = D * K * MY * MX

    d = n // a_strides[0]
    k = (n - d * a_strides[0]) // a_strides[1]
    my = (n - d * a_strides[0] - k * a_strides[1]) // a_strides[2]
    mx = (n - d * a_strides[0] - k * a_strides[1] - my * a_strides[2]) // a_strides[3]

    if n < N1:
        z_hatc = z_hat[d, k, my, mx, 0] + 1j * z_hat[d, k, my, mx, 1]
        abs_zhat_c = abs(z_hatc)
        if abs_zhat_c != 0:
            sgn_zhat = z_hatc / abs_zhat_c
            fac = (a[d, k, my, mx] + beta * abs_zhat_c) / (1.0 + beta)
            zc = fac * sgn_zhat
            z[d, k, my, mx, 0] = zc.real
            z[d, k, my, mx, 1] = zc.imag

In [None]:
#export 
def prox_D_gaussian(z, z_hat, a, beta):
    """
    Proximal operator of the Gaussian log-likelihood.

    :param z:           D x K x My x Mx x 2, updated exit waves
    :param z_hat:       D x K x My x Mx x 2, current model exit waves
    :param a:           D x K x My x Mx,     measured amplitudes
    :param beta:        float                hyperparameter

    :return: z
    """
    gpu = cuda.get_current_device()
    stream = th.cuda.current_stream().cuda_stream
    threadsperblock = gpu.MAX_THREADS_PER_BLOCK // 2
    blockspergrid = m.ceil(np.prod(np.array(a.shape)) / threadsperblock)
    strides = th.tensor(a.stride()).to(z.device)
    prox_D_gaussian_kernel[blockspergrid, threadsperblock, stream](z, z_hat, a, beta, strides)
    return z

In [None]:
#export 
@cuda.jit
def gradz_poisson_sparse_kernel(out, z, a_indices, a_counts, no_count_indicator, total_cts):
    """

    :param z:           D x K x My x Mx x 2
    :param z_hat:       D x K x My x Mx x 2
    :param a_indices:   D x K x counts
    :param a_counts:    D x K x counts
    :param beta:        float
    :param frame_dimensions: (2,)
    :param no_count_indicator: float or int
    :return:
    """

    d, k = cuda.grid(2)
    D, K, MY, MX, _ = z.shape
    if d < D and k < K:
        if total_cts[d,k] > 0:
            for i in range(a_indices[d, k].shape[0]):
                idx1d = a_indices[d, k, i]
                if idx1d != no_count_indicator:
                    my = idx1d // MX
                    mx = idx1d - my * MX
                    zc = z[d, k, my, mx, 0] + 1j * z[d, k, my, mx, 1]
                    abs_zc = abs(zc)
                    # if abs_zc != 0:
                    fac = 1 - (a_counts[d, k, i] / (abs_zc**2+1e-2))
                    # else:
                    #     fac = 1 - (a_counts[d, k, i] / 1e-3)
                    zc *= fac
                    out[d, k, my, mx, 0] = zc.real
                    out[d, k, my, mx, 1] = zc.imag

In [None]:
#export 
def gradz_poisson_sparse(out, z, a_indices, a_counts):
    """
    Proximal operator of the Gaussian log-likelihood. Sparse version

    :param z:           D x K x My x Mx x 2, updated exit waves
    :param z_hat:       D x K x My x Mx x 2, current model exit waves
    :param a_indices:   D x K x cts,     measured amplitude indices
    :param a_counts:    D x K x cts,     measured amplitude counts
    :param beta:        float                hyperparameter

    :return: z
    """
    gpu = cuda.get_current_device()
    stream = th.cuda.current_stream().cuda_stream
    threadsperblock = (2, 32)
    blockspergrid = tuple(np.ceil(np.array(z.shape[:2]) / threadsperblock).astype(np.int))
    no_count_indicator = th.iinfo(a_indices.dtype).max
    total_cts = th.sum(a_counts,2)
    gradz_poisson_sparse_kernel[blockspergrid, threadsperblock, stream](out, z, a_indices, a_counts, no_count_indicator,
                                                                        total_cts)
    return z

In [None]:
#export 
@cuda.jit
def gradz_gaussian_sparse_kernel(out, z, a_indices, a_counts, no_count_indicator):
    """

    :param z:           D x K x My x Mx x 2
    :param z_hat:       D x K x My x Mx x 2
    :param a_indices:   D x K x counts
    :param a_counts:    D x K x counts
    :param beta:        float
    :param frame_dimensions: (2,)
    :param no_count_indicator: float or int
    :return:
    """
    d, k = cuda.grid(2)
    D, K, MY, MX, _ = z.shape
    if d < D and k < K:
        for i in range(a_indices[d, k].shape[0]):
            idx1d = a_indices[d, k, i]
            if idx1d != no_count_indicator:
                my = idx1d // MX
                mx = idx1d - my * MX
                zc = z[d, k, my, mx, 0] + 1j * z[d, k, my, mx, 1]
                abs_zc = abs(zc)
                if abs_zc != 0:
                    fac = 1 - (float(a_counts[d, k, i]) / abs_zc)
                else:
                    fac = 1 - (float(a_counts[d, k, i]) / 1e-3)
                zc *= fac
                out[d, k, my, mx, 0] = zc.real
                out[d, k, my, mx, 1] = zc.imag

In [None]:
#export 

def gradz_gaussian_sparse(out, z, a_indices, a_counts):
    """
    Proximal operator of the Gaussian log-likelihood. Sparse version

    :param z:           D x K x My x Mx x 2, updated exit waves
    :param z_hat:       D x K x My x Mx x 2, current model exit waves
    :param a_indices:   D x K x cts,     measured amplitude indices
    :param a_counts:    D x K x cts,     measured amplitude counts
    :param beta:        float                hyperparameter

    :return: z
    """
    gpu = cuda.get_current_device()
    stream = th.cuda.current_stream().cuda_stream
    threadsperblock = (2, 32)
    blockspergrid = tuple(np.ceil(np.array(z.shape[:2]) / threadsperblock).astype(np.int))
    no_count_indicator = th.iinfo(a_indices.dtype).max
    gradz_gaussian_sparse_kernel[blockspergrid, threadsperblock, stream](out, z, a_indices, a_counts, no_count_indicator)
    return z