Skip to content

Commit 269aa7f

Browse files
committed
[PyTorch][JIT] use a better hash table in alias analysis
Pull Request resolved: #69854 ghstack-source-id: 145889449 Differential Revision: [D33039733](https://our.internmc.facebook.com/intern/diff/D33039733/)
1 parent b3f7460 commit 269aa7f

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

torch/csrc/jit/ir/alias_analysis.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/ir/alias_analysis.h>
22

3+
#include <c10/util/flat_hash_map.h>
34
#include <c10/util/irange.h>
45
#include <torch/csrc/jit/jit_log.h>
56
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
@@ -26,7 +27,7 @@ c10::MaybeOwned<TypePtr> toSingleType(const AliasTypeSet& mut_types) {
2627
class MutableTypePtrHelper {
2728
public:
2829
explicit MutableTypePtrHelper(
29-
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache)
30+
ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache)
3031
: mutable_type_cache_(mutable_type_cache) {}
3132

3233
// Map any mutable type to a type such that all other types which the
@@ -140,12 +141,12 @@ class MutableTypePtrHelper {
140141
return c10::nullopt;
141142
}
142143
}
143-
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
144+
ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
144145
};
145146

146147
bool isMutableTypeImpl(
147148
const TypePtr& type,
148-
std::unordered_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
149+
ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
149150
// Check common cases to avoid recursively constructing type in
150151
// `mapTypeToAliasTypeSetPtrImpl`
151152
auto kind = type->kind();

torch/csrc/jit/ir/alias_analysis.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class AliasDb {
260260
// Mapping of values to MemoryDAG elements
261261
ska::flat_hash_map<const Value*, Element*> elementMap_;
262262
// All wildcard Elements (one for each unique mutable type)
263-
std::unordered_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_;
263+
ska::flat_hash_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_;
264264
Element* getWildcard(const TypePtr& type) const;
265265
c10::optional<Element*> tryGetOrCreateWildcard(const TypePtr& type);
266266
void addContainedTypesToFreshElement(
@@ -276,7 +276,7 @@ class AliasDb {
276276
bool hasWriters(const at::ArrayRef<Value*>& values) const;
277277

278278
// Cached mapping of type ptrs to their mutable types
279-
mutable std::unordered_map<TypePtr, AliasTypeSet> mapped_mutable_types_;
279+
mutable ska::flat_hash_map<TypePtr, AliasTypeSet> mapped_mutable_types_;
280280

281281
/**
282282
* State for tracking write info.

0 commit comments

Comments
 (0)