From fd75506e6760218bc465fc781d441e1774f95858 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 19 Jun 2019 14:06:54 -0700 Subject: [PATCH] [jit] add for in string support --- test/test_jit.py | 21 +++++++++++++++++++++ torch/csrc/jit/script/compiler.cpp | 5 +++-- torch/csrc/jit/script/sugared_value.cpp | 3 +++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 47082652ffe3b..c37f25fe82659 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10047,6 +10047,27 @@ def list_iterables(x): return x ''') + def test_for_in_string(self): + def test_strings(x): + # type: (str) -> str + reverse = "" + for c in x: + reverse = c + reverse + return reverse + + self.checkScript(test_strings, ("hello",)) + self.checkScript(test_strings, ("",)) + + def test_list_strings(x): + # type: (List[str]) -> str + result = "" + for sub_str in x: + result += sub_str + return result + + self.checkScript(test_list_strings, (["hello", "world"],)) + self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) + def test_for_tuple_unpack(self): def for_tuple_unpack(x, y): for i, j in [[3, 4], [5, 6], [7, 8]]: diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 89bd9c141ad22..1679c6021d234 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1343,13 +1343,14 @@ struct to_ir { auto sv = emitSugaredExpr(itrs[0], 1); // We will get IterableTree for builtinFunctions zip() and enumerate(), - // RangeValue for range(), and SimpleValue for types like List, Tensor, Dict. + // RangeValue for range(), and SimpleValue for types like List/Tensor/Dict/String. auto range_val = std::dynamic_pointer_cast(sv); auto siv = std::dynamic_pointer_cast(sv); auto iterable_tree = std::dynamic_pointer_cast(sv); if ((siv && (siv->getValue()->type()->kind() == TypeKind::ListType - || siv->getValue()->type()->isSubtypeOf(TensorType::get())) + || siv->getValue()->type()->isSubtypeOf(TensorType::get()) + || siv->getValue()->type()->isSubtypeOf(StringType::get())) ) || range_val || iterable_tree) { emitLoopCommon(stmt.range(), body, sv, targets, {}); return; diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index ba58e0cf7f4fe..49a3ea67d2653 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -265,6 +265,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { TypePtr val_type = val->type(); Graph& g = *m.graph(); if (val_type->cast() || + val_type->cast() || val_type->isSubtypeOf(TensorType::get())) { return g.insert(aten::len, {val}, {}, loc); } else { @@ -280,6 +281,8 @@ Value* SimpleValue::getelem(const SourceRange&loc, Function& m, Value* i) { Value* cur_elem = nullptr; if (val_type->cast()) { cur_elem = g.insert(aten::select, {val, i}, {}, loc); + } else if (val_type->cast()) { + cur_elem = g.insert(prim::StringIndex, {val, i}, {}, loc); } else if (val_type->isSubtypeOf(TensorType::get())) { cur_elem = g.insert(aten::select, {val, 0, i}, {}, loc); } else {