-
Notifications
You must be signed in to change notification settings - Fork 25.6k
use TypeMeta instead of ScalarType in TensorOptions #12768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Differential Revision: D10419671 Differential Version: 60878909
Differential Revision: D10419671 Differential Version: 61134806
aten/src/ATen/Context.cpp
Outdated
TypeExtendedInterface& getType(TensorOptions options) { | ||
return globalContext().getType( | ||
options.backend(), options.dtype(), options.is_variable()); | ||
options.backend(), dataTypeToScalarType(options.dtype().id()), options.is_variable()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
} | ||
|
||
void set_dtype(optional<ScalarType> dtype) & noexcept { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return TensorOptions().dtype(dtype); | ||
} | ||
|
||
inline TensorOptions dtype(ScalarType dtype) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/cudnn/Descriptors.h
Outdated
AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size)); | ||
AT_ASSERT(options.device().type() == kCUDA); | ||
AT_ASSERT(options.dtype() == kByte); | ||
AT_ASSERT(dataTypeToScalarType(options.dtype().id()) == kByte); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
options); | ||
AT_CHECK( | ||
at::isFloatingType(options.dtype()), | ||
at::isFloatingType(dataTypeToScalarType(options.dtype().id())), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, but with comments.
Differential Revision: D10419671 Differential Version: 61337004
Differential Revision: D10419671 Differential Version: 61345557
Differential Revision: D10419671 Differential Version: 61438577
Differential Revision: D10419671 Differential Version: 61590771
Differential Revision: D10419671 Differential Version: 61603924
Differential Revision: D10419671 Differential Version: 61616762
Differential Revision: D10419671 Differential Version: 61622039
Differential Revision: D10419671 Differential Version: 61622526
|
||
/// Constructs a `TensorOptions` object with the given dtype. | ||
/// legacy constructor to support ScalarType | ||
/* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Differential Revision: D10419671 Differential Version: 61681127
Differential Revision: D10419671 Differential Version: 61685404
Differential Revision: D10419671 Differential Version: 61727366
Differential Revision: D10419671 Differential Version: 61759285
Summary: Pull Request resolved: pytorch/pytorch#12768 Note: DefaultTensorOptions no longer fits in 64-bits. I kept functions that take ScalarType as input to minimize changes for now. Reviewed By: ezyang Differential Revision: D10419671 fbshipit-source-id: 9cc8c5982fde9ff243e03d55c0c52c2aa2c7efd8
Stack:
:white_circle: #12766 Change return type of Tensor::dtype() from ScalarType to TypeMeta 💚
:white_circle: #12767 reduce Device to 32bits 💚
:black_circle: #12768 use TypeMeta instead of ScalarType in TensorOptions 💚
Note: DefaultTensorOptions no longer fits in 64-bits.
I kept functions that take ScalarType as input to minimize changes for now.
Differential Revision: D10419671