Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MTA for amp grad unscaling, enforce op math type in MTA functors, and allow op lambdas #44778

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8713ff6
procrastiworking
mcarilli Sep 16, 2020
518bc3c
Pass scalars as opmath_t to avoid precision loss
mcarilli Sep 16, 2020
5d34ee2
Fixed opmath_t for all Functors. Pushing for visibility, still need …
mcarilli Sep 16, 2020
64b6748
Change call sites to supply instantiated op functors
mcarilli Sep 16, 2020
f978609
Refactor op plumbing to allow lambdas. Everything compiles, test_for…
mcarilli Sep 17, 2020
cafe839
some comments
mcarilli Sep 17, 2020
36b93c3
Test passes
mcarilli Sep 21, 2020
79f4908
test_grad_scaling* all pass
mcarilli Sep 21, 2020
eb86bbc
resolve conflicts
mcarilli Sep 24, 2020
c1731c1
Fixing subrepos with master take 1, how do i always screw this up
mcarilli Sep 24, 2020
8d5ea34
merging in master
mcarilli Sep 25, 2020
a5ca21d
Align scalarlist Functors with other ForeachFunctors
mcarilli Sep 24, 2020
cb29145
All foreach tests except for four cpu tests pass
mcarilli Sep 25, 2020
294637e
tests mimic cuda kernel dtype flow
mcarilli Sep 25, 2020
961b780
two comments
mcarilli Sep 26, 2020
8cb1bc1
this is why i can't have nice things
mcarilli Sep 27, 2020
a5ba727
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 27, 2020
868d3b1
self.device_type instead of device.startswith
mcarilli Sep 27, 2020
efd12f5
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 27, 2020
e5c7622
hopefully fix bc-breaking error
mcarilli Sep 27, 2020
b938310
If on rocm, dont explicitly cast to fp32 before computing expected
mcarilli Sep 28, 2020
8f5bfcb
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 28, 2020
c4deefe
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 28, 2020
af02a6d
increase tolerances for rocm
mcarilli Sep 28, 2020
b612ac5
specify rtol and atol
mcarilli Sep 29, 2020
aa98705
Device guard in fallback path
mcarilli Sep 29, 2020
3622b28
resolving conflict
mcarilli Sep 29, 2020
db2e193
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 29, 2020
552f614
Adjust rocm bf16 and fp16 atol
mcarilli Sep 29, 2020
e787dd2
Fix bad conflict resolve in grad_scaler.py
mcarilli Sep 29, 2020
bb20adc
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 29, 2020
b4d274d
fix indent and mypy error
mcarilli Sep 29, 2020
13439de
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 29, 2020
e3804e3
flake8
mcarilli Sep 29, 2020
c4cfaa9
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 29, 2020
c9a8944
ensure multiGPU unscaling test tries fallback on both devices
mcarilli Sep 30, 2020
c989e1b
Merge remote-tracking branch 'upstream/master' into mta_unscale
mcarilli Sep 30, 2020
f84274d
resolving conflict
mcarilli Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions aten/src/ATen/native/ForeachUtils.h
Expand Up @@ -74,7 +74,7 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
return false;
}

// integral scalar + boolean tensor will result in integral tensor
// integral scalar + boolean tensor will result in integral tensor
if (scalar.isIntegral(/*includeBool*/ false) && t.dtype() == at::kBool) {
return false;
}
Expand All @@ -89,17 +89,17 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
for (int64_t i = 0; i < tensors1.size(); i++) {
TORCH_CHECK(tensors1[i].sizes() == tensors2[i].sizes(), "Corresponding tensors from tensor lists have different size.");

if (tensors1[i].device() != expected_device ||
if (tensors1[i].device() != expected_device ||
tensors2[i].device() != expected_device) {
return false;
}

if (tensors1[i].layout() != at::kStrided ||
if (tensors1[i].layout() != at::kStrided ||
tensors2[i].layout() != at::kStrided) {
return false;
}

if (tensors1[i].device() != expected_device ||
if (tensors1[i].device() != expected_device ||
tensors2[i].device() != expected_device) {
return false;
}
Expand All @@ -108,7 +108,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
return false;
}

if (!tensors1[i].is_non_overlapping_and_dense() ||
if (!tensors1[i].is_non_overlapping_and_dense() ||
!tensors2[i].is_non_overlapping_and_dense()) {
return false;
}
Expand Down
147 changes: 119 additions & 28 deletions aten/src/ATen/native/cuda/AmpKernels.cu
Expand Up @@ -3,9 +3,13 @@
#include <math.h>

#include <ATen/ATen.h>
#include <ATen/DeviceGuard.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/TensorIterator.h>


namespace {
// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e,
Expand Down Expand Up @@ -33,49 +37,136 @@ static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float v
namespace at {
namespace native {

// Multiplies scaled_grad in-place by inv_scale. If an element of scaled_grad was inf or NaN sets found_inf to 1.0.
//
// Args:
// scaled_grad: A (scaled) gradient tensor. May contain infs or NaNs.
// found_inf: A single-element float tensor to which 1.0 will be written if any gradients contain infs/nans.
// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
// inv_scale: The inverse of the scale factor by which scaled_grad is currently multiplied.
//
// Returns:
// A tuple with references to scaled_grad, which is now unscaled in place, and found_inf,
// which is now guaranteed to contain 1.0 if an inf or NaN was found in scaled_grad.
namespace {
// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_.
// Handles individual tensors that are acceptable to unscale but not MTA-safe.
void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad,
Tensor& found_inf,
const Tensor& inv_scale)
{
TORCH_CHECK(scaled_grad.is_cuda(), "scaled_grad must be a CUDA tensor.");
// The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks.

// It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't, only gpu_kernel_with_scalars does

const OptionalDeviceGuard device_guard(device_of(scaled_grad));

// Acts on scaled_grad in place.
auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
iter.dtype(),
"_amp_non_finite_check_and_unscale_cuda",
[&iter, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();

using opmath_t = get_opmath_t<scalar_t>::opmath_t;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wrong with using acc_type<scalar_t, true>? That's what's used in all other places.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it does the mapping we want, that makes sense. I'll double check the behavior. get_opmath_t doesn't do anything for integer types, while acc_type might. I'm not sure if we do or don't want any pre/post op casting to occur for integer types.


gpu_kernel(iter,
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t {
auto val = static_cast<opmath_t>(val_in);
if (!isfinite_ensure_cuda_math(val)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time goes on, nature heals, maybe std::isfinite works now?

*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<scalar_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
});
});
}
} // anonymous namespace


// Multiplies each tensor in scaled_grads by inv_scale in-place.
// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0.
// Uses multi tensor apply (MTA) to process all MTA-safe tensors.
//
// Args:
// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs.
// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans.
// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied.
void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
Tensor& found_inf,
const Tensor& inv_scale)
{
if (scaled_grads.size() == 0) {
return;
}

TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor.");
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor.");
TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
TORCH_CHECK(scaled_grad.layout() == at::kStrided, "scaled_grad must be a strided (not sparse) Tensor.");

// Act on scaled_grad in place.
auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);
// Ensures client code (GradScaler) filtered scaled_grads by dtype.
check_foreach_api_restrictions(scaled_grads);

std::vector<std::vector<at::Tensor>> tensor_lists;

// is_non_overlapping_and_dense() is not available in Python.
// GradScaler can't filter for it. We need to filter here.
if (can_use_fast_route(scaled_grads)) {
// Hopefully common case.
// can_use_fast_route is true, which confirms:
// - all scaled_grads are strided
// - all scaled_grads are non overlapping and dense
// - all scaled_grads are on the same device
TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors.");
// Sets up MTA launch to use scaled_grads as-is.
tensor_lists.emplace_back(scaled_grads.vec());
} else {
// Hopefully uncommon case.
// can_use_fast_route is an all-or-nothing check. In this path it was false,
// so any of the above confirmations could have gone wrong.
// We filter MTA-safe tensors into an MTA-able list.
// If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel.
// If a tensor is unacceptable, we throw an error to blame GradScaler.
tensor_lists.resize(1);
mcarilli marked this conversation as resolved.
Show resolved Hide resolved
tensor_lists[0].reserve(scaled_grads.size());
auto expected_device = scaled_grads[0].device();
for (const Tensor& t : scaled_grads) {
// Ensures GradScaler filtered scaled_grads by device.
TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor.");
TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device.");
TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor.");
if (!t.is_non_overlapping_and_dense()) {
// t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel.
_amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t),
found_inf,
inv_scale);
} else {
tensor_lists[0].push_back(t);
}
}
if (tensor_lists[0].size() == 0) {
return;
}
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
iter.dtype(),
"_amp_non_finite_check_and_unscale_cuda",
[&iter, &found_inf, &inv_scale] {
tensor_lists[0][0].scalar_type(),
"_amp_foreach_non_finite_check_and_unscale_cuda",
[&tensor_lists, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();

gpu_kernel(iter, [found_inf_ptr, inv_scale_ptr]GPU_LAMBDA(scalar_t val) -> scalar_t {
float fval = static_cast<float>(val);
// See isfinite_ensure_cuda_math above.
if (!isfinite_ensure_cuda_math(fval)) {
*found_inf_ptr = 1.f;
}
const auto inv_scale_val = *inv_scale_ptr; // Every thread accesses inv_scale, but it will hit in cache.
return static_cast<scalar_t>(inv_scale_val == 1.f ? fval : fval*inv_scale_val);
});
using opmath_t = get_opmath_t<scalar_t>::opmath_t;

// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
multi_tensor_apply<1>(tensor_lists,
UnaryOpFunctor_<scalar_t>(),
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
// There is a slight asymmetry here with the TensorIterator kernel above.
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
if (!isfinite_ensure_cuda_math(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
});
});
}

Expand Down
17 changes: 13 additions & 4 deletions aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
Expand Up @@ -6,8 +6,9 @@ namespace at { namespace native {

template<template<class> class Op>
std::vector<Tensor> foreach_tensor_list_op(TensorList tensors1, TensorList tensors2, Scalar alpha = 1) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t: tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}
Expand All @@ -17,20 +18,28 @@ std::vector<Tensor> foreach_tensor_list_op(TensorList tensors1, TensorList tenso
tensor_lists.emplace_back(std::move(vec_res));

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors1[0].scalar_type(), "foreach_binary_op_list_cuda", [&]() {
multi_tensor_apply<3>(tensor_lists, BinaryOpListAlphaFunctor<scalar_t, Op>(), alpha.to<scalar_t>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<3>(tensor_lists,
BinaryOpListAlphaFunctor<scalar_t>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
});

return tensor_lists[2];
}

template<template<class> class Op>
void foreach_tensor_list_op_(TensorList tensors1, TensorList tensors2, Scalar alpha = 1) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors1.vec());
tensor_lists.emplace_back(tensors2.vec());

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors1[0].scalar_type(), "foreach_binary_op_list_cuda_", [&]() {
multi_tensor_apply<2>(tensor_lists, BinaryOpListAlphaFunctor_<scalar_t, Op>(), alpha.to<scalar_t>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<2>(tensor_lists,
BinaryOpListAlphaFunctor_<scalar_t>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
});
}

Expand Down
17 changes: 13 additions & 4 deletions aten/src/ATen/native/cuda/ForeachBinaryOpScalar.cu
Expand Up @@ -8,8 +8,9 @@ template<template<class> class Op>
std::vector<Tensor> foreach_binary_op(TensorList tensors, Scalar scalar) {
check_foreach_api_restrictions(tensors);

std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors.size());
for (const auto& t: tensors) {
vec_res.emplace_back(at::native::empty_like(t));
}
Expand All @@ -18,7 +19,11 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, Scalar scalar) {
tensor_lists.emplace_back(std::move(vec_res));

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalar_cuda", [&]() {
multi_tensor_apply<2>(tensor_lists, BinaryOpScalarFunctor<scalar_t, Op>(), scalar.to<scalar_t>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<2>(tensor_lists,
BinaryOpScalarFunctor<scalar_t>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
return tensor_lists[1];
}
Expand All @@ -27,11 +32,15 @@ template<template<class> class Op>
void foreach_binary_op_(TensorList tensors, Scalar scalar) {
check_foreach_api_restrictions(tensors);

std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalar_cuda_", [&]() {
multi_tensor_apply<1>(tensor_lists, BinaryOpScalarFunctor_<scalar_t, Op>(), scalar.to<scalar_t>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<1>(tensor_lists,
BinaryOpScalarFunctor_<scalar_t>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
}

Expand Down
17 changes: 13 additions & 4 deletions aten/src/ATen/native/cuda/ForeachBinaryOpScalarList.cu
Expand Up @@ -6,8 +6,9 @@ namespace at { namespace native {

template<template<class> class Op>
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do the same for tensor_lists? Here and everywhere else

Copy link
Collaborator Author

@mcarilli mcarilli Sep 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've already done so, ctrl+f "reserve(". Lmk if you spot any location i missed.

for (const auto& t: tensors) {
vec_res.emplace_back(at::native::empty_like(t));
}
Expand All @@ -16,18 +17,26 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> s
tensor_lists.emplace_back(vec_res);

AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
multi_tensor_apply<2>(tensor_lists, scalars, BinaryOpScalarListFunctor<scalar_t, Op>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<2>(tensor_lists,
scalars,
BinaryOpScalarListFunctor<scalar_t>(),
Op<opmath_t>());
});
return tensor_lists[1];
}

template<template<class> class Op>
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());

AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
multi_tensor_apply<1>(tensor_lists, scalars, BinaryOpScalarListFunctor_<scalar_t, Op>());
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<1>(tensor_lists,
scalars,
BinaryOpScalarListFunctor_<scalar_t>(),
Op<opmath_t>());
});
}

Expand Down