Skip to content

Commit

Permalink
Update on "Use object identity for deepcopy memo"
Browse files Browse the repository at this point in the history
Copy of #126089, with some fixes & tests (TODO)

[TODO description]


[ghstack-poisoned]
  • Loading branch information
davidberard98 committed May 14, 2024
1 parent 8cc43dd commit 2a7cfe5
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
c10::optional<at::Device> device = c10::nullopt) const;

c10::intrusive_ptr<Object> deepcopy(
IValue::HashAliasedIValueMap& memo,
IValue::HashIdentityIValueMap& memo,
c10::optional<at::Device> device = c10::nullopt) const;

bool is_weak_compilation_ref() const {
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ Module Module::deepcopy(c10::optional<at::Device> device) const {

Module Module::clone(bool inplace) const {
std::unordered_map<TypePtr, TypePtr> type_remap;
IValue::HashAliasedIValueMap memo;
IValue::HashIdentityIValueMap memo;
const std::unordered_set<std::string> ignored_methods;
const std::unordered_set<std::string> ignored_attributes;
return clone_impl(
Expand All @@ -335,15 +335,15 @@ Module Module::clone(
const std::unordered_set<std::string>& ignored_methods,
const std::unordered_set<std::string>& ignored_attributes) const {
std::unordered_map<TypePtr, TypePtr> type_remap;
IValue::HashAliasedIValueMap memo;
IValue::HashIdentityIValueMap memo;
return clone_impl(
type_remap, inplace, memo, ignored_methods, ignored_attributes);
}

Module Module::clone_impl(
std::unordered_map<TypePtr, TypePtr>& type_remap,
bool inplace,
IValue::HashAliasedIValueMap memo,
IValue::HashIdentityIValueMap memo,
const std::unordered_set<std::string>& ignored_methods,
const std::unordered_set<std::string>& ignored_attributes) const {
// Create a new _ivalue in the same compilation unit.
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ struct TORCH_API Module : public Object {
Module clone_impl(
std::unordered_map<TypePtr, TypePtr>& type_remap,
bool inplace,
IValue::HashAliasedIValueMap memo,
IValue::HashIdentityIValueMap memo,
const std::unordered_set<std::string>& ignored_methods,
const std::unordered_set<std::string>& ignored_attributes) const;

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/passes/quantization/insert_observers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ModuleCloneHelper {
const ModuleQConfigMap& module_qconfig_map,
bool inplace = false) {
std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
IValue::HashAliasedIValueMap memo;
IValue::HashIdentityIValueMap memo;
return clone_impl(
module, module_qconfig_map, type_remap, inplace, std::move(memo));
}
Expand All @@ -103,7 +103,7 @@ class ModuleCloneHelper {
const ModuleQConfigMap& module_qconfig_map,
std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
bool inplace,
IValue::HashAliasedIValueMap memo) {
IValue::HashIdentityIValueMap memo) {
auto qconfig = module_qconfig_map.at(module._ivalue());
auto type = module.type();
// Create a new _ivalue in the same compilation unit.
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,13 @@ static constexpr std::array<const char*, 48> magic_method_names = {
};

struct DeepCopyMemoTable {
std::shared_ptr<IValue::HashAliasedIValueMap> map;
std::shared_ptr<IValue::HashIdentityIValueMap> map;
};

IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
if (!memo.contains(py::str("__torch_script_memo_table"))) {
memo["__torch_script_memo_table"] =
DeepCopyMemoTable{std::make_shared<IValue::HashAliasedIValueMap>()};
DeepCopyMemoTable{std::make_shared<IValue::HashIdentityIValueMap>()};
}
auto& ivalue_memo =
*py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;
Expand Down

0 comments on commit 2a7cfe5

Please sign in to comment.