Skip to content

Commit

Permalink
Support %-based string formatting
Browse files Browse the repository at this point in the history
ghstack-source-id: 86c85502befd882c8bdefff3d8c3a7d8674bc5d7
Pull Request resolved: #45976
  • Loading branch information
Ansley Adelaide Ussery committed Oct 15, 2020
1 parent 1e654a4 commit b70b60f
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 18 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
148 changes: 148 additions & 0 deletions test/jit/test_string_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
import sys

import torch
from typing import List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

class TestStringFormatting(JitTestCase):

def test_modulo_operator(self):
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.checkScript(fn, (1,))

def test_string_interpolation_with_string_placeholder_and_string_variable(self):
def fn(arg1: str):
return "%s in template" % arg1
self.checkScript(fn, ("foo",))

def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
def fn(arg1: str):
return arg1 % "foo"
self.checkScript(fn, ("%s in template",))

def test_string_interpolation_with_double_percent_in_string(self):
def fn(arg1: str):
return "%s in template %%" % arg1
self.checkScript(fn, ("foo",))

def test_string_interpolation_with_percent_in_string(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s in template %" % arg1 # noqa: F501

with self.assertRaisesRegex(RuntimeError, "Incomplete format specifier"):
fn("foo")

def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_alternate_digit_placeholder(self):
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%d requires a number for formatting, but got String"):
fn("1")

def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%e in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%e requires a number for formatting, but got String"):
fn("1")

def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_float_placeholder_and_float_variable(self):
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1.0,))

def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_char_placeholder_and_char_variable(self):
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.checkScript(fn, ("a",))

def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.checkScript(fn, (97,))

def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%c requires an int or char for formatting, but got String"):
fn("foo")

def test_string_interpolation_with_multiple_placeholders(self):
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.checkScript(fn, ("foo", 1, 1))

def test_string_interpolation_with_subscript(self):
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.checkScript(fn, (["foo", "bar"],))

def test_string_interpolation_with_too_few_arguments(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1

with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"):
fn("foo")

def test_string_interpolation_with_too_many_arguments(self):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2) # noqa: F507

with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"):
fn("foo", "bar")

def test_string_interpolation_with_unknown_format_specifier(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%a in template" % arg1 # noqa: F501

with self.assertRaisesRegex(RuntimeError, "The specifier %a is not supported in TorchScript format strings"):
fn("foo")
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jit.test_onnx_export import TestONNXExport # noqa: F401
from jit.test_with import TestWith # noqa: F401
from jit.test_enum import TestEnum # noqa: F401
from jit.test_string_formatting import TestStringFormatting # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
Expand Down
52 changes: 34 additions & 18 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,23 @@ struct to_ir {
return std::make_shared<SimpleValue>(rpc_node_output);
}

// This is an auxiliary function that is only called from `emitSimpleExpr`.
Value* emitSimpleExprFromBuiltInFunction(const TreeRef& tree) {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto overload = getOperatorOverload(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
if (tree->kind() == TK_IN) {
// For `in` the arguments are in reverse order (the object being
// checked is second)
std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
}
return asSimple(
makeMagic(
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
->call(tree->range(), method, named_values, {}, 0));
}

Value* emitSimpleExpr(
const TreeRef& tree,
const TypePtr& type_hint = nullptr) {
Expand All @@ -3136,6 +3153,21 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0)
->asValue(tree->tree(0)->range(), method);
auto const& lhs_type = lhs->type();
if (lhs_type == StringType::get()) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
} else {
return emitSimpleExprFromBuiltInFunction(tree);
}
}
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,28 +3180,12 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
case TK_LSHIFT:
case TK_RSHIFT: {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto overload = getOperatorOverload(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);

if (tree->kind() == TK_IN) {
// For `in` the arguments are in reverse order (the object being
// checked is second)
std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
}

return asSimple(
makeMagic(
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
->call(tree->range(), method, named_values, {}, 0));
}
case TK_RSHIFT:
return emitSimpleExprFromBuiltInFunction(tree);
case TK_IS:
case TK_ISNOT:
case TK_AND:
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ RegisterOperators reg(
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::percentFormat(str self, ...) -> str",
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
Operator(
"aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
[](Stack* stack) {
Expand Down

0 comments on commit b70b60f

Please sign in to comment.