Skip to content

Commit

Permalink
Add support for legacy tensor constructors in JIT (#74785)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #74785

Fix for https://github.com/facebookresearch/torchdynamo/issues/93

Because the constructor follow a non-standard input schema (variadic integers), they are handled specially in ir_emitter.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D35362762

Pulled By: eellison

fbshipit-source-id: 960badf08ba2ab0818af5fd331aff3542051250f
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Apr 6, 2022
1 parent 75682f1 commit bd579de
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ namespace c10 {
_(prim, Placeholder) /* debug */ \
_(prim, Print) \
_(prim, EmptyListLiteral) \
_(prim, LegacyTypedConstructor) \
_(prim, PythonOp) \
_(prim, IgnoredPythonOp) \
_(prim, Reverse) \
Expand Down
47 changes: 47 additions & 0 deletions test/jit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,53 @@ def test_return():

FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)

def test_legacy_tensor_constructor(self):
# testing PyObject overload
def test_all_dtypes():
return (
torch.BoolTensor([2]),
torch.LongTensor([3]),
torch.ByteTensor([4]),
torch.CharTensor([5]),
torch.DoubleTensor([6]),
torch.FloatTensor([7]),
torch.IntTensor([8]),
torch.ShortTensor([1]),
torch.HalfTensor([1]),
)

self.checkScript(test_all_dtypes, ())

# now test empty overload
def empty_overload():
return torch.LongTensor(2, 3, 4)

eager = empty_overload()
jit = torch.jit.script(empty_overload)()
eager[:] = 1
jit[:] = 1
self.assertEqual(eager, jit)

def no_inputs():
return torch.DoubleTensor()

self.checkScript(no_inputs, ())

# bad schema
def multiple_args():
return torch.LongTensor(1, [2])

with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
torch.jit.script(multiple_args)

# kwarg bad schema
def bad_kwarg():
return torch.LongTensor(hello="1")

with self.assertRaisesRegex(RuntimeError, "hello"):
torch.jit.script(bad_kwarg)


def test_broadcasting_list(self):
"""
Test BroadcastingList and torch.nn._size_N_t alias
Expand Down
69 changes: 68 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <atomic>
#include <climits>
#include <set>
Expand Down Expand Up @@ -3319,7 +3320,7 @@ struct to_ir {
auto sv = emitSugaredExpr(apply.callee(), 1);
auto loc = apply.callee().range();
if (auto special_form = dynamic_cast<SpecialFormValue*>(sv.get())) {
return emitApplySpecialForm(special_form->form(), apply, type_hint);
return emitApplySpecialForm(special_form->form(), apply, sv, type_hint);
}
auto args = getNamedValues(apply.inputs(), true);
auto kwargs = emitAttributes(apply.attributes());
Expand All @@ -3334,6 +3335,7 @@ struct to_ir {
std::shared_ptr<SugaredValue> emitApplySpecialForm(
Symbol form,
Apply& apply,
std::shared_ptr<SugaredValue> sv,
const TypePtr& type_hint = nullptr) {
switch (form) {
case prim::fork: {
Expand Down Expand Up @@ -3438,6 +3440,71 @@ struct to_ir {
return std::make_shared<SimpleValue>(
graph->insertNode(graph->createTuple(inp_values))->output());
}
case prim::LegacyTypedConstructor: {
// see legacy_tensor_generic_ctor_new
// These legacy constructors do not follow schemas that can be
// typed in native_functions.yaml / JIT type signature and are handled
// here. Only the two common cases are handled initially:
// "new(IntArrayRef size, *, Device? device=None)",
// "new(PyObject* data, *, Device? device=None)",
// Note: device argument is unused in the kernel
auto args = getValues(apply.inputs(), true);
auto kwargs = emitAttributes(apply.attributes());
auto get_base_error_msg = [&]() {
std::stringstream base_error_msg;
base_error_msg
<< "Legacy Tensor Constructor only supports two schemas in TorchScript: \n";
base_error_msg
<< "'new(IntArrayRef size, *, Device? device=None)',\n";
base_error_msg << "'new(PyObject* data, *, Device? device=None)\n'";
return base_error_msg;
};
if (kwargs.size() == 1 && kwargs[0].name() != "device") {
throw ErrorReport(apply)
<< get_base_error_msg().str() << "Got kwarg " << kwargs[0].name();
}
if (kwargs.size() > 1) {
throw ErrorReport(apply)
<< get_base_error_msg().str() << "Got multiple kwargs\n";
}
auto dtype = dynamic_cast<LegacyTensorConstructor*>(sv.get())->dtype();
auto dtype_ivalue = graph->insertConstant(dtype);

// supporting "new(IntArrayRef size, *, Device? device=None)", through
// empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout?
// layout=None, Device? device=None, bool? pin_memory=None,
// MemoryFormat? memory_format=None) -> Tensor
bool all_ints = std::all_of(args.begin(), args.end(), [](Value* v) {
return v->type()->cast<IntType>();
});
if (args.size() == 0) {
// empty inputs == torch.tensor([], dtype=....)
auto inp_list =
graph->insertNode(graph->createList(IntType::get(), {}))
->output();
return std::make_shared<SimpleValue>(graph->insert(
aten::tensor,
{inp_list},
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
} else if (all_ints) {
auto inp_list =
graph->insertNode(graph->createList(IntType::get(), args))
->output();
return std::make_shared<SimpleValue>(graph->insert(
aten::empty,
{inp_list},
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
} else if (args.size() == 1) {
return std::make_shared<SimpleValue>(graph->insert(
aten::tensor,
{args[0]},
{NamedValue(apply.range(), "dtype", dtype_ivalue)}));
} else {
throw ErrorReport(apply)
<< get_base_error_msg().str()
<< "Got multiple positional arguments that were not all integers";
}
}
case prim::isinstance: {
checkApplyNumInputs(apply, 2);
auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]);
Expand Down
20 changes: 20 additions & 0 deletions torch/csrc/jit/frontend/sugared_value.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <c10/util/Optional.h>
#include <functional>
#include <memory>
#include <string>
Expand Down Expand Up @@ -618,6 +619,25 @@ struct TORCH_API SpecialFormValue : public SugaredValue {
Symbol form_;
};

struct TORCH_API LegacyTensorConstructor : public SpecialFormValue {
LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device)
: SpecialFormValue(form), device_(device), dtype_(dtype) {}

static std::shared_ptr<LegacyTensorConstructor> create(
Symbol form,
at::ScalarType dtype,
at::Device device) {
return std::make_shared<LegacyTensorConstructor>(form, dtype, device);
}
at::ScalarType dtype() const {
return dtype_;
}

private:
at::Device device_;
at::ScalarType dtype_;
};

// matched against for special handling of range expressions
struct TORCH_API RangeValue : SugaredValue {
RangeValue(
Expand Down
21 changes: 21 additions & 0 deletions torch/csrc/jit/python/python_sugared_value.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <torch/csrc/jit/python/python_sugared_value.h>

#include <ATen/core/interned_strings.h>
#include <c10/core/ScalarType.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
Expand Down Expand Up @@ -1160,6 +1162,25 @@ std::shared_ptr<SugaredValue> toSugaredValue(
throw ErrorReport(loc) << "Cannot call a ScriptModule that is not"
<< " a submodule of the caller";
}
std::vector<std::pair<const char*, at::ScalarType>> tensor_names = {
{"BoolTensor", at::ScalarType::Bool},
{"LongTensor", at::ScalarType::Long},
{"ByteTensor", at::ScalarType::Byte},
{"CharTensor", at::ScalarType::Char},
{"DoubleTensor", at::ScalarType::Double},
{"FloatTensor", at::ScalarType::Float},
{"IntTensor", at::ScalarType::Int},
{"ShortTensor", at::ScalarType::Short},
{"HalfTensor", at::ScalarType::Half},
};
for (const auto& name : tensor_names) {
if (obj.ptr() == py::module::import("torch").attr(name.first).ptr()) {
// torch.LongTensor and other related functions create on cpu,
// TODO: add support for torch.cuda.LongTensor for gpu
return LegacyTensorConstructor::create(
prim::LegacyTypedConstructor, name.second, at::kCPU);
}
}

py::object builtin_name =
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
Expand Down

0 comments on commit bd579de

Please sign in to comment.