Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
593ed03
The c++ structure might be...ok?git diff Still need to compile.
Jul 27, 2020
f0d95c0
Merge remote-tracking branch 'upstream/master' into amp_rnn
Jul 27, 2020
13ec785
Adding RNNUtils.h so i don't accidentally rinse it again
Jul 27, 2020
393abac
compiles?
Jul 30, 2020
8395faa
It works!
Jul 31, 2020
b03faac
Integrated test. Forward succeeds for all cases, backward still fail…
Jul 31, 2020
7be3005
Merge remote-tracking branch 'upstream/master' into amp_rnn
Jul 31, 2020
bdad360
Fix bias backward
Aug 2, 2020
1df7e92
merge autocast_mode.cpp
Aug 3, 2020
0222d5f
_cudnn_rnn not UNBOXED anymore
Aug 3, 2020
9628fb4
changed back to impl_UNBOXED
Aug 3, 2020
a9819c2
Update _cudnn_rnn signature
Aug 3, 2020
51c3d8b
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 5, 2020
8cb7048
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 5, 2020
18e6467
addressing some comments
Aug 6, 2020
39818b9
split _cudnn_rnn wrapper into ATen/cudnn
Aug 7, 2020
7e93e94
more reorg
Aug 7, 2020
f467272
AutocastRNN to BUILD.bazel
Aug 7, 2020
c31cbb6
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 7, 2020
873a06f
allow some cached_cast calls inline
Aug 9, 2020
ea3f039
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 9, 2020
8f2eebf
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 10, 2020
793d92d
skipIfRocm for test, delete flake8 error
Aug 11, 2020
e062f67
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 11, 2020
452cef5
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 11, 2020
933fcdd
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 14, 2020
7f10ba1
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 15, 2020
8cf8c98
Merge remote-tracking branch 'upstream/master' into amp_rnn
Aug 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ filegroup(
"aten/src/ATen/cuda/CublasHandlePool.cpp",
"aten/src/ATen/cuda/PinnedMemoryAllocator.cpp",
"aten/src/ATen/cuda/detail/CUDAHooks.cpp",
"aten/src/ATen/cudnn/AutocastRNN.cpp",
"aten/src/ATen/cudnn/Descriptors.cpp",
"aten/src/ATen/cudnn/Handle.cpp",
"aten/src/ATen/cudnn/Types.cpp",
Expand Down
174 changes: 29 additions & 145 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,90 +59,10 @@ int decrement_nesting() {
return --nesting;
}

// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
fp32, // Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtype, // Treats functions (like softmax) that
// 1. we'd like to run in fp32 and
// 2. have a c10::optional<ScalarType> arg that controls the output type.
// fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
// don't touch it, otherwise, set it to at::kFloat.
fp32_append_dtype, // Treats functions (like norm) that
// 1. we'd like to run in fp32 and
// 2. have some overloads that accept an output type and other overloads that don't.
// fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
// The wrapper policy is: append at::kFloat to the args, and redispatch to the
// type-aware overload.
promote, // Run in the widest dtype among several args.
};

/********************************************************************
Logic to extract the promote type from any Tensor or TensorList args.
********************************************************************/

// Overload to catch Tensor args.
// If nextArg is floating-point, compare its scalar_type with our
// current best guess for the promote type, and update if necessary.
inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) {
if (current == at::kDouble) {
AT_ERROR("promote type is double in at::autocast::prioritize");
return current;
}
if (nextArg.is_cuda() && nextArg.is_floating_point()) {
auto next = nextArg.scalar_type();
if (next == at::kDouble) {
return current; // ignores double tensors
} else if (current == at::kFloat || next == at::kFloat) {
return at::kFloat; // prioritizes float over half
} else if (current == at::kHalf && next == at::kHalf) {
return at::kHalf;
} else {
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
return current;
}
} else {
return current;
}
}

// Overload to catch TensorList args (for e.g. cat, stack).
// Reuses the overload above to process each Tensor in the list.
inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) {
for (const auto& tensor : list) {
current = prioritize(current, tensor);
}
return current;
}

// Template to catch non-Tensor args (no-op that returns current best guess)
template<typename T>
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {
return current;
}

// Overload for the tail case.
inline at::ScalarType promote_type(at::ScalarType current) {
return current;
}

// Unpack args and determine if incoming float16 tensors need to be promoted to float32.
// Non-Tensor arguments are ignored.
template<typename Arg0, typename... Args>
inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) {
auto new_current = prioritize(current, arg0);
return promote_type(new_current, args...);
}

/****************************************************
Logic to apply cached casting to any Tensor argument.
****************************************************/
inline bool is_eligible(const Tensor& arg) {
return (arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble));
}

// Overload to catch Tensor args
// TODO (possible optimization): Move cast_cache to an inline function in a header
// (+ refactor the can_try_cache branch to call a small non-inline helper function.
// can_try_cache branch is the only part that's hard to inline in other files).
Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
Expand All @@ -165,61 +85,24 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
}
}

// Overload to process optional<Tensor>
c10::optional<Tensor> cached_cast(at::ScalarType to_type, const c10::optional<Tensor>& arg) {
if (arg.has_value()) {
return cached_cast(to_type, *arg);
} else {
return c10::nullopt;
}
}

// Overload to process TensorLists
std::vector<Tensor> cached_cast(at::ScalarType to_type, const TensorList& arg) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t));
}
return vec;
}

// Template to catch non-Tensor args.
template<typename T>
T cached_cast(at::ScalarType to_type, T arg) {
return arg;
}

/*******************************************************
Logic to flip an output dtype flag.
Keep it simple for now by assuming only one such flag is
present in the argument list. If I ever need a function
with more than flag I'll figure out something else.
The policy is:
If the user has explicity specified a dtype, respect it.
Otherwise, set it to the autocast type.
********************************************************/

// Overload to catch dtype flags
c10::optional<ScalarType> set_opt_dtype(at::ScalarType to_type, const c10::optional<ScalarType>& dtype) {
return dtype.has_value() ? dtype : to_type;
}

// Template to catch other args
template<typename T>
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
return arg;
}

template<typename... Args>
inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
return is_eligible(arg);
}

template<typename... Args>
inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) {
return (is_eligible(arg) ? to_type : arg.scalar_type());
}
// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
fp32, // Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtype, // Treats functions (like softmax) that
// 1. we'd like to run in fp32 and
// 2. have a c10::optional<ScalarType> arg that controls the output type.
// fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
// don't touch it, otherwise, set it to at::kFloat.
fp32_append_dtype, // Treats functions (like norm) that
// 1. we'd like to run in fp32 and
// 2. have some overloads that accept an output type and other overloads that don't.
// fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
// The wrapper policy is: append at::kFloat to the args, and redispatch to the
// type-aware overload.
promote, // Run in the widest dtype among several args.
};

/********************************************************************************************************
Templates to provide wrapper functions
Expand All @@ -239,7 +122,7 @@ template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class Ar
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kHalf, args)...);
}
};
Expand All @@ -248,7 +131,7 @@ struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typel
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kFloat, args)...);
}
};
Expand All @@ -257,7 +140,7 @@ struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typel
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
if (firstarg_is_eligible(args...)) {
return (*F)(set_opt_dtype(at::kFloat, args)...);
} else {
Expand All @@ -272,7 +155,7 @@ struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::t
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
return (*F)(args..., out_type);
}
Expand All @@ -282,7 +165,7 @@ struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::ty
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocasting(DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
auto to_type = promote_type(at::kHalf, args...);
return (*F)(cached_cast(to_type, args)...);
}
Expand Down Expand Up @@ -319,6 +202,7 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op
"safe to autocast.");
}


#ifndef USE_STATIC_DISPATCH
namespace {
/*****************************************************************************************************************
Expand Down Expand Up @@ -422,7 +306,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
// The macro doesn't like this one so I had to write it out manually.
m.impl("native_layer_norm",
TORCH_FN((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)));
TORCH_FN((&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)));
KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
Expand Down Expand Up @@ -490,7 +374,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

m.impl("binary_cross_entropy",
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}

}
Expand Down
123 changes: 123 additions & 0 deletions aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,128 @@ TORCH_API void clear_cache();
TORCH_API int increment_nesting();
TORCH_API int decrement_nesting();

/********************************************************************
Logic to extract the promote type from any Tensor or TensorList args.
********************************************************************/

// Overload to catch Tensor args.
// If nextArg is floating-point, compare its scalar_type with our
// current best guess for the promote type, and update if necessary.
inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) {
if (current == at::kDouble) {
AT_ERROR("promote type is double in at::autocast::prioritize");
return current;
}
if (nextArg.is_cuda() && nextArg.is_floating_point()) {
auto next = nextArg.scalar_type();
if (next == at::kDouble) {
return current; // ignores double tensors
} else if (current == at::kFloat || next == at::kFloat) {
return at::kFloat; // prioritizes float over half
} else if (current == at::kHalf && next == at::kHalf) {
return at::kHalf;
} else {
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
return current;
}
} else {
return current;
}
}

// Overload to catch TensorList args (for e.g. cat, stack).
// Reuses the overload above to process each Tensor in the list.
inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) {
for (const auto& tensor : list) {
current = prioritize(current, tensor);
}
return current;
}

// Template to catch non-Tensor args (no-op that returns current best guess)
template<typename T>
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {
return current;
}

// Overload for the tail case.
inline at::ScalarType promote_type(at::ScalarType current) {
return current;
}

// Unpack args and determine if incoming float16 tensors need to be promoted to float32.
// Non-Tensor arguments are ignored.
template<typename Arg0, typename... Args>
inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) {
auto new_current = prioritize(current, arg0);
return promote_type(new_current, args...);
}

/****************************************************
Logic to apply cached casting to any Tensor argument.
****************************************************/
inline bool is_eligible(const Tensor& arg) {
return (arg.defined() && arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble));
}

// Overload to catch Tensor args
TORCH_API Tensor cached_cast(at::ScalarType to_type, const Tensor& arg);

// Overload to process optional<Tensor>
inline c10::optional<Tensor> cached_cast(at::ScalarType to_type, const c10::optional<Tensor>& arg) {
if (arg.has_value()) {
return cached_cast(to_type, *arg);
} else {
return c10::nullopt;
}
}

// Overload to process TensorLists
inline std::vector<Tensor> cached_cast(at::ScalarType to_type, const TensorList& arg) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t));
}
return vec;
}

// Template to catch non-Tensor args.
template<typename T>
inline T cached_cast(at::ScalarType to_type, T arg) {
return arg;
}

/*******************************************************
Logic to flip an output dtype flag.
Keep it simple for now by assuming only one such flag is
present in the argument list. If I ever need a function
with more than flag I'll figure out something else.
The policy is:
If the user has explicity specified a dtype, respect it.
Otherwise, set it to the autocast type.
********************************************************/

// Overload to catch dtype flags
c10::optional<ScalarType> inline set_opt_dtype(at::ScalarType to_type, const c10::optional<ScalarType>& dtype) {
return dtype.has_value() ? dtype : to_type;
}

// Template to catch other args
template<typename T>
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
return arg;
}

template<typename... Args>
inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
return is_eligible(arg);
}

template<typename... Args>
inline at::ScalarType type_from_firstarg(at::ScalarType to_type, const Tensor& arg, Args... args) {
return (is_eligible(arg) ? to_type : arg.scalar_type());
}

} // namespace autocast
} // namespace at
Loading