Skip to content

Commit

Permalink
[JIT] Improve May Contain Alias Using Contained Elements (#32326)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #32326

Now that we have type-level granularity we can improve `mayContainAlias` queries. Each new values is initialized as containing the wildcard set of each contained mutable type. Whenever a value is added to a container it is set to the wildcard set. Now, to check if any two values contain overlapping values, we can just check if the `containedMemoryLocations` of two sets overlap.

Test Plan: Imported from OSS

Differential Revision: D19563262

Pulled By: eellison

fbshipit-source-id: c6d7489749c14b2054a6d50ef75baca699ada471
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Jan 29, 2020
1 parent 25d33a2 commit c729614
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 62 deletions.
8 changes: 8 additions & 0 deletions test/cpp/jit/test_alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ void testContainerAliasing() {
graph(%inp: Tensor[]):
%x : str = prim::Constant[value="a"]()
%y : Tensor = prim::Constant()
%z : Tensor = prim::Constant()
%a : (Tensor) = prim::TupleConstruct(%y)
%b : Dict(str, Tensor) = prim::DictConstruct(%x, %y)
%c : Tensor[] = prim::ListConstruct(%y)
Expand All @@ -557,14 +558,17 @@ void testContainerAliasing() {

auto str_output = vmap["x"];
auto ten_output = vmap["y"];
auto local_var = vmap["z"];
AliasDb aliasDb(graph);

AT_ASSERT(graph->outputs().size() == 3);
for (auto out : graph->outputs()) {
AT_ASSERT(aliasDb.mayContainAlias(ten_output, out));
AT_ASSERT(!aliasDb.mayContainAlias(local_var, out));
}

AT_ASSERT(aliasDb.mayContainAlias(ten_output, graph->inputs()));
AT_ASSERT(!aliasDb.mayContainAlias(local_var, graph->inputs()));

AT_ASSERT(aliasDb.mayContainAlias({ten_output}, graph->outputs()));
AT_ASSERT(!aliasDb.mayContainAlias(str_output, graph->outputs()));
Expand Down Expand Up @@ -967,12 +971,16 @@ void testWildcards() {
AT_ASSERT(!aliasDb.hasWriters(int_list));
AT_ASSERT(aliasDb.hasWriters(opt_ten_list));
AT_ASSERT(aliasDb.hasWriters(ten_list));
AT_ASSERT(!aliasDb.mayContainAlias(int_list, opt_ten_list));
AT_ASSERT(aliasDb.mayContainAlias(ten_list, opt_ten_list));
AT_ASSERT(aliasDb.mayAlias(ten_list, opt_ten_list));

auto list_of_tensor_lists = vmap["ten_ten_list"];
AT_ASSERT(aliasDb.mayContainAlias(ten_list, list_of_tensor_lists));
AT_ASSERT(aliasDb.mayContainAlias(ten_list, vmap["ten"]));

AT_ASSERT(
!aliasDb.mayContainAlias(vmap["int_int_list"], list_of_tensor_lists));
}

// test invariant container aliasing
Expand Down
3 changes: 2 additions & 1 deletion test/jit/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys
import unittest
from torch.testing._internal.common_utils import enable_profiling_mode
from torch.testing._internal.common_utils import enable_profiling_mode, GRAPH_EXECUTOR, ProfilingMode
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_neural_style_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_neural_style(self, device='cuda', check_export_import=False)

@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor")
@staticmethod
def _test_mnist(self, device, check_export_import=True):
# eval() is present because dropout makes this nondeterministic
Expand Down
53 changes: 2 additions & 51 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,20 +843,6 @@ bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
return false;
}

bool AliasDb::cannotCheckAliasContainment(const Value* elem) const {
if (isContainerType(elem->type())) {
if (elem->node()->kind() != prim::TupleConstruct) {
return true;
}
auto inps = elem->node()->inputs();
return std::any_of(inps.begin(), inps.end(), [&](const Value* v) {
return cannotCheckAliasContainment(v);
});
}

return false;
}

bool AliasDb::mayContainAlias(Value* a, Value* b) const {
const std::vector<Value*> a_vec = {a};
const std::vector<Value*> b_vec = {b};
Expand All @@ -877,43 +863,8 @@ std::vector<Element*> AliasDb::getElements(at::ArrayRef<Value*> vs) const {
bool AliasDb::mayContainAlias(
const at::ArrayRef<Value*> a,
const at::ArrayRef<Value*> b) const {
std::vector<Element*> a_elements;
bool a_cannot_check_containment = false;
for (const auto& val : a) {
if (cannotCheckAliasContainment(val)) {
a_cannot_check_containment = true;
break;
}
if (mutableType(val)) {
a_elements.push_back(elementMap_.at(val));
}
}

std::vector<Element*> b_elements;
bool b_cannot_check_containment = false;
for (const auto& val : b) {
if (cannotCheckAliasContainment(val)) {
b_cannot_check_containment = true;
break;
}
if (mutableType(val)) {
b_elements.push_back(elementMap_.at(val));
}
}
// if we can check all elements of one value set and it doesn't contain any
// mutable types, then we can safely return false
if (!a_cannot_check_containment && a_elements.size() == 0) {
return false;
}
if (!b_cannot_check_containment && b_elements.size() == 0) {
return false;
}
// now both elements must contain mutable elements, so if we cannot check
// containment we must return true
if (a_cannot_check_containment || b_cannot_check_containment) {
return true;
}
return memoryDAG_->mayContainAlias(a_elements, b_elements);
auto a_elems = getElements(a);
return a_elems.size() == 0 ? false : memoryDAG_->mayContainAlias(a_elems, getElements(b));
}

// Make each value in the `from` list point to its partner in the `to` list
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/jit/passes/alias_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ class AliasDb {
// Register `v` as a wildcard value.
c10::optional<Element*> setWildcard(const Value* v);

// Is the element a wildcard or an unhandled container type,
// or does the element contain an element for which that's true
bool cannotCheckAliasContainment(const Value* elem) const;

// Is this a value which will not alias
bool nonAliasingValue(const Value* elem) const;

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/type_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace torch {
namespace jit {

size_t HashType::operator()(const TypePtr& type) const noexcept {
size_t HashType::operator()(const TypePtr& type) const {
if (auto named_type = type->cast<ClassType>()) {
return get_hash(named_type->name().value());
}
Expand All @@ -20,7 +20,7 @@ size_t HashType::operator()(const TypePtr& type) const noexcept {
return get_hash(typekind_hash, hashes);
};

bool EqualType::operator()(const TypePtr& a, const TypePtr& b) const noexcept {
bool EqualType::operator()(const TypePtr& a, const TypePtr& b) const {
return *a == *b;
};

Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/jit/type_hashing.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#pragma once

#include <ATen/core/jit_type.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir.h>

namespace torch {
namespace jit {

struct HashType {
size_t operator()(const TypePtr& type) const noexcept;
struct TORCH_API HashType {
TORCH_API size_t operator()(const TypePtr& type) const;
};

struct EqualType {
bool operator()(const TypePtr& a, const TypePtr& b) const noexcept;
struct TORCH_API EqualType {
TORCH_API bool operator()(const TypePtr& a, const TypePtr& b) const;
};

} // namespace jit
Expand Down

0 comments on commit c729614

Please sign in to comment.