diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index 362dc57c37d..fc167dd71e8 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -73,6 +73,8 @@ torch::executor::ScalarType torch_to_executorch_scalar_type( return torch::executor::ScalarType::Short; case c10::ScalarType::Half: return torch::executor::ScalarType::Half; + case c10::ScalarType::BFloat16: + return torch::executor::ScalarType::BFloat16; case c10::ScalarType::Int: return torch::executor::ScalarType::Int; case c10::ScalarType::Float: @@ -103,6 +105,8 @@ c10::ScalarType executorch_to_torch_scalar_type( return c10::ScalarType::Short; case torch::executor::ScalarType::Half: return c10::ScalarType::Half; + case torch::executor::ScalarType::BFloat16: + return c10::ScalarType::BFloat16; case torch::executor::ScalarType::Int: return c10::ScalarType::Int; case torch::executor::ScalarType::Float: