diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 4b6612ffb7f64e..9630f6f0e8cc34 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -142,16 +142,11 @@ static auto& kUnsupportedOps = HloOpcode::kCall}; static auto& kUnimplementedOps = *new absl::flat_hash_set{ - HloOpcode::kConvolution, - HloOpcode::kDot, - HloOpcode::kDynamicUpdateSlice, - HloOpcode::kMap, - HloOpcode::kReduceWindow, - // Has a custom approximation in XLA: - HloOpcode::kErf, -}; - -static auto& kF32SupportedOps = *new absl::flat_hash_set{ + HloOpcode::kConvolution, HloOpcode::kDot, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kMap, HloOpcode::kReduceWindow, + // Custom approximations in XLA: + HloOpcode::kErf, HloOpcode::kTanh, + // Incorrect NaN handling: HloOpcode::kMaximum, HloOpcode::kMinimum, HloOpcode::kClamp}; bool IsUnsupportedConstant(const HloInstruction* instr) { @@ -776,12 +771,6 @@ bool IsHloOpSupported(const HloInstruction* instr, return false; } - // TODO(jreiffers): Fix the F64 lowering for these ops. - if (kF32SupportedOps.contains(instr->opcode()) && - instr->shape().element_type() == F64) { - return false; - } - return !(kUnsupportedOps.contains(instr->opcode()) || kUnimplementedOps.contains(instr->opcode()) || IsUnsupportedConstant(instr) || IsUnsupportedTuple(instr) ||