- 
        Couldn't load subscription status. 
- Fork 146
Closed
Labels
Description
Description
See below:
import pytensor
import pytensor.tensor as pt
y = pt.dvector('y')
nan_mask = pt.isnan(y)
funcs = [pytensor.function([y], nan_mask, mode=mode) for mode in [None, 'JAX', 'NUMBA']]
[f(np.array([np.nan, 1, 2, 3])) for f in funcs]
# C:            [array([ True, False, False, False]),
# JAX:         Array([ True, False, False, False], dtype=bool),
# NUMBA:  array([False, False, False, False])]