Skip to content

Commit

Permalink
add support for _modules, reducing special casing of nn.Sequential (p…
Browse files Browse the repository at this point in the history
…ytorch#29495)

Summary:
Pull Request resolved: pytorch#29495

This PR adds support for `_modules`, making it so we no longer need to special case support for `nn.Sequential`. I was getting internal errors around the previous approach using `self.define()`, so i am adding this PR as part of the stack.

Fix for pytorch#28998

Test Plan: Imported from OSS

Differential Revision: D18412561

Pulled By: eellison

fbshipit-source-id: a8b24ebee39638fccf63b2701f65f8bb0de84faa
  • Loading branch information
Elias Ellison authored and ttumiel committed Mar 4, 2020
1 parent 6141207 commit 6dd4ed8
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 64 deletions.
22 changes: 22 additions & 0 deletions test/test_jit.py
Expand Up @@ -9824,6 +9824,28 @@ def forward(self, inp):
m2 = M2()
m2(torch.zeros(4, 3))

def test_named_modules(self):
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.mod = (nn.ReLU())
self.mod2 = (nn.ReLU())
self.mod3 = nn.Sequential(nn.Sequential(nn.ReLU()))

@torch.jit.export
def method(self, x):
mod_names = ""
for name, mod in self.named_modules():
mod_names = mod_names + " " + name
x = mod(x)
return mod_names, x

def forward(self, x):
return x + 2

mod = torch.jit.script(MyMod())
self.assertEqual(mod.method(torch.tensor(2)), MyMod().method(torch.tensor(2)))

def test_script_module_const(self):
class M(torch.jit.ScriptModule):

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/register_prim_ops.cpp
Expand Up @@ -2933,7 +2933,6 @@ RegisterOperators reg2({
DEFINE_DIVMOD_MIXED_OP(float, int),

#undef DEFINE_DIVMOD_MIXED_OP

Operator(
"aten::hash(str t) -> int",
hashValue<std::string>,
Expand Down
118 changes: 78 additions & 40 deletions torch/csrc/jit/script/python_sugared_value.cpp
Expand Up @@ -219,9 +219,41 @@ Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}

SugaredValuePtr ModuleValue::desugarModuleContainer(
bool get_keys,
bool get_values,
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue> self,
const std::string& prefix) {
auto prefix_value =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), prefix));

keys.push_back(prefix_value);
values.push_back(self);

auto module_dict = self->getSugaredModuleDict(loc, m);
auto keys_iter = module_dict->keys_;
auto module_values_iter = module_dict->modules_;
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
std::shared_ptr<SugaredValue> module_sugared_value =
module_values_iter->tup_.at(i);
auto module_value =
std::dynamic_pointer_cast<ModuleValue>(module_sugared_value);

auto keys_value = keys_iter->tup_.at(i);
auto key_string = toIValue(keys_value->asValue(loc, m))->toStringRef();
std::string submodule_prefix = prefix;
if (prefix != "") {
submodule_prefix = prefix + ".";
}
submodule_prefix = submodule_prefix + key_string;
recurseThroughNestedModules(
loc, m, keys, values, module_value, submodule_prefix);
};
}

std::shared_ptr<SugaredModuleDict> ModuleValue::getSugaredModuleDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> submoduleNames;
Expand All @@ -242,28 +274,39 @@ SugaredValuePtr ModuleValue::desugarModuleContainer(
auto mod_v = std::make_shared<ModuleValue>(
module_v, concreteType_->findSubmoduleConcreteType(name));

if (get_keys) {
keys.push_back(name_v);
}
if (get_values) {
values.push_back(mod_v);
}
keys.push_back(name_v);
values.push_back(mod_v);
}

if (get_keys && !get_values) {
return std::make_shared<SugaredTupleValue>(keys);
} else if (get_values && !get_keys) {
return std::make_shared<SugaredTupleValue>(values);
} else if (get_values && get_keys) {
auto key_list = std::make_shared<SugaredTupleValue>(keys);
auto value_list = std::make_shared<SugaredTupleValue>(values);
return std::make_shared<SugaredModuleDict>(
std::make_shared<ModuleValue>(self_, concreteType_),
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}

std::shared_ptr<SugaredValue> SugaredModuleDict::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
if (field == "keys") {
return std::make_shared<ModuleDictMethod>(keys_, "keys");
} else if (field == "values") {
return std::make_shared<ModuleDictMethod>(modules_, "values");
} else if (field == "items") {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, key_list);
iterator->addChild(loc, m, value_list);
return iterator->iter(loc, m);
} else {
TORCH_INTERNAL_ASSERT(false);
}
iterator->addChild(loc, m, keys_);
iterator->addChild(loc, m, modules_);
return std::make_shared<ModuleDictMethod>(iterator, "items");
} else if (field == "named_modules") {
auto iterator = std::make_shared<IterableTree>();
std::vector<SugaredValuePtr> keys;
std::vector<SugaredValuePtr> values;
recurseThroughNestedModules(loc, m, keys, values, self_, "");
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
return std::make_shared<ModuleDictMethod>(iterator, "named_modules");
};
TORCH_INTERNAL_ASSERT(false);
}

// helper function for instantiating a SugaredValue from an IValue
Expand Down Expand Up @@ -316,24 +359,16 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(

// 2. Special case: for module dicts we manually desugar items(), keys(),
// values() calls into the appropriate method.
// TODO: These could be represented as first class methods probably.
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (field == "items" || field == "keys" || field == "values") {
bool get_keys = false;
bool get_values = false;
if (field == "items") {
get_keys = true;
get_values = true;
} else if (field == "values") {
get_values = true;
} else {
get_keys = true;
}
return std::make_shared<ModuleDictMethod>(
desugarModuleContainer(get_keys, get_values, loc, m), field);
return getSugaredModuleDict(loc, m)->attr(loc, m, field);
}
}

if (field == "named_modules") {
return getSugaredModuleDict(loc, m)->attr(loc, m, "named_modules");
}

// 3. Check if this is the name of an overloaded method.

// This can also be a call to a non-script module, or a plain
Expand Down Expand Up @@ -403,11 +438,14 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}

// iterating over a dictionary returns the keys, iterating over a
// list returns the values
const bool get_keys = iterableModuleKind == IterableModuleKind::DICT;
const bool get_values = iterableModuleKind == IterableModuleKind::LIST;
return desugarModuleContainer(get_keys, get_values, loc, m);
auto module_dict = getSugaredModuleDict(loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}

std::shared_ptr<SugaredValue> PythonClassValue::attr(
Expand Down
83 changes: 60 additions & 23 deletions torch/csrc/jit/script/python_sugared_value.h
Expand Up @@ -102,6 +102,33 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
Value* the_list_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

struct SugaredModuleDict;

// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
Expand Down Expand Up @@ -136,6 +163,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
->call(loc, caller, inputs, attributes, n_binders);
}

std::shared_ptr<SugaredModuleDict> getSugaredModuleDict(
const SourceRange& loc,
Function& m);

void setAttr(
const SourceRange& loc,
Function& m,
Expand All @@ -144,40 +175,46 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {

SugaredValuePtr iter(const SourceRange& loc, Function& m) override;

SugaredValuePtr desugarModuleContainer(
bool get_keys,
bool get_values,
const SourceRange& loc,
Function& m);

private:
Value* self_;
std::shared_ptr<ConcreteModuleType> concreteType_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue> self,
const std::string& prefix);

// Used to support named_modules()
struct VISIBILITY_HIDDEN SugaredModuleDict : public SugaredValue {
explicit SugaredModuleDict(
std::shared_ptr<ModuleValue> self,
std::shared_ptr<SugaredTupleValue> keys,
std::shared_ptr<SugaredTupleValue> modules) {
self_ = std::move(self);
keys_ = std::move(keys);
modules_ = std::move(modules);
}

std::string kind() const override {
return name_;
return "ModuleDict";
}

std::shared_ptr<SugaredValue> call(
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}
Function& m,
const std::string& field) override;

SugaredValuePtr iterable_;
const std::string name_;
SugaredValuePtr iter(const SourceRange& loc, Function& m) {
return keys_;
};

std::shared_ptr<ModuleValue> self_;
std::shared_ptr<SugaredTupleValue> keys_;
std::shared_ptr<SugaredTupleValue> modules_;
};

struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
Expand Down

0 comments on commit 6dd4ed8

Please sign in to comment.