Skip to content

Crashing while using torch.nn.functional.normalize() with amp enabled #2857

@Clive2312

Description

@Clive2312

🐛 Bug

Program crashes when using normalize() with amp patch from #2654

Here's the error msg:

2021-04-02 10:32:04.080678: W tensorflow/core/framework/op_kernel.cc:1763] OP_REQUIRES failed at xrt_compile_ops.cc:220 : Internal: Seen floating point types of different precisions in %dot.123 = f32[8,5]{1,0} dot(f32[8,10]{1,0} %add.110, f16[10,5]{1,0} %transpose.122), lhs_contracting_dims={1}, rhs_contracting_dims={0}, but mixed precision is disallowed.

To Reproduce

with autocast(True):
    m = nn.Linear(5, 10).to(device=xla)
    x = torch.rand(8, 5).to(device=xla).requires_grad_(True)
    x = m(x)
    x = F.normalize(x, dim=1) # <-- this line causes the exception
    loss = x.sum()
    loss.backward()
    xm.mark_step()

Steps to reproduce the behavior:

  1. run the code above

Expected behavior

Environment

Additional context

The error msg is dot op doesn's support mixed presicion so I tried to call PromoteValues before building dot op here. It seems this could solve the problem. But I am not sure if this is the best way to fix it.

NodePtr Dot(const Value& input, const Value& weight) {
  auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
    xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
    xla::XlaOp xla_weight = loctx->GetOutputOp(node.operand(1));
    // modification
    std::tie(xla_input, xla_weight) = XlaHelpers::PromoteValues(xla_input, xla_weight);
    return node.ReturnOp(BuildDot(xla_input, xla_weight), loctx);
  };
  auto lower_for_shape_fn =
      [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
    return BuildDot(operands[0], operands[1]);
  };
  return GenericOp(OpKind(at::aten::mm), {input, weight},
                   [&]() {
                     return InferOutputShape({input.shape(), weight.shape()},
                                             lower_for_shape_fn);
                   },
                   std::move(lower_fn));
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions