Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize constant_table from tensor only to ivalue #37631

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,46 @@ TypePtr IValue::type() const {
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
}

void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend against adding this to IValue: it does not contain tests, and is not even turned on in this PR. It is the kind of simple function that people looking at ivalue will think it tested and implemented correctly. This should be implemented locally where needed unless we commit to completely implementing it and testing it.

if (visitor(*this)) {
// Short cut.
return;
}
switch (this->tag) {
case Tag::Tuple:
case Tag::GenericList: {
c10::ArrayRef<IValue> elems;
if (isTuple()) {
elems = this->toTuple()->elements();
} else {
elems = this->toListRef();
}
for (auto& elem : elems) {
elem.visit(visitor);
}
break;
}
case Tag::GenericDict:
for (const auto& pair : this->toGenericDict()) {
pair.value().visit(visitor);
pair.key().visit(visitor);
}
break;
case Tag::Object: {
auto obj_type = type()->expect<ClassType>();
auto obj_value = toObject();
auto attributes = obj_type->getAttributes();
for (const auto& attr: attributes) {
auto attribute = obj_value->getAttr(attr.getName());
attribute.visit(visitor);
}
break;
}
default:
break;
}
}

void IValue::getSubValues(HashAliasedIValues& subValues) const {
switch (this->tag) {
case Tag::Tensor:
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ struct CAFFE2_API IValue final {
// Inserts all subvalues of this in subValues.
void getSubValues(HashAliasedIValues& subValues) const;

// Apply visitor to every subvalue.
// TODO: There are several places that recurse over IValue. This is fragile.
// This visitor should be used to recurse over ivalues.
void visit(const std::function<bool (const IValue &)>& visitor) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short comment explaining how the return bool is used?

IValue deepcopy() const;
IValue deepcopy(
HashAliasedIValueMap& memo) const;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ static TypePtr importType(
std::shared_ptr<CompilationUnit> cu,
const std::string& qual_name,
const std::string& src) {
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
auto source = std::make_shared<torch::jit::Source>(src);
torch::jit::SourceImporter si(
cu,
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/jit/test_class_import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -48,7 +48,7 @@ static void import_libs(
void testClassImport() {
auto cu1 = std::make_shared<CompilationUnit>();
auto cu2 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down Expand Up @@ -83,7 +83,7 @@ void testClassImport() {
void testScriptObject() {
Module m1("m1");
Module m2("m2");
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
m1._ivalue()->compilation_unit(),
"__torch__.FooTest",
Expand Down Expand Up @@ -144,7 +144,7 @@ class FooBar1234(Module):

void testSaveLoadTorchbind() {
auto cu1 = std::make_shared<CompilationUnit>();
std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
// Import different versions of FooTest into two namespaces.
import_libs(
cu1,
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table) {
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
Expand All @@ -49,7 +49,7 @@ void testModuleInterfaceSerialization() {
Module parentMod("parentMod", cu);
Module subMod("subMod", cu);

std::vector<at::Tensor> constantTable;
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.OneForward",
Expand Down
4 changes: 4 additions & 0 deletions test/expect/TestJit.test_non_ascii_string.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def forward(self,
x: Tensor,
y: Tensor) -> str:
return torch.add(self.a, CONSTANTS.c0)
56 changes: 33 additions & 23 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest
import torch
import torch.nn as nn
from torch.testing._internal.jit_utils import JitTestCase
Expand All @@ -12,52 +13,63 @@
"instead.")

class TestFreezing(JitTestCase):
@unittest.skip("temporarily disable the test for fwd compatibility")
def test_freeze_module(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.a = 1 # folded
bzinodev marked this conversation as resolved.
Show resolved Hide resolved
self.b = 1.2 # folded
self.c = "hello" # folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.c2 = "hi\xA1" # not folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.f2 = [(1, "Over \u0e55\u0e57 57")]
self.g = ([1, 2], 3.2, "4.4", torch.tensor([5.5], requires_grad=True)) # folded
self.h = {"layer" : [torch.tensor([7.7], requires_grad=True)]}
self.h2 = {"layer\xB1" : [torch.tensor([8.8], requires_grad=True)]}
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
self.ts = [torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([3.0, 4.0], requires_grad=True)] # folded
self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]

def forward(self, x):
return str(self.a) + str(self.b) + self.c + str(self.d) + \
str(self.e) + str(self.f) + str(self.g) + str(self.h['layer']) + str(self.t) + str(self.ts) + str(self.tt)
return str(self.a) + str(self.b) + self.c + self.c2 + str(self.d) + \
str(self.e) + str(self.f) + str(self.f2) + str(self.g) + \
str(self.h) + str(self.h2) + str(self.t) + str(self.ts) + str(self.tt)


m = torch.jit.script(M())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
m._c = torch._C._freeze_module(m._c)

buffer = io.BytesIO()
torch.jit.save(m._c, buffer)
buffer.seek(0)
m2 = torch.jit.load(buffer)
# Check if frozen module looks as below:
# module m {
# attributes {
# tt = ...
# }
# ...
# }
self.assertFalse(m._c.hasattr('a'))
self.assertFalse(m._c.hasattr('b'))
self.assertFalse(m._c.hasattr('c'))
self.assertFalse(m._c.hasattr('d'))
self.assertFalse(m._c.hasattr('e'))
self.assertFalse(m._c.hasattr('f'))
self.assertFalse(m._c.hasattr('g'))
self.assertFalse(m._c.hasattr('h'))
self.assertFalse(m._c.hasattr('t'))
self.assertFalse(m._c.hasattr('ts'))
self.assertFalse(m._c.hasattr('tt'))
output_f = m.forward(input)
self.assertFalse(m2._c.hasattr('a'))
self.assertFalse(m2._c.hasattr('b'))
self.assertFalse(m2._c.hasattr('c'))
self.assertFalse(m2._c.hasattr('c2'))
self.assertFalse(m2._c.hasattr('d'))
self.assertFalse(m2._c.hasattr('e'))
self.assertFalse(m2._c.hasattr('f'))
self.assertFalse(m2._c.hasattr('f2'))
self.assertFalse(m2._c.hasattr('g'))
self.assertFalse(m2._c.hasattr('h'))
self.assertFalse(m2._c.hasattr('h2'))
self.assertFalse(m2._c.hasattr('t'))
self.assertFalse(m2._c.hasattr('ts'))
self.assertFalse(m2._c.hasattr('tt'))
output_f = m2.forward(input)
self.assertEqual(output_s, output_f)

def test_freeze_module_with_submodule(self):
Expand Down Expand Up @@ -926,8 +938,6 @@ def forward(self, x):
FileCheck().check_not('GetAttr[name=') \
.run(m._c._get_method('forward').graph)



def test_freeze_module_detach_gradient(self):
mod = nn.Conv2d(8, 3, 4, 2, 1)
self.assertTrue(mod.weight.requires_grad)
Expand Down
19 changes: 19 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,25 @@ def forward(self, x, y):
foo_loaded = torch.jit.load(buffer)
self.assertExpected(foo_loaded.forward.code)

@unittest.skip("temporarily disable the test for fwd compatibility")
def test_non_ascii_string(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.a = "Over \u0e55\u0e57 57"

@torch.jit.script_method
def forward(self, x, y):
return self.a + "hi\xA1"

foo = Foo()
buffer = io.BytesIO()
torch.jit.save(foo, buffer)

buffer.seek(0)
foo_loaded = torch.jit.load(buffer)
self.assertExpected(foo_loaded.forward.code)

def test_function_default_values(self):
outer_var = torch.tensor(20)
outer_var2 = torch.tensor(30)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/jit_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ bool is_enabled(const char* cfname, JitLoggingLevels level) {
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();
}
Expand Down
32 changes: 16 additions & 16 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,23 +993,23 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
return pp.str();
})
.def_property_readonly(
"code_with_constants",
[](Module& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down Expand Up @@ -1156,10 +1156,10 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](const StrongFunctionPtr& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;

PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printFunction(*self.function_);
return pp.str();
})
Expand Down Expand Up @@ -1201,21 +1201,21 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"code",
[](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
return pp.str();
})
.def_property_readonly("code_with_constants", [](Method& self) {
std::vector<at::Tensor> tensors;
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps);
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
std::map<std::string, at::Tensor> consts;
std::map<std::string, at::IValue> consts;
int i = 0;
for (auto const& tensor : tensors) {
consts["c" + std::to_string(i)] = tensor;
for (auto const& constant : constants) {
consts["c" + std::to_string(i)] = constant;
i += 1;
}
return std::make_tuple(pp.str(), consts);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/serialization/export_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class ScriptModuleSerializer {
}

caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::Tensor> constant_table_;
std::vector<at::IValue> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
TypeNameUniquer type_name_uniquer_;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/serialization/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ScriptModuleDeserializer final {
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
std::vector<at::Tensor> constants_table_;
std::vector<at::IValue> constants_table_;
SourceImporter source_importer_;
std::string export_prefix_ = "code/";
};
Expand Down Expand Up @@ -264,7 +264,7 @@ Module ScriptModuleDeserializer::deserialize(
}
auto tuple = readArchive("constants").toTuple();
for (auto constant : tuple->elements()) {
constants_table_.push_back(constant.toTensor());
constants_table_.push_back(constant.toIValue());
}
auto m = Module(readArchive("data").toObject());
rewriteQuantizedConvForBC(m);
Expand Down