In [75]:
import numpy as np 
import matplotlib.pyplot as plt
from skimage.filters import median
from numba import jit, njit, prange


In [76]:
N_COL, N_ROW = 1000,640
invalid_raio = 0.01

test_arr = np.random.randint(0, 255, (N_COL, N_ROW), dtype=int)
N_invalid = int(N_COL*N_ROW*invalid_raio)
print (f"Total pixel number: {N_COL*N_ROW}")
print (f"Invalide value number: {N_invalid}")

# assign invalid value as -1
xlist = np.random.randint(0, N_COL, N_invalid)
ylist = np.random.randint(0, N_ROW, N_invalid)
test_arr[xlist, ylist] = -1

Total pixel number: 640000
Invalide value number: 6400


In [77]:
def test_raw_loop(input_arr, filter="median",kernel_size=3):
    xlist,ylist = np.where(input_arr<0)
    for ix,iy in zip(xlist,ylist):
        arr_kernel = input_arr[max(ix-kernel_size,0):min(ix+kernel_size+1,N_COL), max(iy-kernel_size,0):min(iy+kernel_size+1,N_ROW)]
        if filter == "median":
            input_arr[ix,iy] = np.median(arr_kernel[arr_kernel>=0])
        elif filter == "mean":
            input_arr[ix,iy] = np.mean(arr_kernel[arr_kernel>=0])
        else:
            raise ValueError("filter should be either median or mean")
    return input_arr

In [78]:
%%timeit

output_arr = test_raw_loop(test_arr, filter="median",kernel_size=3)

1.45 ms ± 68.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [79]:
@jit(nopython=True)
def test_numba_jit(input_arr, filter="median",kernel_size=3):
    for ix,iy in zip(xlist,ylist):
        arr_kernel = input_arr[max(ix-kernel_size,0):min(ix+kernel_size+1,N_COL), max(iy-kernel_size,0):min(iy+kernel_size+1,N_ROW)]
        if filter == "median":
            input_arr[ix,iy] = np.nanmedian(arr_kernel)
        elif filter == "mean":
            input_arr[ix,iy] = np.nanmean(arr_kernel)
        else:
            raise ValueError("filter should be either median or mean")
    return input_arr

test_arr = test_arr.astype(np.uint8)
test_arr[test_arr<0] = 10

In [67]:
%%timeit
output_arr = test_numba_jit(test_arr, filter="median",kernel_size=3)

The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see https://numba.readthedocs.io/en/stable/user/parallel.html#diagnostics for help.

File "../../tmp/ipykernel_2328/3317236485.py", line 2:
<source missing, REPL/exec in use?>



2.96 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [107]:

@njit(parallel=True)
def test_numba_para(input_arr, filter="median",kernel_size=3):
    NX,NY = input_arr.shape
    for i in prange(NX):
        for j in prange(NY):
            if input_arr[i,j]<0:
                collect_list = []
                for k_i in prange(max(i-kernel_size,0),min(i+kernel_size+1,N_COL)):
                    for k_j in prange(max(j-kernel_size,0),min(j+kernel_size+1,N_ROW)):
                        if input_arr[i+k_i,j+k_j]>=0:
                            collect_list.append(input_arr[i+k_i,j+k_j])
                if filter == "median":
                    input_arr[i,j] = np.median(np.array(collect_list))
                elif filter == "mean":
                    input_arr[i,j] = np.mean(np.array(collect_list))
                else:
                    raise ValueError("filter should be either median or mean")
    return input_arr

In [109]:
%%timeit
output_arr = test_numba_para(test_arr, filter="median",kernel_size=3)

2.18 µs ± 35.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
