New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use MTA for amp grad unscaling, enforce op math type in MTA functors, and allow op lambdas #44778
Changes from 31 commits
8713ff6
518bc3c
5d34ee2
64b6748
f978609
cafe839
36b93c3
79f4908
eb86bbc
c1731c1
8d5ea34
a5ca21d
cb29145
294637e
961b780
8cb1bc1
a5ba727
868d3b1
efd12f5
e5c7622
b938310
8f5bfcb
c4deefe
af02a6d
b612ac5
aa98705
3622b28
db2e193
552f614
e787dd2
bb20adc
b4d274d
13439de
e3804e3
c4cfaa9
c9a8944
c989e1b
f84274d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,13 @@ | |
#include <math.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/DeviceGuard.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/cuda/ForeachFunctors.cuh> | ||
#include <ATen/native/cuda/Loops.cuh> | ||
#include <ATen/native/ForeachUtils.h> | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
|
||
namespace { | ||
// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e, | ||
|
@@ -33,49 +37,136 @@ static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float v | |
namespace at { | ||
namespace native { | ||
|
||
// Multiplies scaled_grad in-place by inv_scale. If an element of scaled_grad was inf or NaN sets found_inf to 1.0. | ||
// | ||
// Args: | ||
// scaled_grad: A (scaled) gradient tensor. May contain infs or NaNs. | ||
// found_inf: A single-element float tensor to which 1.0 will be written if any gradients contain infs/nans. | ||
// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. | ||
// inv_scale: The inverse of the scale factor by which scaled_grad is currently multiplied. | ||
// | ||
// Returns: | ||
// A tuple with references to scaled_grad, which is now unscaled in place, and found_inf, | ||
// which is now guaranteed to contain 1.0 if an inf or NaN was found in scaled_grad. | ||
namespace { | ||
// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_. | ||
// Handles individual tensors that are acceptable to unscale but not MTA-safe. | ||
void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad, | ||
Tensor& found_inf, | ||
const Tensor& inv_scale) | ||
{ | ||
TORCH_CHECK(scaled_grad.is_cuda(), "scaled_grad must be a CUDA tensor."); | ||
// The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks. | ||
|
||
// It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case. | ||
const OptionalDeviceGuard device_guard(device_of(scaled_grad)); | ||
|
||
// Acts on scaled_grad in place. | ||
auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
iter.dtype(), | ||
"_amp_non_finite_check_and_unscale_cuda", | ||
[&iter, &found_inf, &inv_scale] { | ||
auto* found_inf_ptr = found_inf.data_ptr<float>(); | ||
auto* inv_scale_ptr = inv_scale.data_ptr<float>(); | ||
|
||
using opmath_t = get_opmath_t<scalar_t>::opmath_t; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's wrong with using acc_type<scalar_t, true>? That's what's used in all other places. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it does the mapping we want, that makes sense. I'll double check the behavior. |
||
|
||
gpu_kernel(iter, | ||
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { | ||
auto val = static_cast<opmath_t>(val_in); | ||
if (!isfinite_ensure_cuda_math(val)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. time goes on, nature heals, maybe std::isfinite works now? |
||
*found_inf_ptr = 1.f; | ||
} | ||
// Every thread accesses inv_scale, but it will hit in cache. | ||
const auto inv_scale_val = *inv_scale_ptr; | ||
return static_cast<scalar_t>(inv_scale_val == 1.f ? val : val * inv_scale_val); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}); | ||
}); | ||
} | ||
} // anonymous namespace | ||
|
||
|
||
// Multiplies each tensor in scaled_grads by inv_scale in-place. | ||
// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0. | ||
// Uses multi tensor apply (MTA) to process all MTA-safe tensors. | ||
// | ||
// Args: | ||
// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs. | ||
// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans. | ||
// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. | ||
// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied. | ||
void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, | ||
Tensor& found_inf, | ||
const Tensor& inv_scale) | ||
{ | ||
if (scaled_grads.size() == 0) { | ||
return; | ||
} | ||
|
||
TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor."); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor."); | ||
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); | ||
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); | ||
TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); | ||
TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); | ||
TORCH_CHECK(scaled_grad.layout() == at::kStrided, "scaled_grad must be a strided (not sparse) Tensor."); | ||
|
||
// Act on scaled_grad in place. | ||
auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); | ||
// Ensures client code (GradScaler) filtered scaled_grads by dtype. | ||
check_foreach_api_restrictions(scaled_grads); | ||
|
||
std::vector<std::vector<at::Tensor>> tensor_lists; | ||
|
||
// is_non_overlapping_and_dense() is not available in Python. | ||
// GradScaler can't filter for it. We need to filter here. | ||
if (can_use_fast_route(scaled_grads)) { | ||
// Hopefully common case. | ||
// can_use_fast_route is true, which confirms: | ||
// - all scaled_grads are strided | ||
// - all scaled_grads are non overlapping and dense | ||
// - all scaled_grads are on the same device | ||
TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors."); | ||
// Sets up MTA launch to use scaled_grads as-is. | ||
tensor_lists.emplace_back(scaled_grads.vec()); | ||
} else { | ||
// Hopefully uncommon case. | ||
// can_use_fast_route is an all-or-nothing check. In this path it was false, | ||
// so any of the above confirmations could have gone wrong. | ||
// We filter MTA-safe tensors into an MTA-able list. | ||
// If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel. | ||
// If a tensor is unacceptable, we throw an error to blame GradScaler. | ||
tensor_lists.resize(1); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tensor_lists[0].reserve(scaled_grads.size()); | ||
auto expected_device = scaled_grads[0].device(); | ||
for (const Tensor& t : scaled_grads) { | ||
// Ensures GradScaler filtered scaled_grads by device. | ||
TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor."); | ||
TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); | ||
TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); | ||
if (!t.is_non_overlapping_and_dense()) { | ||
// t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. | ||
_amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t), | ||
found_inf, | ||
inv_scale); | ||
} else { | ||
tensor_lists[0].push_back(t); | ||
} | ||
} | ||
if (tensor_lists[0].size() == 0) { | ||
return; | ||
} | ||
} | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
iter.dtype(), | ||
"_amp_non_finite_check_and_unscale_cuda", | ||
[&iter, &found_inf, &inv_scale] { | ||
tensor_lists[0][0].scalar_type(), | ||
"_amp_foreach_non_finite_check_and_unscale_cuda", | ||
[&tensor_lists, &found_inf, &inv_scale] { | ||
auto* found_inf_ptr = found_inf.data_ptr<float>(); | ||
auto* inv_scale_ptr = inv_scale.data_ptr<float>(); | ||
|
||
gpu_kernel(iter, [found_inf_ptr, inv_scale_ptr]GPU_LAMBDA(scalar_t val) -> scalar_t { | ||
float fval = static_cast<float>(val); | ||
// See isfinite_ensure_cuda_math above. | ||
if (!isfinite_ensure_cuda_math(fval)) { | ||
*found_inf_ptr = 1.f; | ||
} | ||
const auto inv_scale_val = *inv_scale_ptr; // Every thread accesses inv_scale, but it will hit in cache. | ||
return static_cast<scalar_t>(inv_scale_val == 1.f ? fval : fval*inv_scale_val); | ||
}); | ||
using opmath_t = get_opmath_t<scalar_t>::opmath_t; | ||
|
||
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. | ||
multi_tensor_apply<1>(tensor_lists, | ||
UnaryOpFunctor_<scalar_t>(), | ||
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { | ||
// There is a slight asymmetry here with the TensorIterator kernel above. | ||
// MTA Functors ensure val comes in as opmath_t rather than scalar_t. | ||
if (!isfinite_ensure_cuda_math(val)) { | ||
*found_inf_ptr = 1.f; | ||
} | ||
// Every thread accesses inv_scale, but it will hit in cache. | ||
const auto inv_scale_val = *inv_scale_ptr; | ||
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val); | ||
}); | ||
}); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,9 @@ namespace at { namespace native { | |
|
||
template<template<class> class Op> | ||
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists; | ||
std::vector<std::vector<at::Tensor>> tensor_lists; | ||
std::vector<at::Tensor> vec_res; | ||
vec_res.reserve(tensors.size()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do the same for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I've already done so, ctrl+f "reserve(". Lmk if you spot any location i missed. |
||
for (const auto& t: tensors) { | ||
vec_res.emplace_back(at::native::empty_like(t)); | ||
} | ||
|
@@ -16,18 +17,26 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> s | |
tensor_lists.emplace_back(vec_res); | ||
|
||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() { | ||
multi_tensor_apply<2>(tensor_lists, scalars, BinaryOpScalarListFunctor<scalar_t, Op>()); | ||
using opmath_t = get_opmath_t<scalar_t>::opmath_t; | ||
multi_tensor_apply<2>(tensor_lists, | ||
scalars, | ||
BinaryOpScalarListFunctor<scalar_t>(), | ||
Op<opmath_t>()); | ||
}); | ||
return tensor_lists[1]; | ||
} | ||
|
||
template<template<class> class Op> | ||
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists; | ||
std::vector<std::vector<at::Tensor>> tensor_lists; | ||
tensor_lists.emplace_back(tensors.vec()); | ||
|
||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() { | ||
multi_tensor_apply<1>(tensor_lists, scalars, BinaryOpScalarListFunctor_<scalar_t, Op>()); | ||
using opmath_t = get_opmath_t<scalar_t>::opmath_t; | ||
multi_tensor_apply<1>(tensor_lists, | ||
scalars, | ||
BinaryOpScalarListFunctor_<scalar_t>(), | ||
Op<opmath_t>()); | ||
}); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't, only gpu_kernel_with_scalars does