Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A modest proposal: delete arithmetic overloads from c10::Half #64023

Open
ezyang opened this issue Aug 26, 2021 · 5 comments
Open

A modest proposal: delete arithmetic overloads from c10::Half #64023

ezyang opened this issue Aug 26, 2021 · 5 comments
Labels
module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Aug 26, 2021

There are a number of one off issues from people complaining about some half precision operator not having enough precision, and us fixing it by increasing the internal precision in the operator. For example: #41446

The thing is that we basically never actually do computation using half intrinsics: it's always done by converting to float, doing the operation, and then streaming the computation back to memory as half (it's all about memory bandwidth). All of our half overloads on c10::Half are implemented this way: cast to float, do the op, cast back.

So what do you think about a kernel like this?

void smooth_l1_kernel_cuda(TensorIterator& iter, double beta) {
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&iter, beta]() {
    scalar_t beta_val(beta);
    gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
      auto z = ::abs(a - b);
      return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val;
    });
  });
}

It is easy look at this code for float and conclude that it is good. But what if we run this in half precision? There is a risk here that z * z will overflow before we can divide it by beta_val, which might put it back in the representable range of half precision. In fact, we are getting very lucky here: ::abs has no half precision override, so it is implicitly coercing a - b to floating precision and the rest of the computation proceeds in floating point (despite the explicit casts of 0.5 to half precision). Actually, not all is well because we can still force an overflow in the subtraction (although it's probably pretty unlikely you'll actually trigger this in practice:

>>> x = torch.tensor([65504./2+100],dtype=torch.half).cuda()
>>> torch.nn.functional.smooth_l1_loss(x, -x, beta=65504.)
tensor(inf, device='cuda:0', dtype=torch.float16)
>>> torch.nn.functional.smooth_l1_loss(x.float(), -x.float(), beta=65504.)
tensor(32976., device='cuda:0')  # representable in half!

So, I propose we just delete the arithmetic overloads and force people to explicitly cast. Will make things a lot clearer!

cc @ngimel @gchanan

@ngimel
Copy link
Collaborator

ngimel commented Aug 26, 2021

This is a great proposal!
Whatever we do for halfs, we should also do for bfloat16. Right now, it's pretty confusing which overloads are used for particular operations.
Another example issue, where unexpected intermediate truncation leads to nans #61523

@ngimel
Copy link
Collaborator

ngimel commented Aug 26, 2021

Historic note: initially we didn't have the overloads, so a lot of code in THC/THCUNN was littered with explicit casts, and that also wasn't very readable, so the goal of the overloads, AFAIU, was to make writing templated code for non-standard type as easy and clean as it is for the standard ones. But as you point out, it comes with sometimes unexpected accuracy drops, and still puts a mental load on author/reviewer to think baout whether things should have been casted to higher precision type or not. Sometimes we do it right, sometimes we don't.

@ezyang ezyang added module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 30, 2021
@lezcano
Copy link
Collaborator

lezcano commented Oct 21, 2021

Would it be possible to add (yet another) option to TI that casts these halfs and bfloats to the correct type before sending them to the kernel? That way we could support this with minimal code changes (we would still need to remove the _HALF and _AND2 from the DISPATCH macros though.

@ngimel
Copy link
Collaborator

ngimel commented Oct 21, 2021

It shouldn't be an option to TI (it should be part of TI kernels, not TI itself), and for GPU it's very nearly done in #63884 for binary ops. For unary ops even a wrapper like this isn't needed, just rewriting lambdas in terms of opmath_t should achieve this. It can't be achieved by simply changing the dispatch, because lambda itself should read low precision data and convert to high.

facebook-github-bot pushed a commit that referenced this issue Oct 22, 2021
Summary:
Removes some of the half math ops to make #64023 possible.

Pull Request resolved: #67048

Reviewed By: mruberry

Differential Revision: D31847249

Pulled By: ngimel

fbshipit-source-id: 8385aacd846bb990e368ff336eb346d847af70b9
@ezyang
Copy link
Contributor Author

ezyang commented Jan 5, 2022

#65851 should help, perhaps, as I can insert the conversions as part of codegen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants