Skip to content

Commit

Permalink
Support for bfloat16 for binary, unary operators in reference impleme…
Browse files Browse the repository at this point in the history
…ntation (onnx#6166)

### Description
Supports bfloat16 binary, unary operations if ml_dtypes is installed.
Partially answer onnx#6151.

### Motivation and Context
numpy does not support bfloat16 natively but pytorch or tensorflow does.
The reference implementation should support that as well.

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Signed-off-by: Xavier Dupré <xadupre@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
xadupre and justinchuby committed Jun 21, 2024
1 parent 159fa47 commit 0e808c5
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 3 deletions.
48 changes: 48 additions & 0 deletions onnx/reference/custom_element_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,58 @@

import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None # type: ignore[assignment]

bfloat16 = np.dtype((np.uint16, {"bfloat16": (np.uint16, 0)}))
float8e4m3fn = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)}))
float8e4m3fnuz = np.dtype((np.uint8, {"e4m3fnuz": (np.uint8, 0)}))
float8e5m2 = np.dtype((np.uint8, {"e5m2": (np.uint8, 0)}))
float8e5m2fnuz = np.dtype((np.uint8, {"e5m2fnuz": (np.uint8, 0)}))
uint4 = np.dtype((np.uint8, {"uint4": (np.uint8, 0)}))
int4 = np.dtype((np.int8, {"int4": (np.int8, 0)}))


def convert_from_ml_dtypes(array: np.ndarray) -> np.ndarray:
"""Detects the type and changes into one of the ONNX
defined custom types when ``ml_dtypes`` is installed.
Args:
array: Numpy array with a dtype from ml_dtypes.
Returns:
numpy array
"""
if not ml_dtypes:
return array
if array.dtype == ml_dtypes.bfloat16:
return array.view(dtype=bfloat16)
return array


def convert_to_ml_dtypes(array: np.ndarray) -> np.ndarray:
"""Detects the type and changes into one of the type
defined in ``ml_dtypes`` if installed.
Args:
array: array
Returns:
numpy Numpy array with a dtype from ml_dtypes.
"""
dt = array.dtype
new_dt = None
if dt == bfloat16 and array.dtype.descr[0][0] == "bfloat16":
assert ml_dtypes, (
f"ml_dtypes is not installed and the tensor cannot "
f"be converted into ml_dtypes.{array.dtype.descr[0][0]}"
)

new_dt = ml_dtypes.bfloat16

if new_dt:
return array.view(dtype=new_dt).reshape(array.shape)

return array
12 changes: 12 additions & 0 deletions onnx/reference/ops/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import numpy as np

from onnx.onnx_pb import NodeProto
from onnx.reference.custom_element_types import (
convert_from_ml_dtypes,
convert_to_ml_dtypes,
)
from onnx.reference.op_run import OpRun, RuntimeTypeError


Expand All @@ -23,13 +27,15 @@ def run(self, x): # type: ignore
Supports only unary operators.
"""
self._log("-- begin %s.run(1 input)", self.__class__.__name__)
x = convert_to_ml_dtypes(x)
try:
res = self._run(x)
except TypeError as e:
raise TypeError(
f"Issues with types {', '.join(str(type(_)) for _ in [x])} "
f"(unary operator {self.__class__.__name__!r})."
) from e
res = (convert_from_ml_dtypes(res[0]),)
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
return self._check_and_fix_outputs(res)

Expand Down Expand Up @@ -79,13 +85,16 @@ def run(self, x, y): # type: ignore
f"(operator '{self.__class__.__name__!r}', "
f"shapes {x.shape}, {y.shape})."
)
x = convert_to_ml_dtypes(x)
y = convert_to_ml_dtypes(y)
try:
res = self._run(x, y)
except (TypeError, ValueError) as e:
raise TypeError(
f"Issues with types {', '.join(str(type(_)) for _ in [x, y])} "
f"(binary operator {self.__class__.__name__!r})."
) from e
res = (convert_from_ml_dtypes(res[0]),)
self._log("-- done %s.run -> %d outputs", self.__class__.__name__, len(res))
return self._check_and_fix_outputs(res)

Expand Down Expand Up @@ -124,7 +133,10 @@ def __init__(
self.numpy_fct = numpy_fct

def _run(self, a, b): # type: ignore
a = convert_to_ml_dtypes(a)
b = convert_to_ml_dtypes(b)
res = (self.numpy_fct(a, b),)
res = (convert_from_ml_dtypes(res[0]),)
return self._check_and_fix_outputs(res)


Expand Down
4 changes: 2 additions & 2 deletions onnx/reference/ops/op_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from onnx.reference.ops._op import OpRunUnaryNum
from onnx.reference.op_run import OpRun


class Identity(OpRunUnaryNum):
class Identity(OpRun):
def _run(self, a): # type: ignore
if a is None:
return (None,)
Expand Down
108 changes: 107 additions & 1 deletion onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np
import parameterized
import version_utils
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_almost_equal

import onnx.reference.custom_element_types as custom
from onnx import (
Expand Down Expand Up @@ -124,6 +124,20 @@ def wrapper(*args, **kwargs):
return wrapper


def skip_if_no_ml_dtypes(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
try:
import ml_dtypes

del ml_dtypes
except ImportError:
raise unittest.SkipTest("ml-dtypes not installed") from None
fn(*args, **kwargs)

return wrapper


def make_sequence_value_info(name, elem_type, shape):
if isinstance(elem_type, int):
return make_tensor_sequence_value_info(name, elem_type, shape)
Expand Down Expand Up @@ -5916,6 +5930,98 @@ class MyReferenceEvaluator(ReferenceEvaluator):
for v in oinf.functions_.values():
self.assertIsInstance(v, MyReferenceEvaluator)

@parameterized.parameterized.expand(
[
("DOUBLE", 0),
("FLOAT", 0),
("FLOAT16", 1e-3),
("BFLOAT16", 1e-2),
("FLOAT8E4M3FN", 1),
("FLOAT8E4M3FNUZ", 0.9),
("FLOAT8E5M2", 0.85),
("FLOAT8E5M2FNUZ", 0.85),
("INT4", 0.5),
("UINT4", 0.5),
]
)
@skip_if_no_ml_dtypes
def test_add_custom_dtype(self, stype, atol):
itype = getattr(TensorProto, stype)
model = make_model(
make_graph(
[
make_node("Cast", ["X"], ["Xc"], to=itype),
make_node("Cast", ["Y"], ["Yc"], to=itype),
make_node("Neg", ["Yc"], ["Ycn"]),
make_node("Add", ["Xc", "Ycn"], ["Zc"]),
make_node("Cast", ["Zc"], ["Z"], to=TensorProto.FLOAT),
],
"nd",
[
make_tensor_value_info("X", TensorProto.FLOAT, [None, None, None]),
make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None]),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, [None, None, None])],
),
opset_imports=[make_opsetid("", 18)],
ir_version=9,
)

ref = ReferenceEvaluator(model)

x = (np.arange(18) / 18).reshape((2, 3, 3)).astype(np.float32)
y = (np.arange(18) / 180).reshape((2, 3, 3)).astype(np.float32)
feeds = dict(X=x, Y=y)
expected = x - y
got = ref.run(None, feeds)[0]
assert_allclose(expected, got, atol=atol)

@parameterized.parameterized.expand(
[
("DOUBLE",),
("FLOAT",),
("FLOAT16",),
("BFLOAT16",),
# Comparison fails with ml_dtypes
# ("FLOAT8E4M3FN", ),
# ("FLOAT8E4M3FNUZ", ),
# ("FLOAT8E5M2", ),
# ("FLOAT8E5M2FNUZ", ),
# ("INT4", ),
# ("UINT4", ),
]
)
@skip_if_no_ml_dtypes
def test_cmp_custom_dtype(self, stype):
itype = getattr(TensorProto, stype)
model = make_model(
make_graph(
[
make_node("Cast", ["X"], ["Xc"], to=itype),
make_node("Cast", ["Y"], ["Yc"], to=itype),
make_node("Greater", ["Xc", "Yc"], ["Zc"]),
make_node("Cast", ["Zc"], ["Z"], to=TensorProto.BOOL),
],
"nd",
[
make_tensor_value_info("X", TensorProto.FLOAT, [None, None, None]),
make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None]),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, [None, None, None])],
),
opset_imports=[make_opsetid("", 18)],
ir_version=9,
)

ref = ReferenceEvaluator(model)

x = (np.arange(18) / 18).reshape((2, 3, 3)).astype(np.float32)
y = ((np.arange(18) - 9) / 18).reshape((2, 3, 3)).astype(np.float32)
feeds = dict(X=x, Y=y)
expected = x >= y
got = ref.run(None, feeds)[0]
assert_almost_equal(expected, got)

def test_scatter_elements_4d(self):
model = make_model(
make_graph(
Expand Down
1 change: 1 addition & 0 deletions requirements-release.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
build
ipython
nbval
ml-dtypes; python_version >= "3.9"
numpy==1.24.3; python_version<"3.12"
numpy==1.26.0; python_version>="3.12"
parameterized
Expand Down

0 comments on commit 0e808c5

Please sign in to comment.