In [9]:
import numba
import numpy as np
import math

import scipy.special

In [69]:
EPS = 1e-8

@numba.jit(nopython=True, fastmath=True, parallel=True)
def digamma(arr, eps=EPS):
    """Digamma function (arr is assumed to be 2-dimensional)"""
    lgamma_prime = np.zeros_like(arr)
    if arr.ndim == 2:
        for j in numba.prange(arr.shape[0]):
            for i in numba.prange(arr.shape[1]):
                lgamma_prime[j, i] = (math.lgamma(arr[j, i] + eps) - math.lgamma(arr[j, i])) / eps
    elif arr.ndim == 1:
        for i in numba.prange(arr.shape[0]):
            lgamma_prime[i] = (math.lgamma(arr[i] + eps) - math.lgamma(arr[i])) / eps
    return lgamma_prime


@numba.jit(nopython=True, fastmath=True, parallel=True)
def expect_alpha(as_po, ar_po):
    """Compute the expectation of alpha"""
    return as_po / ar_po


@numba.jit(nopython=True, fastmath=True)
def outer_func(arr_list):
    dim = len(arr_list)
    res_list = list()
    for i in numba.prange(dim):
        arr = arr_list[i]
        res = np.zeros_like(arr)
        res += digamma(arr)
        res_list.append(res)
    return res_list

In [70]:
arr_list = [np.random.random(size=(2, 5)) * 10.0 for i in range(3)]
arr_list

[array([[0.22672, 6.34059, 9.06638, 1.15469, 6.63486],
        [2.16741, 3.7488 , 8.76838, 4.10292, 8.21289]]),
 array([[9.50271, 6.26632, 7.53082, 4.61954, 7.26448],
        [5.49205, 5.05811, 1.14606, 8.18206, 8.97928]]),
 array([[2.43879, 3.78439, 8.07794, 9.26902, 2.89617],
        [0.78869, 3.27529, 0.32494, 4.44891, 7.52243]])]

In [72]:
arr = arr_list[0]
digamma(arr[:, 0])

array([-4.66646,  0.52545])

In [103]:
np.random.seed(2)
arr = np.random.random(size=(4, 2))

arr_max = np.expand_dims(arr.max(axis=1), 1)
arr -= arr_max

arr = np.exp(arr)
arr /= np.expand_dims(arr.sum(axis=1), 1)

arr.round(2)

array([[0.6 , 0.4 ],
       [0.53, 0.47],
       [0.52, 0.48],
       [0.4 , 0.6 ]])

In [104]:
np.random.seed(2)
arr = np.random.random(size=(4, 2))

arr_max = np.expand_dims(arr.max(axis=1), 0)
arr = (arr.T - arr_max).T

arr = np.exp(arr)
arr /= np.expand_dims(arr.sum(axis=1), 1)

arr.round(2)

array([[0.6 , 0.4 ],
       [0.53, 0.47],
       [0.52, 0.48],
       [0.4 , 0.6 ]])

In [21]:
#@numba.njit()
def func(arr):
    return arr.max()

func(arr)

0.9346108200349625