Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support %-based string formatting #45976

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
148 changes: 148 additions & 0 deletions test/jit/test_string_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import os
import sys

import torch
from typing import List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

class TestStringFormatting(JitTestCase):

def test_modulo_operator(self):
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.checkScript(fn, (5, 2))

def test_string_interpolation_with_string_placeholder_and_string_variable(self):
def fn(arg1: str):
return "%s in template" % arg1
self.checkScript(fn, ("foo",))

def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
def fn(arg1: str):
return arg1 % "foo"
self.checkScript(fn, ("%s in template",))

def test_string_interpolation_with_double_percent_in_string(self):
def fn(arg1: str):
return "%s in template %%" % arg1
self.checkScript(fn, ("foo",))

def test_string_interpolation_with_percent_in_string(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s in template %" % arg1 # noqa: F501

with self.assertRaisesRegex(RuntimeError, "Incomplete format specifier"):
fn("foo")

def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_alternate_digit_placeholder(self):
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%d requires a number for formatting, but got String"):
fn("1")
ansley marked this conversation as resolved.
Show resolved Hide resolved

def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%e in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%e requires a number for formatting, but got String"):
fn("1")

def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_float_placeholder_and_float_variable(self):
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1.0,))

def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1,))

def test_string_interpolation_with_char_placeholder_and_char_variable(self):
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.checkScript(fn, ("a",))

def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.checkScript(fn, (97,))

def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1

with self.assertRaisesRegex(RuntimeError, "%c requires an int or char for formatting, but got String"):
fn("foo")

def test_string_interpolation_with_multiple_placeholders(self):
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.checkScript(fn, ("foo", 1, 1))

def test_string_interpolation_with_subscript(self):
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.checkScript(fn, (["foo", "bar"],))

def test_string_interpolation_with_too_few_arguments(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1

with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"):
fn("foo")

def test_string_interpolation_with_too_many_arguments(self):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2) # noqa: F507

with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"):
fn("foo", "bar")

def test_string_interpolation_with_unknown_format_specifier(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%a in template" % arg1 # noqa: F501

with self.assertRaisesRegex(RuntimeError, "The specifier %a is not supported in TorchScript format strings"):
fn("foo")
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jit.test_onnx_export import TestONNXExport # noqa: F401
from jit.test_with import TestWith # noqa: F401
from jit.test_enum import TestEnum # noqa: F401
from jit.test_string_formatting import TestStringFormatting # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
Expand Down
51 changes: 33 additions & 18 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,22 @@ struct to_ir {
return std::make_shared<SimpleValue>(rpc_node_output);
}

Value* emitBinaryOp(const TreeRef& tree) {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto overload = getOperatorOverload(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
if (tree->kind() == TK_IN) {
// For `in` the arguments are in reverse order (the object being
// checked is second)
std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
}
return asSimple(
makeMagic(
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
->call(tree->range(), method, named_values, {}, 0));
}

Value* emitSimpleExpr(
const TreeRef& tree,
const TypePtr& type_hint = nullptr) {
Expand All @@ -3136,6 +3152,21 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0)
->asValue(tree->tree(0)->range(), method);
auto const& lhs_type = lhs->type();
if (lhs_type == StringType::get()) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
} else {
return emitBinaryOp(tree);
}
}
ansley marked this conversation as resolved.
Show resolved Hide resolved
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,28 +3179,12 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
case TK_LSHIFT:
case TK_RSHIFT: {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto overload = getOperatorOverload(tree->kind(), inputs.size());
auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);

if (tree->kind() == TK_IN) {
// For `in` the arguments are in reverse order (the object being
// checked is second)
std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
}

return asSimple(
makeMagic(
overload, std::make_shared<BuiltinFunction>(kind, at::nullopt))
->call(tree->range(), method, named_values, {}, 0));
}
case TK_RSHIFT:
return emitBinaryOp(tree);
case TK_IS:
case TK_ISNOT:
case TK_AND:
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ RegisterOperators reg(
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::percentFormat(str self, ...) -> str",
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
Operator(
"aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
[](Stack* stack) {
Expand Down