Skip to content

Commit

Permalink
Update on "[dynamo][nn module] Check for duplicate tensors in registe…
Browse files Browse the repository at this point in the history
…r_attr_or_module"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed May 2, 2024
2 parents e3c721e + f7e5e6b commit 2ead9c5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 40 deletions.
72 changes: 50 additions & 22 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self):
self.cache_entry = None
self.extra_state = None
self.id_matched_objs = None
self.no_tensor_aliasing_sources = []

def get_guard_lines(self, guard):
guard_name = guard.__class__.__name__
Expand Down Expand Up @@ -2094,6 +2095,7 @@ def add_code_part(code_part, guard, log_only=False):
# when the CacheEntry is constructed
guard_fn.cache_entry = None
guard_fn.extra_state = None
guard_fn.no_tensor_aliasing_sources = tensor_check_names
return guard_fn

def invalidate(self):
Expand Down Expand Up @@ -2184,6 +2186,23 @@ def is_recompiles_verbose_enabled():
return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose")


def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
duplicate_tensors = []
global_scope = dict(guard_manager.global_scope)
ids_to_source = collections.defaultdict(list)
for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined]
global_scope["__compile_source__"] = tensor_source
tensor_id = id(eval(tensor_source, global_scope, scope))
ids_to_source[tensor_id].append(tensor_source)

for key in ids_to_source:
if len(ids_to_source[key]) > 1:
duplicate_tensors.append(f"{ids_to_source[key]}")

reason = ", ".join(duplicate_tensors)
return [f"Duplicate tensors found: {reason}"]


def get_guard_fail_reason(
guard_fn: GuardFn,
code: types.CodeType,
Expand All @@ -2198,6 +2217,8 @@ def get_guard_fail_reason(
scope.update(guard_fn.closure_vars)
reasons: List[str] = []

no_tensor_aliasing_check_failed = False

verbose_code_parts: List[str] = []
if config.enable_cpp_guard_manager:
guard_manager = guard_fn
Expand All @@ -2213,35 +2234,42 @@ def get_guard_fail_reason(
# walk through this list and find the guard that failed. This is
# very important for symbolic shape guards which are currently
# installed as a lambda guard and can encompass a long list of code_parts.

if len(verbose_code_parts) == 1:
reasons = verbose_code_parts
verbose_code_parts = []
if "Duplicate tensor found" in verbose_code_parts[0]:
no_tensor_aliasing_check_failed = True
else:
reasons = verbose_code_parts
verbose_code_parts = []
else:
verbose_code_parts = guard_fn.verbose_code_parts
# This is not needed for CPP guard because the verbose check is already
# run in C++.
scope["___check_tensors"] = scope["___check_tensors_verbose"]

for part in verbose_code_parts:
global_scope = dict(guard_fn.global_scope)
global_scope["__compile_source__"] = part
with report_compile_source_on_error():
try:
fail_reason = eval(part, global_scope, scope)
except Exception as e:
if is_recompiles_verbose_enabled():
continue
else:
raise
# Only ___check_tensors knows how to return a fancy fail reason;
# for everything else we just report the code that failed

if isinstance(fail_reason, bool) and not fail_reason:
fail_reason = part
if isinstance(fail_reason, str):
reasons.append(fail_reason)
if not is_recompiles_verbose_enabled():
break
if no_tensor_aliasing_check_failed:
reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope)
else:
for part in verbose_code_parts:
global_scope = dict(guard_fn.global_scope)
global_scope["__compile_source__"] = part
with report_compile_source_on_error():
try:
fail_reason = eval(part, global_scope, scope)
except Exception as e:
if is_recompiles_verbose_enabled():
continue
else:
raise
# Only ___check_tensors knows how to return a fancy fail reason;
# for everything else we just report the code that failed

if isinstance(fail_reason, bool) and not fail_reason:
fail_reason = part
if isinstance(fail_reason, str):
reasons.append(fail_reason)
if not is_recompiles_verbose_enabled():
break

reason_str = "\n".join(reasons)
guard_failures[orig_code_map[code]].append(reason_str)
Expand Down
24 changes: 6 additions & 18 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1273,13 +1273,11 @@ class TENSOR_ALIASING : public RelationalGuard {
class NO_TENSOR_ALIASING : public RelationalGuard {
public:
NO_TENSOR_ALIASING(
long unsigned int num_tensors,
py::object tensor_names,
const py::list& tensor_names,
py::object verbose_code_parts)
: RelationalGuard(std::move(verbose_code_parts)),
_num_tensors(num_tensors),
_tensor_names(std::move(tensor_names)) {
_unique_tensors.reserve(num_tensors);
_tensor_names(tensor_names) {
_unique_tensors.reserve(tensor_names.size());
}

bool check_nopybind(PyObject* value) override { // borrowed ref
Expand All @@ -1303,30 +1301,22 @@ class NO_TENSOR_ALIASING : public RelationalGuard {
bool result = check_nopybind(value);

if (!result) {
std::stringstream fail_reason;
fail_reason << "Duplicate tensor found where not expected! ";
fail_reason << py::cast<std::string>(_tensor_names[_counter])
<< " should not alias to anything, but is aliased."
<< " Total number of tensors are " << _num_tensors;
return GuardDebugInfo(false, fail_reason.str(), 0);
return GuardDebugInfo(
false, "Duplicate tensor found where not expected!", 0);
}
_counter += 1;
return GuardDebugInfo(true, 1);
}

void reset_state() final {
_counter = 0;
for (auto item : _unique_tensors) {
Py_DECREF(item.first);
}
_unique_tensors.clear();
}

private:
long unsigned int _num_tensors;
py::list _tensor_names;
ska::flat_hash_map<PyObject*, std::nullptr_t> _unique_tensors;
long unsigned int _counter = 0;
};

class DYNAMIC_INDICES : public LeafGuard {
Expand Down Expand Up @@ -3186,9 +3176,7 @@ void install_no_tensor_aliasing_guard(
// relational guard. There is one guard object that is shared between multiple
// guard managers.
std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
guard_managers.size(),
std::move(tensor_names),
std::move(verbose_code_parts));
std::move(tensor_names), std::move(verbose_code_parts));

// Register the resetter on the toor gaurd mananger, so that it can reset
// the newly added relational guard when the guard eval fails.
Expand Down

0 comments on commit 2ead9c5

Please sign in to comment.