From bbf3095fb92730cb5e18b5f6f218eb31ec94d70c Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 10 Sep 2025 15:30:18 -0700 Subject: [PATCH] Support BFloat16 in exir (#14164) Summary: This diff adds support for BFloat16 dtype in ExecuTorch's exir dialect by updating the type system and operator constraints. Previously, BFloat16 inputs required disabling IR validity checks with `_check_ir_validity=False` to work around the lack of proper type system support. The changes include: - Adding `torch.bfloat16: "BFloat16"` mapping in the supported dtypes configuration - Updating edge.yaml to include BFloat16 anywhere that Half is supported Reviewed By: JacobSzwejbka Differential Revision: D82129165 --- exir/dialects/edge/dtype/supported.py | 1 + exir/dialects/edge/edge.yaml | 422 +++++++++++------------ exir/dialects/edge/test/test_edge_ops.py | 6 +- exir/tests/test_arg_validator.py | 14 +- exir/tests/test_verification.py | 13 +- 5 files changed, 233 insertions(+), 223 deletions(-) diff --git a/exir/dialects/edge/dtype/supported.py b/exir/dialects/edge/dtype/supported.py index 89c269de1d0..810966220af 100644 --- a/exir/dialects/edge/dtype/supported.py +++ b/exir/dialects/edge/dtype/supported.py @@ -23,6 +23,7 @@ torch.int32: "Int", torch.int64: "Long", torch.float16: "Half", + torch.bfloat16: "BFloat16", torch.float: "Float", torch.double: "Double", torch.uint16: "UInt16", diff --git a/exir/dialects/edge/edge.yaml b/exir/dialects/edge/edge.yaml index 6970f7b00cc..2dac0741007 100644 --- a/exir/dialects/edge/edge.yaml +++ b/exir/dialects/edge/edge.yaml @@ -7,7 +7,7 @@ namespace: edge inherits: aten::_log_softmax type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - self: T0 __ret_0: T0 @@ -16,7 +16,7 @@ namespace: edge inherits: aten::_softmax type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - self: T0 __ret_0: T0 @@ -26,12 +26,12 @@ inherits: aten::_to_copy type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -72,7 +72,7 @@ namespace: edge inherits: aten::abs type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -82,7 +82,7 @@ inherits: aten::acos type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -95,7 +95,7 @@ inherits: aten::acosh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -117,7 +117,7 @@ T7: [Double] T8: [Float] T9: [Float, Int] - T10: [Half] + T10: [Half, BFloat16] T11: [Int] T12: [Long] T13: [Short] @@ -254,9 +254,9 @@ type_alias: T0: [Bool] T1: [Bool, Byte] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T4: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T4: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T5: [Bool, Byte, Char, Int, Long, Short] T6: [Bool, Byte, Char, Int, Short] T7: [Bool, Byte, Char, Short] @@ -269,7 +269,7 @@ T14: [Double] T15: [Float] T16: [Float, Int] - T17: [Half] + T17: [Half, BFloat16] T18: [Int] T19: [Long] T20: [Short] @@ -598,7 +598,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -896,7 +896,7 @@ namespace: edge inherits: aten::alias_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -905,7 +905,7 @@ namespace: edge inherits: aten::amax type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -914,7 +914,7 @@ namespace: edge inherits: aten::amin type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -924,7 +924,7 @@ inherits: aten::any type_alias: T0: [Bool] - T1: [Bool, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] type_constraint: - self: T1 @@ -941,7 +941,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -981,7 +981,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -2071,7 +2071,7 @@ namespace: edge inherits: aten::argmax type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] T1: [Long] type_constraint: - self: T0 @@ -2081,7 +2081,7 @@ namespace: edge inherits: aten::argmin type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] T1: [Long] type_constraint: - self: T0 @@ -2091,7 +2091,7 @@ namespace: edge inherits: aten::as_strided_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -2101,7 +2101,7 @@ inherits: aten::asin type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2114,7 +2114,7 @@ inherits: aten::asinh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2127,7 +2127,7 @@ inherits: aten::atan type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2140,7 +2140,7 @@ inherits: aten::atanh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2152,7 +2152,7 @@ namespace: edge inherits: aten::avg_pool2d type_alias: - T0: [Double, Float, Half, Long] + T0: [Double, Float, Half, BFloat16, Long] type_constraint: - self: T0 __ret_0: T0 @@ -2479,7 +2479,7 @@ namespace: edge inherits: aten::bmm type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 mat2: T0 @@ -2489,7 +2489,7 @@ namespace: edge inherits: aten::cat type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - tensors: T0 __ret_0: T0 @@ -2498,7 +2498,7 @@ namespace: edge inherits: aten::ceil type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -2516,7 +2516,7 @@ T6: [Char] T7: [Double] T8: [Float] - T9: [Half] + T9: [Half, BFloat16] T10: [Int] T11: [Long] T12: [Short] @@ -2775,7 +2775,7 @@ namespace: edge inherits: aten::clone type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -2790,7 +2790,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -2831,12 +2831,12 @@ namespace: edge inherits: aten::convolution type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T1: [Byte] T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -2879,12 +2879,12 @@ inherits: aten::copy type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -2926,7 +2926,7 @@ inherits: aten::cos type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2939,7 +2939,7 @@ inherits: aten::cosh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -2951,12 +2951,12 @@ namespace: edge inherits: aten::cumsum type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T1: [Byte] T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -2990,7 +2990,7 @@ namespace: edge inherits: aten::detach_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -3006,7 +3006,7 @@ T4: [Char] T5: [Double] T6: [Float] - T7: [Half] + T7: [Half, BFloat16] T8: [Int] T9: [Long] T10: [Short] @@ -3057,16 +3057,16 @@ inherits: aten::div.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Bool, Byte, Char, Float, Int, Long, Short] - T4: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T4: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T5: [Byte] T6: [Char] T7: [Double] T8: [Float] T9: [Float, UInt16] - T10: [Half] + T10: [Half, BFloat16] T11: [Int] T12: [Long] T13: [Short] @@ -3138,9 +3138,9 @@ inherits: aten::div.Tensor_mode type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -3151,7 +3151,7 @@ T11: [Char, Short] T12: [Double] T13: [Float] - T14: [Half] + T14: [Half, BFloat16] T15: [Int] T16: [Long] T17: [Short] @@ -3225,7 +3225,7 @@ T1: [Bool, Float, Int] T2: [Double] T3: [Float] - T4: [Half] + T4: [Half, BFloat16] T5: [Int] type_constraint: - self: T2 @@ -3643,7 +3643,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Int, Long] T8: [Long] @@ -3685,7 +3685,7 @@ namespace: edge inherits: aten::empty.memory_format type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - dtype: T0 __ret_0: T0 @@ -3695,13 +3695,13 @@ inherits: aten::eq.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Bool, Float, Int] T3: [Byte] T4: [Char] T5: [Double] T6: [Float] - T7: [Half] + T7: [Half, BFloat16] T8: [Int] T9: [Long] T10: [Short] @@ -3752,7 +3752,7 @@ inherits: aten::erf type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -3765,7 +3765,7 @@ inherits: aten::exp type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -3777,7 +3777,7 @@ namespace: edge inherits: aten::expand_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -3792,7 +3792,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -3834,12 +3834,12 @@ inherits: aten::fill.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -3880,7 +3880,7 @@ namespace: edge inherits: aten::floor type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -3890,9 +3890,9 @@ inherits: aten::floor_divide type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -3903,7 +3903,7 @@ T11: [Char, Short] T12: [Double] T13: [Float] - T14: [Half] + T14: [Half, BFloat16] T15: [Int] T16: [Long] T17: [Short] @@ -3981,7 +3981,7 @@ T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -4022,9 +4022,9 @@ inherits: aten::fmod.Tensor type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -4035,7 +4035,7 @@ T11: [Char, Short] T12: [Double] T13: [Float] - T14: [Half] + T14: [Half, BFloat16] T15: [Int] T16: [Long] T17: [Short] @@ -4111,7 +4111,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -4153,13 +4153,13 @@ inherits: aten::full_like type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Bool, Float, Int] T3: [Byte] T4: [Char] T5: [Double] T6: [Float] - T7: [Half] + T7: [Half, BFloat16] T8: [Int] T9: [Long] T10: [Short] @@ -4691,14 +4691,14 @@ inherits: aten::ge.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Bool, Float, Int] T4: [Byte] T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -4745,14 +4745,14 @@ inherits: aten::ge.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -4823,7 +4823,7 @@ namespace: edge inherits: aten::gelu type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - self: T0 __ret_0: T0 @@ -4832,7 +4832,7 @@ namespace: edge inherits: aten::glu type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - self: T0 __ret_0: T0 @@ -4842,14 +4842,14 @@ inherits: aten::gt.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Bool, Float, Int] T4: [Byte] T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -4896,14 +4896,14 @@ inherits: aten::gt.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -4980,7 +4980,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -5188,7 +5188,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -5231,7 +5231,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -5282,7 +5282,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Int, Long] T8: [Long] @@ -5325,7 +5325,7 @@ inherits: aten::isinf type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T1 __ret_0: T0 @@ -5335,7 +5335,7 @@ inherits: aten::isnan type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T1 __ret_0: T0 @@ -5345,14 +5345,14 @@ inherits: aten::le.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Bool, Float, Int] T4: [Byte] T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5399,14 +5399,14 @@ inherits: aten::le.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5480,7 +5480,7 @@ T0: [Bool, Float, Int] T1: [Double] T2: [Float] - T3: [Half] + T3: [Half, BFloat16] type_constraint: - self: T1 negative_slope: T0 @@ -5496,7 +5496,7 @@ namespace: edge inherits: aten::lift_fresh_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -5506,7 +5506,7 @@ inherits: aten::log type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -5519,14 +5519,14 @@ inherits: aten::logical_and type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5598,7 +5598,7 @@ inherits: aten::logical_not type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T1 __ret_0: T0 @@ -5608,14 +5608,14 @@ inherits: aten::logical_or type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5687,14 +5687,14 @@ inherits: aten::logical_xor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5766,7 +5766,7 @@ inherits: aten::logit type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -5779,14 +5779,14 @@ inherits: aten::lt.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Bool, Float, Int] T4: [Byte] T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5833,14 +5833,14 @@ inherits: aten::lt.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half] + T6: [Double, Float, Half, BFloat16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -5917,7 +5917,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -5968,7 +5968,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -6009,7 +6009,7 @@ T1: [Char] T2: [Double] T3: [Float] - T4: [Half] + T4: [Half, BFloat16] T5: [Int] T6: [Long] T7: [Short] @@ -6043,10 +6043,10 @@ namespace: edge inherits: aten::mean.dim type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T1: [Double] T2: [Float] - T3: [Half] + T3: [Half, BFloat16] type_constraint: - self: T0 dtype: T1 @@ -6067,7 +6067,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -6106,9 +6106,9 @@ type_alias: T0: [Bool] T1: [Bool, Byte] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T4: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T4: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T5: [Bool, Byte, Char, Int, Long, Short] T6: [Bool, Byte, Char, Int, Short] T7: [Bool, Byte, Char, Short] @@ -6119,7 +6119,7 @@ T12: [Char, Short] T13: [Double] T14: [Float] - T15: [Half] + T15: [Half, BFloat16] T16: [Int] T17: [Long] T18: [Short] @@ -6192,7 +6192,7 @@ namespace: edge inherits: aten::mm type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 mat2: T0 @@ -6211,7 +6211,7 @@ T6: [Char] T7: [Double] T8: [Float] - T9: [Half] + T9: [Half, BFloat16] T10: [Int] T11: [Long] T12: [Short] @@ -6259,9 +6259,9 @@ inherits: aten::mul.Tensor type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -6273,7 +6273,7 @@ T12: [Char, Short] T13: [Double] T14: [Float] - T15: [Half] + T15: [Half, BFloat16] T16: [Int] T17: [Long] T18: [Short] @@ -6346,7 +6346,7 @@ namespace: edge inherits: aten::native_layer_norm type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - input: T0 weight: T0 @@ -6360,13 +6360,13 @@ inherits: aten::ne.Scalar type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Bool, Float, Int] T3: [Byte] T4: [Char] T5: [Double] T6: [Float] - T7: [Half] + T7: [Half, BFloat16] T8: [Int] T9: [Long] T10: [Short] @@ -6417,14 +6417,14 @@ inherits: aten::ne.Tensor type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T3: [Byte] T4: [Char] T5: [Double] - T6: [Double, Float, Half, UInt16] + T6: [Double, Float, Half, BFloat16, UInt16] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -6495,7 +6495,7 @@ namespace: edge inherits: aten::neg type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -6504,7 +6504,7 @@ namespace: edge inherits: aten::nonzero type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] T1: [Long] type_constraint: - self: T0 @@ -6514,7 +6514,7 @@ namespace: edge inherits: aten::ones type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - dtype: T0 __ret_0: T0 @@ -6523,7 +6523,7 @@ namespace: edge inherits: aten::permute_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -6532,7 +6532,7 @@ namespace: edge inherits: aten::pixel_shuffle type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -6550,7 +6550,7 @@ T6: [Char] T7: [Double] T8: [Float] - T9: [Half] + T9: [Half, BFloat16] T10: [Int] T11: [Long] T12: [Short] @@ -6598,9 +6598,9 @@ inherits: aten::pow.Tensor_Tensor type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -6611,7 +6611,7 @@ T11: [Char, Short] T12: [Double] T13: [Float] - T14: [Half] + T14: [Half, BFloat16] T15: [Int] T16: [Long] T17: [Short] @@ -6682,7 +6682,7 @@ inherits: aten::reciprocal type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -6694,7 +6694,7 @@ namespace: edge inherits: aten::relu type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -6711,7 +6711,7 @@ T5: [Char] T6: [Double] T7: [Float] - T8: [Half] + T8: [Half, BFloat16] T9: [Int] T10: [Long] T11: [Short] @@ -6752,9 +6752,9 @@ inherits: aten::remainder.Tensor type_alias: T0: [Bool, Byte] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Bool, Byte, Char, Int, Long, Short] T5: [Bool, Byte, Char, Int, Short] T6: [Bool, Byte, Char, Short] @@ -6765,7 +6765,7 @@ T11: [Char, Short] T12: [Double] T13: [Float] - T14: [Half] + T14: [Half, BFloat16] T15: [Int] T16: [Long] T17: [Short] @@ -6835,7 +6835,7 @@ namespace: edge inherits: aten::repeat type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -6844,7 +6844,7 @@ namespace: edge inherits: aten::round type_alias: - T0: [Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -6854,7 +6854,7 @@ inherits: aten::rsqrt type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -6872,7 +6872,7 @@ T3: [Double] T4: [Float] T5: [Float, Int] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -6989,7 +6989,7 @@ T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -7035,7 +7035,7 @@ T2: [Char] T3: [Double] T4: [Float] - T5: [Half] + T5: [Half, BFloat16] T6: [Int] T7: [Long] T8: [Short] @@ -7081,7 +7081,7 @@ namespace: edge inherits: aten::select_copy.int type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7091,12 +7091,12 @@ inherits: aten::select_scatter type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -7138,7 +7138,7 @@ inherits: aten::sigmoid type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7150,7 +7150,7 @@ namespace: edge inherits: aten::sign type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -7160,7 +7160,7 @@ inherits: aten::sin type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7173,7 +7173,7 @@ inherits: aten::sinh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7185,7 +7185,7 @@ namespace: edge inherits: aten::slice_copy.Tensor type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7195,12 +7195,12 @@ inherits: aten::slice_scatter type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -7241,7 +7241,7 @@ namespace: edge inherits: aten::split_copy.Tensor type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7250,7 +7250,7 @@ namespace: edge inherits: aten::split_with_sizes_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7260,7 +7260,7 @@ inherits: aten::sqrt type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7272,7 +7272,7 @@ namespace: edge inherits: aten::squeeze_copy.dim type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7281,7 +7281,7 @@ namespace: edge inherits: aten::squeeze_copy.dims type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7290,7 +7290,7 @@ namespace: edge inherits: aten::stack type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - tensors: T0 __ret_0: T0 @@ -7305,7 +7305,7 @@ T3: [Double] T4: [Float] T5: [Float, Int] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -7417,9 +7417,9 @@ inherits: aten::sub.Tensor type_alias: T0: [Byte] - T1: [Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T2: [Byte, Char, Float, Half, Int, Long, Short, UInt16] - T3: [Byte, Char, Half, Int, Long, Short, UInt16] + T1: [Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T2: [Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T4: [Byte, Char, Int, Long, Short] T5: [Byte, Char, Int, Short] T6: [Byte, Char, Short] @@ -7429,7 +7429,7 @@ T10: [Double] T11: [Float] T12: [Float, Int] - T13: [Half] + T13: [Half, BFloat16] T14: [Int] T15: [Long] T16: [Short] @@ -7717,12 +7717,12 @@ inherits: aten::sum.dim_IntList type_alias: T0: [Bool] - T1: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T1: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] T2: [Byte] T3: [Char] T4: [Double] T5: [Float] - T6: [Half] + T6: [Half, BFloat16] T7: [Int] T8: [Long] T9: [Short] @@ -7759,7 +7759,7 @@ namespace: edge inherits: aten::t_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7769,7 +7769,7 @@ inherits: aten::tan type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7782,7 +7782,7 @@ inherits: aten::tanh type_alias: T0: [Bool, Byte, Char, Float, Int, Long, Short, UInt16] - T1: [Double, Half] + T1: [Double, Half, BFloat16] T2: [Float] type_constraint: - self: T0 @@ -7794,7 +7794,7 @@ namespace: edge inherits: aten::transpose_copy.int type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7803,7 +7803,7 @@ namespace: edge inherits: aten::tril type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short] type_constraint: - self: T0 __ret_0: T0 @@ -7812,7 +7812,7 @@ namespace: edge inherits: aten::unbind_copy.int type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7821,7 +7821,7 @@ namespace: edge inherits: aten::unsqueeze_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7830,7 +7830,7 @@ namespace: edge inherits: aten::var.dim type_alias: - T0: [Double, Float, Half] + T0: [Double, Float, Half, BFloat16] type_constraint: - self: T0 __ret_0: T0 @@ -7839,7 +7839,7 @@ namespace: edge inherits: aten::view_copy type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - self: T0 __ret_0: T0 @@ -7850,9 +7850,9 @@ type_alias: T0: [Bool] T1: [Bool, Byte] - T2: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] - T3: [Bool, Byte, Char, Float, Half, Int, Long, Short, UInt16] - T4: [Bool, Byte, Char, Half, Int, Long, Short, UInt16] + T2: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] + T3: [Bool, Byte, Char, Float, Half, BFloat16, Int, Long, Short, UInt16] + T4: [Bool, Byte, Char, Half, BFloat16, Int, Long, Short, UInt16] T5: [Bool, Byte, Char, Int, Long, Short] T6: [Bool, Byte, Char, Int, Short] T7: [Bool, Byte, Char, Short] @@ -7863,7 +7863,7 @@ T12: [Char, Short] T13: [Double] T14: [Float] - T15: [Half] + T15: [Half, BFloat16] T16: [Int] T17: [Long] T18: [Short] @@ -8382,7 +8382,7 @@ namespace: edge inherits: aten::zeros type_alias: - T0: [Bool, Byte, Char, Double, Float, Half, Int, Long, Short, UInt16] + T0: [Bool, Byte, Char, Double, Float, Half, BFloat16, Int, Long, Short, UInt16] type_constraint: - dtype: T0 __ret_0: T0 diff --git a/exir/dialects/edge/test/test_edge_ops.py b/exir/dialects/edge/test/test_edge_ops.py index f621ba1bebb..de0aa205100 100644 --- a/exir/dialects/edge/test/test_edge_ops.py +++ b/exir/dialects/edge/test/test_edge_ops.py @@ -108,14 +108,16 @@ def test_edge_argument_dtype_constraints(self) -> None: if isinstance(arg.type, torch.TensorType): self.assertTrue(isinstance(arg.allowed_types, set)) self.assertEqual( - arg.allowed_types, {torch.float16, torch.float32, torch.float64} + arg.allowed_types, + {torch.float16, torch.float32, torch.float64, torch.bfloat16}, ) for ret in returns: if isinstance(ret.type, torch.TensorType): self.assertTrue(isinstance(ret.allowed_types, set)) self.assertEqual( - ret.allowed_types, {torch.float16, torch.float32, torch.float64} + ret.allowed_types, + {torch.float16, torch.float32, torch.float64, torch.bfloat16}, ) def test_allowed_dtype_set(self) -> None: diff --git a/exir/tests/test_arg_validator.py b/exir/tests/test_arg_validator.py index ede8b224329..5580881ab01 100644 --- a/exir/tests/test_arg_validator.py +++ b/exir/tests/test_arg_validator.py @@ -37,16 +37,15 @@ def forward(self, x): self.assertEqual(len(validator.violating_ops), 0) def test_edge_dialect_fails(self) -> None: - # torch.bfloat16 is not supported by edge::aten::_log_softmax + # torch.complex64 is not supported by edge::aten::add class M(torch.nn.Module): def __init__(self): super().__init__() - self.m = torch.nn.LogSoftmax(dim=1) def forward(self, x): - return self.m(x) + return x + x - inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),) + inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.complex64),) egm = ( to_edge( export(M(), inputs, strict=True), @@ -61,12 +60,13 @@ def forward(self, x): key: EdgeOpOverload = next(iter(validator.violating_ops)) self.assertEqual( key.name(), - ops.edge.aten._log_softmax.default.name(), + ops.edge.aten.add.Tensor.name(), ) self.assertDictEqual( validator.violating_ops[key][0], { - "self": torch.bfloat16, - "__ret_0": torch.bfloat16, + "self": torch.complex64, + "other": torch.complex64, + "__ret_0": torch.complex64, }, ) diff --git a/exir/tests/test_verification.py b/exir/tests/test_verification.py index f18e9d74b75..90073216b2d 100644 --- a/exir/tests/test_verification.py +++ b/exir/tests/test_verification.py @@ -252,14 +252,21 @@ def forward(self, x): self.assertTrue(verifier.is_valid(egm)) def test_edge_sad_with_edge_ops(self) -> None: - # log_softmax only takes float or double Tensor - m = torch.nn.LogSoftmax(dim=1) + # add operation does not support complex64 dtype + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + m = TestModel() with self.assertRaises(SpecViolationError): _ = ( to_edge( export( m, - (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),), + (torch.randn(1, 3, 100, 100).to(dtype=torch.complex64),), strict=True, ) )