-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove more uses of DimensionedTensorType
#23060
Conversation
3e724eb
to
67bc807
Compare
DimensionedTensorType
DimensionedTensorType
070a914
to
0abb49f
Compare
f934329
to
fcc382e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Krovatkin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks for cleaning these up.
auto tensor_type = input->type()->cast<DimensionedTensorType>(); | ||
// As of now, we do the decomposition for batchnorm/layernorm on GPU device only | ||
if (!tensor_type || tensor_type->device().is_cpu()) { | ||
if (!input->type()->isSubclass(TypeKind::TensorType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The canonical way to do this is: !input->type()->isSubtype(TensorType::get())
@@ -182,13 +182,16 @@ struct GraphFuser { | |||
} | |||
|
|||
bool isFusableDevice(Value *v) { | |||
auto tensor_type = v->type()->cast<DimensionedTensorType>(); | |||
if (!tensor_type) { | |||
if (!v->type()->isSubclass(TypeKind::TensorType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@@ -29,8 +29,10 @@ Node* InsertCastForCond(Value* cond_val, Graph* graph, Node* consumer_node) { | |||
|
|||
bool IsCondCastRequired(Value* cond_val) { | |||
const auto& type = cond_val->type(); | |||
if (type->isSubclass(TypeKind::DimensionedTensorType)) { | |||
return type->expect<DimensionedTensorType>()->scalarType() != c10::kBool; | |||
if (type->isSubclass(TypeKind::TensorType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here
torch/jit/__init__.py
Outdated
@@ -35,6 +35,13 @@ | |||
if sys.version_info[0] > 2: | |||
import pathlib | |||
|
|||
def _get_type_attr(type, aname): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have the 'dim' and scalarType methods handle returning None rather than throw/catch an exception. It makes things hard to debug (catch throw
now catches spurious things), and it can mask a real exception that is not the one that is expected.
fix a bug in an assertion remove extra new lines more changes to remove DimensionedTensorType insert a newline assert -> torch_check only use `ProfiledTensorCreate` on tensors and insert checks elsewhere add a missing import fixing a call to _get_type_attr pass a type instead of a value into _get_type_attr don't rely on a tensor type to be dimensioned address Zach's feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to even keep TensorType
instead of using ProfiledTensorType
with everything varying as the default?
c6068b3
to
f770a74
Compare
I think, that's the plan :-) eventually |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Krovatkin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Pull Request resolved: pytorch/pytorch#23060 Differential Revision: D16460391 Pulled By: Krovatkin fbshipit-source-id: b50ee87d22ad18b8cbfff719b199ea876ef172f1
@Krovatkin merged this pull request in 3d15ee1. |
No description provided.