Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,32 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))


# ---------------------------------------------------------------------------
# Numerical checks
# ---------------------------------------------------------------------------


@REGISTRY.register(target=[torch.ops.aten.isnan.default])
def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.isnan - check for NaN values element-wise.

isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
"""
args = P.args(n)
require_args(args, 1, 1, "aten.isnan")
require_kwargs(P.kwargs(n), set(), "aten.isnan")
x = args[0]
out = P.make_or_get_slot(n)
P.emit(
NotEqualNode(
a=P.slot_to_tid(x),
b=P.slot_to_tid(x),
out=P.slot_to_tid(out),
)
)
return out


_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
(
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],
Expand Down
17 changes: 17 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4004,6 +4004,22 @@ def fn(shape, dtype):
return fn


def _nan_input_fn(nan_frac: float = 0.3):
"""Return a callable(shape, dtype) that generates inputs with some NaN values.

Args:
nan_frac: Fraction of elements to set to NaN (default 0.3 = 30%).
"""

def fn(shape, dtype):
x = torch.randn(shape, dtype=dtype)
mask = torch.rand(shape) > (1.0 - nan_frac)
x[mask] = float("nan")
return (x,)

return fn


# Standard shape and dtype configs used by unary tests.
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
_SHAPES_2 = [(16,), (4, 4)]
Expand Down Expand Up @@ -4095,6 +4111,7 @@ def create_model(self) -> nn.Module:
{"op_name": "abs", "op_fn": torch.abs},
{"op_name": "neg", "op_fn": torch.neg},
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
# activations
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},
Expand Down
Loading