Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] allow classes to be used in their own methods #20106

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14623,6 +14623,25 @@ def __init__(self, x, y):
self.x = x
self.y = y

def test_class_constructs_itself(self):
@torch.jit.script # noqa: B903
class LSTMStateStack(object):
def __init__(self, num_layers, hidden_size):
# type: (int, int) -> None
self.num_layers = num_layers
self.hidden_size = hidden_size
self.last_state = (
torch.zeros(num_layers, 1, hidden_size),
torch.zeros(num_layers, 1, hidden_size),
)
self.stack = [(self.last_state[0][-1], self.last_state[0][-1])]

def copy(self):
# should be able to construct a class inside its own methods
other = LSTMStateStack(self.num_layers, self.hidden_size)
other.stack = list(self.stack)
return other


class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
Expand Down
42 changes: 39 additions & 3 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ namespace {
// A resolver that will inspect the outer Python scope to find `name`.
struct PythonResolver : public Resolver {
explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {}

/**
* While compiling classes, the class type we're compiling will not be
* available in Python, since we haven't finished defining the class yet. So
* in order to make the class type available to its own methods, we need to
* explicitly resolve it.
*
* @param rcb Python function to resolve a name to its Python object in the
* enclosing scope
* @param classname The unqualified classname of the class currently being
* compiled.
* @param classType The class's type.
*/
explicit PythonResolver(
ResolutionCallback rcb,
std::string classname,
ClassTypePtr classType)
: rcb_(std::move(rcb)),
classname_(std::move(classname)),
classType_(std::move(classType)) {}

std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
Expand All @@ -70,6 +91,9 @@ struct PythonResolver : public Resolver {
}

TypePtr resolveType(const std::string& name) const override {
if (classType_ && name == classname_) {
return classType_;
}
AutoGIL ag;
py::object obj = rcb_(name);
if (obj.is(py::none())) {
Expand All @@ -88,11 +112,20 @@ struct PythonResolver : public Resolver {

private:
ResolutionCallback rcb_;
std::string classname_;
ClassTypePtr classType_;
};

std::shared_ptr<PythonResolver> pythonResolver(ResolutionCallback rcb) {
return std::make_shared<PythonResolver>(rcb);
}
std::shared_ptr<PythonResolver> pythonResolver(
ResolutionCallback rcb,
std::string classname,
ClassTypePtr classType) {
return std::make_shared<PythonResolver>(
rcb, std::move(classname), std::move(classType));
}
} // namespace

FunctionSchema getSchemaWithNameAndDefaults(
Expand Down Expand Up @@ -517,15 +550,18 @@ void initJitScriptBindings(PyObject* module) {

m.def(
"_jit_script_class_compile",
[](const ClassDef& classDef, ResolutionCallback rcb) {
[](const std::string& qualifiedName,
const ClassDef& classDef,
ResolutionCallback rcb) {
auto cu = std::make_shared<CompilationUnit>();
auto classType =
ClassType::create(c10::QualifiedName(classDef.name().name()), cu);
ClassType::create(c10::QualifiedName(qualifiedName), cu);
std::vector<ResolverPtr> rcbs;
std::vector<Def> methodDefs;
for (const auto& def : classDef.defs()) {
methodDefs.push_back(def);
rcbs.push_back(pythonResolver(rcb));
rcbs.push_back(
pythonResolver(rcb, classDef.name().name(), classType));
}
cu->define(methodDefs, rcbs, simpleSelf(classType));
});
Expand Down
8 changes: 4 additions & 4 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,10 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
if inspect.isclass(obj):
if not _is_new_style_class(obj):
raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
name = _qualified_name(obj)
ast = get_jit_class_def(obj, name)
_jit_script_class_compile(ast, _rcb)
_add_script_class(obj, name)
qualified_name = _qualified_name(obj)
ast = get_jit_class_def(obj, obj.__name__)
_jit_script_class_compile(qualified_name, ast, _rcb)
_add_script_class(obj, qualified_name)
return obj
else:
ast = get_jit_def(obj)
Expand Down