diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 4dc891ee984..3f7da88a793 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -418,6 +418,32 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot: REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name)) +# --------------------------------------------------------------------------- +# Numerical checks +# --------------------------------------------------------------------------- + + +@REGISTRY.register(target=[torch.ops.aten.isnan.default]) +def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.isnan - check for NaN values element-wise. + + isnan(x) is equivalent to x != x (NaN is the only value not equal to itself). + """ + args = P.args(n) + require_args(args, 1, 1, "aten.isnan") + require_kwargs(P.kwargs(n), set(), "aten.isnan") + x = args[0] + out = P.make_or_get_slot(n) + P.emit( + NotEqualNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + _BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ ( [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar], diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index e5ece4931b9..7ba3902e436 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4004,6 +4004,22 @@ def fn(shape, dtype): return fn +def _nan_input_fn(nan_frac: float = 0.3): + """Return a callable(shape, dtype) that generates inputs with some NaN values. + + Args: + nan_frac: Fraction of elements to set to NaN (default 0.3 = 30%). + """ + + def fn(shape, dtype): + x = torch.randn(shape, dtype=dtype) + mask = torch.rand(shape) > (1.0 - nan_frac) + x[mask] = float("nan") + return (x,) + + return fn + + # Standard shape and dtype configs used by unary tests. _SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] _SHAPES_2 = [(16,), (4, 4)] @@ -4095,6 +4111,7 @@ def create_model(self) -> nn.Module: {"op_name": "abs", "op_fn": torch.abs}, {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, + {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},