In [2]:
!uv pip install numba

[2mUsing Python 3.12.9 environment at: /Users/swap357/Documents/dev/learning-numba/.venv[0m
[2K[2mResolved [1m3 packages[0m [2min 99ms[0m[0m                                          [0m
[2K[2mInstalled [1m3 packages[0m [2min 49ms[0m[0m                                [0m
 [32m+[39m [1mllvmlite[0m[2m==0.44.0[0m
 [32m+[39m [1mnumba[0m[2m==0.61.0[0m
 [32m+[39m [1mnumpy[0m[2m==2.1.3[0m


In [3]:
import numpy as np
from numba import njit

In [8]:
arr = np.array([1, 2, 5, 10, 15])
clipped = np.clip(arr, 3, 12)
print(clipped)

[ 3  3  5 10 12]


In [4]:
@njit
def clip_njit():
    return np.clip(np.array([1, 2]), a_min=np.nan, a_max=np.nan)


In [15]:
print(np.clip(np.array([0, 1, 2, 3]), a_min=np.nan, a_max=2))

[nan nan nan nan]


In [14]:
print(clip_njit())

[1 2]


numba’s np.clip is not handling NaN bounds the same way as np does

In [30]:
# _np_clip_impl(a, a_min, a_max, out) 
# https://github.com/numba/numba/blob/c21aa9273ef4298392695b1f4613d29456b53e5c/numba/np/arrayobj.py#L2342
def numba_np_clip(a, a_min, a_max):
    ret = np.empty_like(a)
    a_b, a_min_b, a_max_b = np.broadcast_arrays(a, a_min, a_max)
    for index in np.ndindex(a_b.shape):
        val_a = a_b[index]
        val_a_min = a_min_b[index]
        val_a_max = a_max_b[index]
        ret[index] = min(max(val_a, val_a_min), val_a_max)

    return ret

In [31]:
print("arr:", arr)
print()

print("(arr, a_min=1, a_max=2) :")
print("numba:", numba_np_clip(arr, a_min=1, a_max=2))
print("numpy:", np.clip(arr, a_min=1, a_max=2))
print()

print("(arr, a_min=np.nan, a_max=2) :")
print("numba:", numba_np_clip(arr, a_min=np.nan, a_max=2))
print("numpy:", np.clip(arr, a_min=np.nan, a_max=2))
print()

print("(arr, a_min=1, a_max=np.nan) :")
print("numba:", numba_np_clip(arr, a_min=1, a_max=np.nan))
print("numpy:", np.clip(arr, a_min=1, a_max=np.nan))
print()

print("(arr, a_min=np.nan, a_max=np.nan) :")
print("numba:", numba_np_clip(arr, a_min=np.nan, a_max=np.nan))
print("numpy:", np.clip(arr, a_min=np.nan, a_max=np.nan))

arr: [ 1  2  5 10 15]

(arr, a_min=1, a_max=2) :
numba: [1 2 2 2 2]
numpy: [1 2 2 2 2]

(arr, a_min=np.nan, a_max=2) :
numba: [1 2 2 2 2]
numpy: [nan nan nan nan nan]

(arr, a_min=1, a_max=np.nan) :
numba: [ 1  2  5 10 15]
numpy: [nan nan nan nan nan]

(arr, a_min=np.nan, a_max=np.nan) :
numba: [ 1  2  5 10 15]
numpy: [nan nan nan nan nan]


In [42]:
def revised_np_clip(a, a_min, a_max):
    dtype = np.result_type(a, a_min, a_max)
    print('required dtype for ret based on (a,: ', dtype)
    ret = np.empty(a.shape, dtype=dtype)
    a_b, a_min_b, a_max_b = np.broadcast_arrays(a, a_min, a_max)
    for index in np.ndindex(a_b.shape):
        val_a = a_b[index]
        val_a_min = a_min_b[index]
        val_a_max = a_max_b[index]
        # Propagate NaN if either bound is NaN.
        if np.isnan(val_a_min) or np.isnan(val_a_max):
            ret[index] = np.nan
        else:
            ret[index] = min(max(val_a, val_a_min), val_a_max)

    return ret

In [43]:
print("arr:", arr)
print()

print("(arr, a_min=1, a_max=2) :")
print("revised:", revised_np_clip(arr, a_min=1, a_max=2))
print("numpy:", np.clip(arr, a_min=1, a_max=2))
print()

print("(arr, a_min=np.nan, a_max=2) :")
print("revised:", revised_np_clip(arr, a_min=np.nan, a_max=2))
print("numpy:", np.clip(arr, a_min=np.nan, a_max=2))
print()

print("(arr, a_min=1, a_max=np.nan) :")
print("revised:", revised_np_clip(arr, a_min=1, a_max=np.nan))
print("numpy:", np.clip(arr, a_min=1, a_max=np.nan))
print()

print("(arr, a_min=np.nan, a_max=np.nan) :")
print("revised:", revised_np_clip(arr, a_min=np.nan, a_max=np.nan))
print("numpy:", np.clip(arr, a_min=np.nan, a_max=np.nan))

arr: [ 1  2  5 10 15]

(arr, a_min=1, a_max=2) :
required dtype for ret:  int64
revised: [1 2 2 2 2]
numpy: [1 2 2 2 2]

(arr, a_min=np.nan, a_max=2) :
required dtype for ret:  float64
revised: [nan nan nan nan nan]
numpy: [nan nan nan nan nan]

(arr, a_min=1, a_max=np.nan) :
required dtype for ret:  float64
revised: [nan nan nan nan nan]
numpy: [nan nan nan nan nan]

(arr, a_min=np.nan, a_max=np.nan) :
required dtype for ret:  float64
revised: [nan nan nan nan nan]
numpy: [nan nan nan nan nan]


NameError: name 'a' is not defined