Skip to content

Commit

Permalink
add support for _modules, reducing special casing of nn.Sequential
Browse files Browse the repository at this point in the history
ghstack-source-id: 90941b7813c3f03ca0a63755cc42f6da9f32bf81
Pull Request resolved: #29495
  • Loading branch information
eellison committed Nov 12, 2019
1 parent 10455d6 commit 3a2b794
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 69 deletions.
23 changes: 23 additions & 0 deletions test/test_jit.py
Expand Up @@ -8804,6 +8804,29 @@ def forward(self, v):
with self.assertRaisesRegex(Exception, "object is not iterable"):
print([val for val in m])

def test_underscore_modules(self):
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.mod = nn.ReLU()
self.mod2 = nn.ReLU()

def forward(self, x):
li = torch.jit.annotate(List[str], [])
for mod_name in self._modules:
li.append(mod_name)

for mod in self._modules.values():
x = mod(x)

for mod_name, mod in self._modules.items():
x = mod(x)
li.append(mod_name)

return li, x

self.checkModule(MyMod(), (torch.tensor(1.),))

def test_attr_qscheme_script(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
57 changes: 19 additions & 38 deletions torch/csrc/jit/script/python_sugared_value.cpp
Expand Up @@ -208,9 +208,7 @@ Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}

SugaredValuePtr ModuleValue::desugarModuleContainer(
bool get_keys,
bool get_values,
std::shared_ptr<SugaredModuleDict> ModuleValue::getSugaredModuleDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> submoduleNames;
Expand All @@ -235,28 +233,13 @@ SugaredValuePtr ModuleValue::desugarModuleContainer(
auto mod_v = std::make_shared<ModuleValue>(
module_v, concreteType_->findSubmoduleConcreteType(name));

if (get_keys) {
keys.push_back(name_v);
}
if (get_values) {
values.push_back(mod_v);
}
keys.push_back(name_v);
values.push_back(mod_v);
}

if (get_keys && !get_values) {
return std::make_shared<SugaredTupleValue>(keys);
} else if (get_values && !get_keys) {
return std::make_shared<SugaredTupleValue>(values);
} else if (get_values && get_keys) {
auto key_list = std::make_shared<SugaredTupleValue>(keys);
auto value_list = std::make_shared<SugaredTupleValue>(values);
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, key_list);
iterator->addChild(loc, m, value_list);
return iterator->iter(loc, m);
} else {
TORCH_INTERNAL_ASSERT(false);
}
return std::make_shared<SugaredModuleDict>(
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}

// This method controls how we desugar attribute lookups on ScriptModules.
Expand Down Expand Up @@ -295,21 +278,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
// TODO: These could be represented as first class methods probably.
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (field == "items" || field == "keys" || field == "values") {
bool get_keys = false;
bool get_values = false;
if (field == "items") {
get_keys = true;
get_values = true;
} else if (field == "values") {
get_values = true;
} else {
get_keys = true;
}
return std::make_shared<ModuleDictMethod>(
desugarModuleContainer(get_keys, get_values, loc, m), field);
return getSugaredModuleDict(loc, m)->attr(loc, m, field);
}
}

if (field == "_modules") {
return getSugaredModuleDict(loc, m);
}

// 4. Check if this is the name of an overloaded method.

// This can also be a call to a non-script module, or a plain
Expand Down Expand Up @@ -376,11 +352,16 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}

auto module_dict = getSugaredModuleDict(loc, m);
// iterating over a dictionary returns the keys, iterating over a
// list returns the values
const bool get_keys = iterableModuleKind == IterableModuleKind::DICT;
const bool get_values = iterableModuleKind == IterableModuleKind::LIST;
return desugarModuleContainer(get_keys, get_values, loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}

std::shared_ptr<SugaredValue> PythonClassValue::attr(
Expand Down
99 changes: 68 additions & 31 deletions torch/csrc/jit/script/python_sugared_value.h
Expand Up @@ -102,6 +102,70 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
Value* the_list_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

// Used to support mod._modules
struct VISIBILITY_HIDDEN SugaredModuleDict : public SugaredValue {
explicit SugaredModuleDict(
std::shared_ptr<SugaredValue> keys,
std::shared_ptr<SugaredValue> modules) {
keys_ = std::move(keys);
modules_ = std::move(modules);
}

std::string kind() const override {
return "ModuleDict";
}

std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override {
if (field == "keys") {
return std::make_shared<ModuleDictMethod>(keys_, "keys");
} else if (field == "values") {
return std::make_shared<ModuleDictMethod>(modules_, "values");
} else if (field == "items") {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, keys_);
iterator->addChild(loc, m, modules_);
return std::make_shared<ModuleDictMethod>(iterator, "items");
} else {
throw ErrorReport(loc) << "_Modules only supports keys, values, or items";
}
}

SugaredValuePtr iter(const SourceRange& loc, Function& m) {
return keys_;
};

std::shared_ptr<SugaredValue> keys_;
std::shared_ptr<SugaredValue> modules_;
};

// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
Expand Down Expand Up @@ -136,6 +200,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
->call(loc, caller, inputs, attributes, n_binders);
}

std::shared_ptr<SugaredModuleDict> getSugaredModuleDict(
const SourceRange& loc,
Function& m);

void setAttr(
const SourceRange& loc,
Function& m,
Expand All @@ -144,42 +212,11 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {

SugaredValuePtr iter(const SourceRange& loc, Function& m) override;

SugaredValuePtr desugarModuleContainer(
bool get_keys,
bool get_values,
const SourceRange& loc,
Function& m);

private:
Value* self_;
std::shared_ptr<ConcreteModuleType> concreteType_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValue(py::dict dispatched_fn)
: dispatched_fn_(std::move(dispatched_fn)) {}
Expand Down

0 comments on commit 3a2b794

Please sign in to comment.