Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Add support for named_modules() #29495

Closed
wants to merge 8 commits into from
23 changes: 23 additions & 0 deletions test/test_jit.py
Expand Up @@ -8765,6 +8765,29 @@ def forward(self, v):
with self.assertRaisesRegex(Exception, "object is not iterable"):
print([val for val in m])

def test_underscore_modules(self):
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.mod = nn.ReLU()
self.mod2 = nn.ReLU()

def forward(self, x):
li = torch.jit.annotate(List[str], [])
for mod_name in self._modules:
li.append(mod_name)

for mod in self._modules.values():
x = mod(x)

for mod_name, mod in self._modules.items():
x = mod(x)
li.append(mod_name)

return li, x

self.checkModule(MyMod(), (torch.tensor(1.),))

def test_attr_qscheme_script(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
57 changes: 19 additions & 38 deletions torch/csrc/jit/script/python_sugared_value.cpp
Expand Up @@ -208,9 +208,7 @@ Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}

SugaredValuePtr ModuleValue::desugarModuleContainer(
bool get_keys,
bool get_values,
std::shared_ptr<SugaredModuleDict> ModuleValue::getSugaredModuleDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> submoduleNames;
Expand All @@ -235,28 +233,13 @@ SugaredValuePtr ModuleValue::desugarModuleContainer(
auto mod_v = std::make_shared<ModuleValue>(
module_v, concreteType_->findSubmoduleConcreteType(name));

if (get_keys) {
keys.push_back(name_v);
}
if (get_values) {
values.push_back(mod_v);
}
keys.push_back(name_v);
values.push_back(mod_v);
}

if (get_keys && !get_values) {
return std::make_shared<SugaredTupleValue>(keys);
} else if (get_values && !get_keys) {
return std::make_shared<SugaredTupleValue>(values);
} else if (get_values && get_keys) {
auto key_list = std::make_shared<SugaredTupleValue>(keys);
auto value_list = std::make_shared<SugaredTupleValue>(values);
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, key_list);
iterator->addChild(loc, m, value_list);
return iterator->iter(loc, m);
} else {
TORCH_INTERNAL_ASSERT(false);
}
return std::make_shared<SugaredModuleDict>(
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}

// This method controls how we desugar attribute lookups on ScriptModules.
Expand Down Expand Up @@ -295,21 +278,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
// TODO: These could be represented as first class methods probably.
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (field == "items" || field == "keys" || field == "values") {
bool get_keys = false;
bool get_values = false;
if (field == "items") {
get_keys = true;
get_values = true;
} else if (field == "values") {
get_values = true;
} else {
get_keys = true;
}
return std::make_shared<ModuleDictMethod>(
desugarModuleContainer(get_keys, get_values, loc, m), field);
return getSugaredModuleDict(loc, m)->attr(loc, m, field);
}
}

if (field == "_modules") {
return getSugaredModuleDict(loc, m);
}

// 4. Check if this is the name of an overloaded method.

// This can also be a call to a non-script module, or a plain
Expand Down Expand Up @@ -376,11 +352,16 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}

auto module_dict = getSugaredModuleDict(loc, m);
// iterating over a dictionary returns the keys, iterating over a
// list returns the values
const bool get_keys = iterableModuleKind == IterableModuleKind::DICT;
const bool get_values = iterableModuleKind == IterableModuleKind::LIST;
return desugarModuleContainer(get_keys, get_values, loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}

std::shared_ptr<SugaredValue> PythonClassValue::attr(
Expand Down
99 changes: 68 additions & 31 deletions torch/csrc/jit/script/python_sugared_value.h
Expand Up @@ -102,6 +102,70 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
Value* the_list_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

// Used to support mod._modules
struct VISIBILITY_HIDDEN SugaredModuleDict : public SugaredValue {
explicit SugaredModuleDict(
std::shared_ptr<SugaredValue> keys,
std::shared_ptr<SugaredValue> modules) {
keys_ = std::move(keys);
modules_ = std::move(modules);
}

std::string kind() const override {
return "ModuleDict";
}

std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override {
if (field == "keys") {
return std::make_shared<ModuleDictMethod>(keys_, "keys");
} else if (field == "values") {
return std::make_shared<ModuleDictMethod>(modules_, "values");
} else if (field == "items") {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, keys_);
iterator->addChild(loc, m, modules_);
return std::make_shared<ModuleDictMethod>(iterator, "items");
} else {
throw ErrorReport(loc) << "_Modules only supports keys, values, or items";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_Modules -> _modules

}
}

SugaredValuePtr iter(const SourceRange& loc, Function& m) {
return keys_;
};

std::shared_ptr<SugaredValue> keys_;
std::shared_ptr<SugaredValue> modules_;
};

// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
Expand Down Expand Up @@ -136,6 +200,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
->call(loc, caller, inputs, attributes, n_binders);
}

std::shared_ptr<SugaredModuleDict> getSugaredModuleDict(
const SourceRange& loc,
Function& m);

void setAttr(
const SourceRange& loc,
Function& m,
Expand All @@ -144,42 +212,11 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {

SugaredValuePtr iter(const SourceRange& loc, Function& m) override;

SugaredValuePtr desugarModuleContainer(
bool get_keys,
bool get_values,
const SourceRange& loc,
Function& m);

private:
Value* self_;
std::shared_ptr<ConcreteModuleType> concreteType_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValue(py::dict dispatched_fn)
: dispatched_fn_(std::move(dispatched_fn)) {}
Expand Down
20 changes: 2 additions & 18 deletions torch/jit/_recursive.py
Expand Up @@ -8,7 +8,7 @@

import torch._jit_internal as _jit_internal
from torch.jit.frontend import get_default_args
from torch.nn import Module, Sequential
from torch.nn import Module
from torch._six import get_function_from_type, bind_method


Expand Down Expand Up @@ -382,9 +382,6 @@ def init_fn(script_module):
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER:
add_python_attr_to_scripted_model(script_module, nn_module, name)

if isinstance(nn_module, Sequential):
add_sequential_forward(script_module, nn_module)

return script_module


Expand All @@ -404,18 +401,6 @@ def add_python_attr_to_scripted_model(script_model, orig, attr):
if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
setattr(script_model, attr, getattr(orig, attr))

def add_sequential_forward(script_module, nn_module):
forward_func = getattr(nn_module.forward, "__func__", None)
# we aren't currently able to support compiling Sequential.forward
# so if we encounter it, use this forward instead. support for self._modules.values() is blocking
if forward_func == get_function_from_type(Sequential, "forward"):
script_module.define("""
def forward(self, input):
for m in self:
input = m(input)
return input
""")

def get_overload_annotations(mod):
# original function => [(mangled overload name, overload function)]
overloads = {}
Expand Down Expand Up @@ -478,8 +463,7 @@ def infer_methods_to_compile(nn_module):
if hasattr(nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward):
forward_func = getattr(nn_module.forward, "__func__", None)
module_forward = get_function_from_type(torch.nn.Module, "forward")
sequential_forward = get_function_from_type(Sequential, "forward")
if forward_func != module_forward and forward_func != sequential_forward:
if forward_func != module_forward:
methods = ['forward']

exported = []
Expand Down