Skip to content

Commit

Permalink
[JIT] Use type cache in erasing shape information (#55828)
Browse files Browse the repository at this point in the history
Summary:
`unshapedType` can be very slow on a graph with many modules and recursively contained classes, because each Type you have to recursively descend and map over. Speed it up with a type cache.

Pull Request resolved: #55828

Reviewed By: ngimel

Differential Revision: D27717995

Pulled By: eellison

fbshipit-source-id: f1d502bef0356e78100c27bf00f6caf08a75d68c
  • Loading branch information
eellison authored and facebook-github-bot committed Apr 13, 2021
1 parent 8f953ef commit bbdb37b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,8 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
// what is the type, ignoring extra size/shape information?
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)

// xxx: be careful with calls because this can be very slow. If calling this on a graph
// use `EraseShapeInformation` in shape_analysis.h
inline TypePtr unshapedType(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
return TensorType::get();
Expand Down
58 changes: 49 additions & 9 deletions torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2172,29 +2172,69 @@ void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {

namespace {

void EraseShapeInformation(at::ArrayRef<Value*> vals) {
using TypeCache = std::unordered_map<TypePtr, TypePtr>;

TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache);

TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
if (type->isSubtypeOf(TensorType::get())) {
return TensorType::get();
}
std::vector<TypePtr> unshaped_contained_types;
for (const auto& contained_type : type->containedTypes()) {
unshaped_contained_types.push_back(
getOrCreateUnshapedType(contained_type, unshaped_type_cache));
}
return type->withContained(unshaped_contained_types);
}

TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache) {
auto maybe_cached_type = unshaped_type_cache.find(type);
if (maybe_cached_type != unshaped_type_cache.end()) {
return maybe_cached_type->second;
}
auto unshaped_type = unshapedTypeImpl(type, unshaped_type_cache);
unshaped_type_cache[type] = unshaped_type;
return unshaped_type;
}

void EraseShapeInformation(
const std::shared_ptr<Graph>& graph,
TypeCache& unshaped_type_cache);

void EraseShapeInformation(
at::ArrayRef<Value*> vals,
TypeCache& unshaped_type_cache) {
for (Value* v : vals) {
v->setType(unshapedType(v->type()));
v->setType(getOrCreateUnshapedType(v->type(), unshaped_type_cache));
}
}

void EraseShapeInformation(Block* b) {
EraseShapeInformation(b->inputs());
EraseShapeInformation(b->outputs());
void EraseShapeInformation(Block* b, TypeCache& unshaped_type_cache) {
EraseShapeInformation(b->inputs(), unshaped_type_cache);
EraseShapeInformation(b->outputs(), unshaped_type_cache);
for (Node* n : b->nodes()) {
EraseShapeInformation(n->outputs());
EraseShapeInformation(n->outputs(), unshaped_type_cache);
for (Block* sb : n->blocks()) {
EraseShapeInformation(sb);
EraseShapeInformation(sb, unshaped_type_cache);
}
if (n->hasAttribute(attr::Subgraph)) {
EraseShapeInformation(n->g(attr::Subgraph));
EraseShapeInformation(n->g(attr::Subgraph), unshaped_type_cache);
}
}
}

void EraseShapeInformation(
const std::shared_ptr<Graph>& graph,
TypeCache& unshaped_type_cache) {
EraseShapeInformation(graph->block(), unshaped_type_cache);
}

} // anonymous namespace

void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
EraseShapeInformation(graph->block());
TypeCache unshaped_type_cache;
EraseShapeInformation(graph->block(), unshaped_type_cache);
}
} // namespace jit
} // namespace torch

0 comments on commit bbdb37b

Please sign in to comment.