Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Add support for named_modules() #29495

Closed
wants to merge 8 commits into from
22 changes: 22 additions & 0 deletions test/test_jit.py
Expand Up @@ -9641,6 +9641,28 @@ def forward(self, inp):
m2 = M2()
m2(torch.zeros(4, 3))

def test_named_modules(self):
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.mod = (nn.ReLU())
self.mod2 = (nn.ReLU())
self.mod3 = nn.Sequential(nn.Sequential(nn.ReLU()))

@torch.jit.export
def method(self, x):
mod_names = ""
for name, mod in self.named_modules():
mod_names = mod_names + " " + name
x = mod(x)
return mod_names, x

def forward(self, x):
return x + 2

mod = torch.jit.script(MyMod())
self.assertEqual(mod.method(torch.tensor(2)), MyMod().method(torch.tensor(2)))

def test_script_module_const(self):
class M(torch.jit.ScriptModule):

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/register_prim_ops.cpp
Expand Up @@ -3168,7 +3168,6 @@ RegisterOperators reg2({
DEFINE_DIVMOD_MIXED_OP(float, int),

#undef DEFINE_DIVMOD_MIXED_OP

Operator(
"aten::hash(str t) -> int",
hashValue<std::string>,
Expand Down
97 changes: 55 additions & 42 deletions torch/csrc/jit/script/python_sugared_value.cpp
Expand Up @@ -219,9 +219,41 @@ Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}

SugaredValuePtr ModuleValue::desugarModuleContainer(
bool get_keys,
bool get_values,
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue> self,
const std::string& prefix) {
auto prefix_value =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), prefix));

keys.push_back(prefix_value);
values.push_back(self);

auto module_dict = self->getSugaredModuleDict(loc, m);
auto keys_iter = module_dict->keys_;
auto module_values_iter = module_dict->modules_;
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
std::shared_ptr<SugaredValue> module_sugared_value =
module_values_iter->tup_.at(i);
auto module_value =
std::dynamic_pointer_cast<ModuleValue>(module_sugared_value);

auto keys_value = keys_iter->tup_.at(i);
auto key_string = toIValue(keys_value->asValue(loc, m))->toStringRef();
std::string submodule_prefix = prefix;
if (prefix != "") {
submodule_prefix = prefix + ".";
}
submodule_prefix = submodule_prefix + key_string;
recurseThroughNestedModules(
loc, m, keys, values, module_value, submodule_prefix);
};
}

std::shared_ptr<SugaredModuleDict> ModuleValue::getSugaredModuleDict(
const SourceRange& loc,
Function& m) {
std::vector<std::string> submoduleNames;
Expand All @@ -242,28 +274,14 @@ SugaredValuePtr ModuleValue::desugarModuleContainer(
auto mod_v = std::make_shared<ModuleValue>(
module_v, concreteType_->findSubmoduleConcreteType(name));

if (get_keys) {
keys.push_back(name_v);
}
if (get_values) {
values.push_back(mod_v);
}
keys.push_back(name_v);
values.push_back(mod_v);
}

if (get_keys && !get_values) {
return std::make_shared<SugaredTupleValue>(keys);
} else if (get_values && !get_keys) {
return std::make_shared<SugaredTupleValue>(values);
} else if (get_values && get_keys) {
auto key_list = std::make_shared<SugaredTupleValue>(keys);
auto value_list = std::make_shared<SugaredTupleValue>(values);
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, key_list);
iterator->addChild(loc, m, value_list);
return iterator->iter(loc, m);
} else {
TORCH_INTERNAL_ASSERT(false);
}
return std::make_shared<SugaredModuleDict>(
std::make_shared<ModuleValue>(self_, concreteType_),
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}

// helper function for instantiating a SugaredValue from an IValue
Expand Down Expand Up @@ -316,24 +334,16 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(

// 2. Special case: for module dicts we manually desugar items(), keys(),
// values() calls into the appropriate method.
// TODO: These could be represented as first class methods probably.
if (concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (field == "items" || field == "keys" || field == "values") {
bool get_keys = false;
bool get_values = false;
if (field == "items") {
get_keys = true;
get_values = true;
} else if (field == "values") {
get_values = true;
} else {
get_keys = true;
}
return std::make_shared<ModuleDictMethod>(
desugarModuleContainer(get_keys, get_values, loc, m), field);
if (field == "items" || field == "keys" || field != "values") {
eellison marked this conversation as resolved.
Show resolved Hide resolved
return getSugaredModuleDict(loc, m)->attr(loc, m, field);
}
}

if (field == "named_modules") {
return getSugaredModuleDict(loc, m)->attr(loc, m, "named_modules");
}

// 3. Check if this is the name of an overloaded method.

// This can also be a call to a non-script module, or a plain
Expand Down Expand Up @@ -403,11 +413,14 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}

// iterating over a dictionary returns the keys, iterating over a
// list returns the values
const bool get_keys = iterableModuleKind == IterableModuleKind::DICT;
const bool get_values = iterableModuleKind == IterableModuleKind::LIST;
return desugarModuleContainer(get_keys, get_values, loc, m);
auto module_dict = getSugaredModuleDict(loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}

std::shared_ptr<SugaredValue> PythonClassValue::attr(
Expand Down
106 changes: 84 additions & 22 deletions torch/csrc/jit/script/python_sugared_value.h
Expand Up @@ -102,6 +102,33 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
Value* the_list_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};

std::string kind() const override {
return name_;
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}

SugaredValuePtr iterable_;
const std::string name_;
};

struct SugaredModuleDict;

// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
Expand Down Expand Up @@ -136,6 +163,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
->call(loc, caller, inputs, attributes, n_binders);
}

std::shared_ptr<SugaredModuleDict> getSugaredModuleDict(
const SourceRange& loc,
Function& m);

void setAttr(
const SourceRange& loc,
Function& m,
Expand All @@ -144,40 +175,71 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {

SugaredValuePtr iter(const SourceRange& loc, Function& m) override;

SugaredValuePtr desugarModuleContainer(
bool get_keys,
bool get_values,
const SourceRange& loc,
Function& m);

private:
Value* self_;
std::shared_ptr<ConcreteModuleType> concreteType_;
};

struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, const std::string& name)
: iterable_(iterable), name_(name){};
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue> self,
const std::string& prefix);

// Used to support named_modules()
struct VISIBILITY_HIDDEN SugaredModuleDict : public SugaredValue {
explicit SugaredModuleDict(
std::shared_ptr<ModuleValue> self,
std::shared_ptr<SugaredTupleValue> keys,
std::shared_ptr<SugaredTupleValue> modules) {
self_ = std::move(self);
keys_ = std::move(keys);
modules_ = std::move(modules);
}

std::string kind() const override {
return name_;
return "ModuleDict";
}

std::shared_ptr<SugaredValue> call(
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& f,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
if (inputs.size() || attributes.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
Function& m,
const std::string& field) override {
eellison marked this conversation as resolved.
Show resolved Hide resolved
if (field == "keys") {
return std::make_shared<ModuleDictMethod>(keys_, "keys");
} else if (field == "values") {
return std::make_shared<ModuleDictMethod>(modules_, "values");
} else if (field == "items") {
auto iterator = std::make_shared<IterableTree>();
iterator->addChild(loc, m, keys_);
iterator->addChild(loc, m, modules_);
return std::make_shared<ModuleDictMethod>(iterator, "items");
} else if (field == "named_modules") {
auto iterator = std::make_shared<IterableTree>();
std::vector<SugaredValuePtr> keys;
std::vector<SugaredValuePtr> values;

auto key_tuple = std::dynamic_pointer_cast<SugaredTupleValue>(keys_);
auto values_tuple =
std::dynamic_pointer_cast<SugaredTupleValue>(modules_);

recurseThroughNestedModules(loc, m, keys, values, self_, "");
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(keys));
iterator->addChild(loc, m, std::make_shared<SugaredTupleValue>(values));
return std::make_shared<ModuleDictMethod>(iterator, "named_modules");
};
TORCH_INTERNAL_ASSERT(false);
}

SugaredValuePtr iterable_;
const std::string name_;
SugaredValuePtr iter(const SourceRange& loc, Function& m) {
return keys_;
};

std::shared_ptr<ModuleValue> self_;
std::shared_ptr<SugaredTupleValue> keys_;
std::shared_ptr<SugaredTupleValue> modules_;
};

struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
Expand Down