diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 992761fedce34e..422071b090e6df 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -2098,6 +2098,11 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand, int64_t inferred_dimension) { + CHECK(operand->shape().is_unbounded_dynamic() || + ShapeUtil::StaticExtentProduct(shape) == + ShapeUtil::StaticExtentProduct(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); return std::make_unique(shape, operand, inferred_dimension); } diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 9439fcd81f2b99..143eaafc1363e3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -1555,11 +1555,6 @@ HloReshapeInstruction::HloReshapeInstruction(const Shape& shape, int64_t inferred_dimension) : HloInstruction(HloOpcode::kReshape, shape), inferred_dimension_(inferred_dimension) { - CHECK(operand->shape().is_unbounded_dynamic() || - ShapeUtil::StaticExtentProduct(shape) == - ShapeUtil::StaticExtentProduct(operand->shape())) - << "shape: " << ShapeUtil::HumanString(shape) - << " operand: " << ShapeUtil::HumanString(operand->shape()); AppendOperand(operand); }