Skip to content

Commit

Permalink
fix test on "[te] Fix clamp with uint8 args"
Browse files Browse the repository at this point in the history
Riddle me this, batman: how could `torch.clamp(torch.tensor([0], dtype=torch.uint8), -10, 10)` equal `10`?  The answer: the min/max args are first cast to the dtype of the input, giving min=246 and max 10.  Then you have to apply Min and Max in the right order: `Min(Max(in, min), max)`.  Differ in any way and you're doomed.  Hooray.

This PR makes TE match eager mode for this operator, plus fixes a major facepalm in the llvm min/max codegen where we were always generating signed comparisons.

Differential Revision: [D25456366](https://our.internmc.facebook.com/intern/diff/D25456366/)

[ghstack-poisoned]
  • Loading branch information
bertmaher committed Dec 12, 2020
2 parents 617c2b7 + 693e908 commit 6e611f2
Show file tree
Hide file tree
Showing 91 changed files with 2,223 additions and 845 deletions.
1 change: 0 additions & 1 deletion .jenkins/pytorch/codegen-test.sh
Expand Up @@ -37,7 +37,6 @@ python -m tools.setup_helpers.generate_code \
mkdir -p "$OUT"/pyi/torch/_C
mkdir -p "$OUT"/pyi/torch/nn
python -m tools.pyi.gen_pyi \
--declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--deprecated-functions-path tools/autograd/deprecated.yaml \
--out "$OUT"/pyi
Expand Down
1 change: 1 addition & 0 deletions .jenkins/pytorch/multigpu-test.sh
Expand Up @@ -21,4 +21,5 @@ time python test/run_test.py --verbose -i distributed/test_jit_c10d
time python test/run_test.py --verbose -i distributed/test_distributed_fork
time python test/run_test.py --verbose -i distributed/test_c10d
time python test/run_test.py --verbose -i distributed/test_c10d_spawn
time python test/run_test.py --verbose -i distributed/rpc/test_tensorpipe_agent
assert_git_not_dirty
451 changes: 294 additions & 157 deletions aten/src/ATen/Dispatch.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion aten/src/ATen/TensorIndexing.h
Expand Up @@ -227,7 +227,7 @@ static inline Tensor applySelect(
static inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
// booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty.
if (value) {
return at::native::zeros({1}, {}, self.options().dtype(kLong));
return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
} else {
return at::empty({0}, {}, self.options().dtype(kLong));
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/ThreadLocalState.cpp
Expand Up @@ -19,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
grad_mode_enabled_ = GradMode::is_enabled();
}
#endif
bumped_record_all_functions_ = at::checkRecordAllFunctions();
}

/* static */
Expand Down
24 changes: 23 additions & 1 deletion aten/src/ATen/ThreadLocalState.h
Expand Up @@ -38,25 +38,47 @@ class TORCH_API ThreadLocalState {
bool grad_mode_enabled_;
#endif

// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;

friend class ThreadLocalStateGuard;
};

// Guard to set and reset the thread local state
class TORCH_API ThreadLocalStateGuard {
public:
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
: prev_state_(ThreadLocalState()) {
: prev_state_(ThreadLocalState()),
bumped_record_all_functions_(state.bumped_record_all_functions_) {
// Special handling of RecordFunction pre-sampling optimization:
// pre-samping is enabled (bumped) when there're non-sampled
// (or high-frequency) global or TLS callbacks.
//
// ThreadLocalStateGuard simply resets RecordFunction's TLS and
// hence its thread local callbacks.
//
// Checking if the pre-sampling was enabled and preserving it in the
// async task by calling bumpRecordAllFunctions() and the corresponding
// releaseRecordAllFunctions()
if (bumped_record_all_functions_) {
at::bumpRecordAllFunctions();
}
// set the given state across the thread boundary
ThreadLocalState::setThreadLocalState(state);
}

~ThreadLocalStateGuard() {
// restore previously set variables
ThreadLocalState::setThreadLocalState(prev_state_);
if (bumped_record_all_functions_) {
at::releaseRecordAllFunctions();
}
}

private:
const ThreadLocalState prev_state_;
// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;
};

template <typename T>
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/Utils.cpp
Expand Up @@ -57,8 +57,12 @@ Tensor empty_cpu(
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
}

auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
if (memory_format_opt.has_value()) {
// Restriding a just-created empty contiguous tensor does nothing.
if (*memory_format_opt != MemoryFormat::Contiguous) {
tensor.unsafeGetTensorImpl()->empty_tensor_restride(*memory_format_opt);
}
}

return tensor;
}
Expand Down
Expand Up @@ -119,14 +119,6 @@ namespace impl {
"You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
};

template<class T, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<std::vector<T>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
static_assert(!std::is_same<T, at::Scalar>::value,
"You tried to register a kernel with an unsupported input type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
// TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector<T>. Please use List<T> instead.");
};

template<class T, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
: assert_is_valid_input_type<T, AllowDeprecatedTypes> {
Expand Down
81 changes: 49 additions & 32 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Expand Up @@ -371,28 +371,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(A
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
Expand Down Expand Up @@ -429,20 +440,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
const auto& kernel = entry.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
kernel.callBoxed(op, stack);
return;
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
kernel.callBoxed(op, stack);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue.cpp
Expand Up @@ -22,7 +22,7 @@ namespace ivalue {

// This is in ivalue.cpp because we need to access Type::annotation_str, which
// is declared in jit_type.h
void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
void checkCustomClassType(const Type* expected_type, const Type* actual_type) {
// NB: doing pointer comparison here
// If in the future there ever arises a need to call operator== on custom class
// Type's, this needs to be changed!
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -949,8 +949,8 @@ TORCH_API ska::flat_hash_map<std::type_index, c10::ClassTypePtr>&
getCustomClassTypeMap();

template <typename T>
c10::ClassTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
c10::ClassTypePtr getCustomClassTypeImpl() {
auto& tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(std::type_index(typeid(T)));
if (res == tmap.end()) {
throw c10::Error("Can't find class id in custom class type map", "");
Expand All @@ -959,9 +959,13 @@ c10::ClassTypePtr getCustomClassType() {
}

template <typename T>
inline bool isCustomClassRegistered() {
auto tmap = c10::getCustomClassTypeMap();
return tmap.find(std::type_index(typeid(T))) != tmap.end();
const c10::ClassTypePtr& getCustomClassType() {
// Classes are never unregistered from getCustomClassTypeMap and the
// hash lookup can be a hot path, so just cache.
// For the same reason, it's fine If this ends up getting duplicated across
// DSO boundaries for whatever reason.
static c10::ClassTypePtr cache = getCustomClassTypeImpl<T>();
return cache;
}

TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
Expand Down
27 changes: 15 additions & 12 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -172,7 +172,7 @@ inline at::Generator IValue::toGenerator() const& {
namespace ivalue {

void CAFFE2_API
checkCustomClassType(TypePtr expected_type, TypePtr actual_type);
checkCustomClassType(const Type* expected_type, const Type* actual_type);

template <typename T>
using Shared = c10::intrusive_ptr<T>;
Expand Down Expand Up @@ -820,8 +820,8 @@ c10::intrusive_ptr<T> IValue::toCustomClass() && {
obj->slots().size() == 1,
"Tried to cast IValue to custom class but it did "
"not contain a custom class!");
auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
ivalue::checkCustomClassType(expected_type, type());
const Type* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
ivalue::checkCustomClassType(expected_type, type().get());
auto userObj =
c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
return userObj;
Expand All @@ -838,8 +838,8 @@ c10::intrusive_ptr<T> IValue::toCustomClass() const& {
obj->slots().size() == 1,
"Tried to cast IValue to custom class but it did "
"not contain a custom class!");
auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
ivalue::checkCustomClassType(expected_type, type());
const Type* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
ivalue::checkCustomClassType(expected_type, type().get());
auto userObj =
c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
return userObj;
Expand Down Expand Up @@ -1149,13 +1149,16 @@ template <
typename T,
std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
IValue::IValue(c10::intrusive_ptr<T> custom_class) {
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<T>>()) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
auto classType = c10::getCustomClassType<c10::intrusive_ptr<T>>();
TypePtr classType = []() {
try {
return c10::getCustomClassType<c10::intrusive_ptr<T>>();
} catch (const c10::Error&) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
}();
auto ivalue_obj = c10::ivalue::Object::create(
c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
Expand Down
19 changes: 12 additions & 7 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -1727,13 +1727,18 @@ namespace detail {
template <typename T>
struct getTypePtr_ final {
static TypePtr call() {
TORCH_CHECK(
isCustomClassRegistered<T>(),
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
auto res = getCustomClassType<T>();
TypePtr res = []() {
try {
return getCustomClassType<T>();
} catch(const c10::Error&) {
TORCH_CHECK(
false,
"Type ",
c10::util::get_fully_qualified_type_name<T>(),
" could not be converted to any of the known types."
);
}
}();
return std::dynamic_pointer_cast<Type>(std::move(res));
}
};
Expand Down

0 comments on commit 6e611f2

Please sign in to comment.