Skip to content

Commit

Permalink
[JIT][Reland] add list() support (#42382)
Browse files Browse the repository at this point in the history
Summary:
Fixes #40869

Resubmit of #33818.

Adds support for `list()` by desugaring  it to a list comprehension.

Last time I landed this it made one of the tests slow, and got unlanded. I think that's bc the previous PR changed the emission of `list()` on a list input or a str input to a list comprehension, which is the more general way of emitting `list()`, but also a little bit slower. I updated this version to emit to the builtin operators for these two case. Hopefully it can land without being reverted this time...

Pull Request resolved: #42382

Reviewed By: navahgar

Differential Revision: D24767674

Pulled By: eellison

fbshipit-source-id: a1aa3d104499226b28f47c3698386d365809c23c
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Nov 6, 2020
1 parent eaa993a commit f3ad7b2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ namespace c10 {
_(prim, ConstantChunk) \
_(prim, MMTreeReduce) \
_(prim, MMBatchSide) \
_(prim, list) \
_(prim, min) \
_(prim, max) \
_(prim, abs) \
Expand Down
19 changes: 19 additions & 0 deletions test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def fn(x):
del x[1:3]
return x

def test_list_keyword(self):
def foo():
return list([1, 2, 3]), list(("a", "b")), list(range(5)), list("abcdefg") # noqa: C410

self.checkScript(foo, ())

def foo2():
x: List[int] = list()
x.append(1)
return x,

self.checkScript(foo2, ())

def foo3():
return list(list("abc"))

self.checkScript(foo3, ())
FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)

def test_min_bool_list(self):
def jit_min_list(a, b):
# type: (List[bool], List[bool]) -> List[bool]
Expand Down
40 changes: 39 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ struct Environment {
{"all", std::make_shared<BuiltinFunction>(aten::all, at::nullopt)},
{"divmod",
std::make_shared<BuiltinFunction>(aten::divmod, at::nullopt)},
{"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
{"list", SpecialFormValue::create(prim::list)},
{"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
{"chr", std::make_shared<BuiltinFunction>(aten::chr, at::nullopt)},
{"bin", std::make_shared<BuiltinFunction>(aten::bin, at::nullopt)},
Expand Down Expand Up @@ -2896,6 +2896,44 @@ struct to_ir {
}
return iterable_tree;
}
case prim::list: {
if (apply.inputs().size() == 0) {
TypePtr type = type_hint ? type_hint : ListType::ofTensors();
if (!type->cast<ListType>()) {
throw ErrorReport(apply.range())
<< "Expected list type annotation for list(), found "
<< type_hint->repr_str();
}
return std::make_shared<SimpleValue>(graph->insertNode(graph->createList(type->expect<ListType>()->getElementType(), {}))->output());
}
// list(iter) desugars to [_elem for _elem in iter]
checkApplyNumInputs(apply, 1);
auto iter_input = emitSugaredExpr(apply.inputs()[0], 1);

// aten::list builtin op is registered for List and Str input
// dispatch to the builtin op to avoid perf slowdown on existing uses
if (auto simple = asSimple(iter_input)) {
if (simple->type()->cast<ListType>() || simple->type()->cast<StringType>()) {
return std::make_shared<SimpleValue>(emitBuiltinCall(
apply.range(), *method.graph(), aten::list, {simple}, {}));
}
}
const std::string& iter_name = createTempName("$_iter");
environment_stack->setSugaredVar(
apply.range(),
iter_name,
iter_input,
/*annotated_type=*/nullptr);

const std::string& elem_name = createTempName("$_elem");
auto ident =
Var::create(apply.range(), Ident::create(apply.range(), elem_name));
auto iter =
Var::create(apply.range(), Ident::create(apply.range(), iter_name));
auto lc = ListComp::create(apply.range(), ident, ident, iter);
return std::make_shared<SimpleValue>(
emitListComprehension(lc, type_hint));
}
default:
TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form);
}
Expand Down
8 changes: 4 additions & 4 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
# product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
indices = _indices_product(shape[:-2])
for idx in indices:
final_order = [i for i in range(m)] # noqa: C416 TODO: rewrite as list(range(m))
final_order = list(range(m))
for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):
final_order[k], final_order[j] = final_order[j], final_order[k]
# TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list
p_idx = _index_tensor_with_indices_list(P, idx)
p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))
else:
P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)
final_order = [i for i in range(m)] # noqa: C416 TODO: rewrite as list(range(m))
final_order = list(range(m))
for k, j, in enumerate(LU_pivots_zero_idx):
final_order[k], final_order[j] = final_order[j], final_order[k]
P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
raise ValueError("dtype argument is not supported in frobenius norm")

if _dim is None:
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
_dim = list(range(ndim))
if out is None:
return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore
else:
Expand All @@ -1355,7 +1355,7 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
else:
if _dim is None:
_dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
_dim = list(range(ndim))

if out is None:
if dtype is None:
Expand Down

0 comments on commit f3ad7b2

Please sign in to comment.