-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Description
🐛 Bug
Program crashes when using normalize() with amp patch from #2654
Here's the error msg:
2021-04-02 10:32:04.080678: W tensorflow/core/framework/op_kernel.cc:1763] OP_REQUIRES failed at xrt_compile_ops.cc:220 : Internal: Seen floating point types of different precisions in %dot.123 = f32[8,5]{1,0} dot(f32[8,10]{1,0} %add.110, f16[10,5]{1,0} %transpose.122), lhs_contracting_dims={1}, rhs_contracting_dims={0}, but mixed precision is disallowed.
To Reproduce
with autocast(True):
m = nn.Linear(5, 10).to(device=xla)
x = torch.rand(8, 5).to(device=xla).requires_grad_(True)
x = m(x)
x = F.normalize(x, dim=1) # <-- this line causes the exception
loss = x.sum()
loss.backward()
xm.mark_step()
Steps to reproduce the behavior:
- run the code above
Expected behavior
Environment
- Reproducible on XLA backend [CPU/TPU]: XLA GPU
- torch_xla version: 1.7.0 with support amp (auto mixed precision) #2654
Additional context
The error msg is dot op doesn's support mixed presicion so I tried to call PromoteValues
before building dot op here. It seems this could solve the problem. But I am not sure if this is the best way to fix it.
NodePtr Dot(const Value& input, const Value& weight) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_weight = loctx->GetOutputOp(node.operand(1));
// modification
std::tie(xla_input, xla_weight) = XlaHelpers::PromoteValues(xla_input, xla_weight);
return node.ReturnOp(BuildDot(xla_input, xla_weight), loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildDot(operands[0], operands[1]);
};
return GenericOp(OpKind(at::aten::mm), {input, weight},
[&]() {
return InferOutputShape({input.shape(), weight.shape()},
lower_for_shape_fn);
},
std::move(lower_fn));
}
Metadata
Metadata
Assignees
Labels
No labels