Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support %-based string formatting #45976

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
1 change: 1 addition & 0 deletions c10/util/C++17.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <utility>
#include <memory>
#include <sstream>
#include <iostream>
ansley marked this conversation as resolved.
Show resolved Hide resolved
#include <string>
#include <cstdlib>
#include <functional>
Expand Down
114 changes: 114 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,6 +2581,120 @@ def forward(self, x):


class TestScript(JitTestCase):

ansley marked this conversation as resolved.
Show resolved Hide resolved
def test_modulo_operator(self):
@torch.jit.script
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.assertEqual(1, fn(5, 2))

def test_string_interpolation_with_string_placeholder_and_string_varaiable(self):
@torch.jit.script
def fn(arg1: str):
return "%s in template" % arg1
self.assertEqual("foo in template", fn("foo"))

def test_string_interpolation_with_string_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_alternate_digit_placeholder(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_string_varaiable(self):
with self.assertRaisesRegex(RuntimeError, "%d format: A number is required, not String"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1
fn("1")

def test_string_interpolation_with_exponent_placeholder_and_string_varaiable(self):
with self.assertRaisesRegex(RuntimeError, "%e format: A number is required, not String"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%e in template" % arg1
fn("1")

def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.assertEqual("1.000000e+00 in template", fn(1))

def test_string_interpolation_with_capital_exponent_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.assertEqual("1.000000E+00 in template", fn(1))

def test_string_interpolation_with_float_placeholder_and_float_varaiable(self):
@torch.jit.script
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1.0))

def test_string_interpolation_with_float_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1))

def test_string_interpolation_with_char_placeholder_and_char_varaiable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn("a"))

def test_string_interpolation_with_char_placeholder_and_digit_varaiable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn(97))

def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "%c requires int or char"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
fn("foo")

def test_string_interpolation_with_multiple_placeholders(self):
@torch.jit.script
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.assertEqual("foo 1 1.000000 in template", fn("foo", 1, 1))

def test_string_interpolation_with_subscript(self):
@torch.jit.script
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.assertEqual("foo in template", fn(["foo", "bar"]))

def test_string_interpolation_with_too_few_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1
fn("foo")

def test_string_interpolation_with_too_many_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2)
fn("foo", "bar")

def test_pretty_print_function(self):
@torch.jit.script
def foo(x):
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,16 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
if (tree->tree(0)->kind() == TK_STRINGLITERAL) {
ansley marked this conversation as resolved.
Show resolved Hide resolved
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
}
}
ansley marked this conversation as resolved.
Show resolved Hide resolved
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,7 +3158,6 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ RegisterOperators reg(
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::percentFormat(str self, ...) -> str",
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
Operator(
"aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
[](Stack* stack) {
Expand Down
102 changes: 102 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,108 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}

// IValue tags are intentionally private, so we need additional logic to cast
// the IValue type to the specified format.
std::string getFormattedArg(char key, const IValue& ival, int precision=6) {
//TODO: Implement precison-based formatting
std::stringstream ss;
bool added = false;
ansley marked this conversation as resolved.
Show resolved Hide resolved
switch (key) {
case 'd':
case 'i': {
ansley marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" format: A number is required, not ",
ival.tagKind());
if (ival.isDouble()) {
std::stringstream().swap(ss);
ss << static_cast<int>(ival.toDouble());
added = true;
}
break;
}
case 'e':
case 'E': {
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" format: A number is required, not ",
ansley marked this conversation as resolved.
Show resolved Hide resolved
ival.tagKind());
ss << std::setprecision(precision) << std::scientific;
if (key == 'E') {
ss << std::uppercase;
ansley marked this conversation as resolved.
Show resolved Hide resolved
}
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
added = true;
break;
}
case 'f':
case 'F': {
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" format: A number is required, not ",
ival.tagKind());
ss << std::setprecision(precision) << std::fixed;
ansley marked this conversation as resolved.
Show resolved Hide resolved
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
added = true;
break;
}
case 'c': {
TORCH_CHECK(ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1), "%c requires int or char");
if (ival.isInt()) {
ss << static_cast<char>(ival.toInt());
added = true;
}
break;
}
}
if (!added) {
ss << ival;
}
return ss.str();
ansley marked this conversation as resolved.
Show resolved Hide resolved
}

void percentFormat(Stack& stack, size_t num_inputs) {
std::vector<char> specifiers = {'d', 'i', 'e', 'E', 'f', 'F', 's', 'c'};
ansley marked this conversation as resolved.
Show resolved Hide resolved
auto format = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0];
auto args_size = !args.isTuple() ? 1 : args.toTuple()->elements().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TORCH_CHECK to ensure that there is only 1 input when its type is tuple.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what circumstances would there not be only one input? I believe that the modulo operator is always parsed into a Tree with a left child and a right child. This means that a test case like the one below would fail before hitting IREmitter at all.

# This test does not raise the correct error! Instead, it raises "cannot call a value of type 'Tuple[str, int]'"
def test_string_interpolation_with_double_tuple(self):
    with self.assertRaisesRegex(RuntimeError, "Argument to format string must be a single str or Tuple"):
        @torch.jit.script
        def fn(arg1: str, arg2: int, arg3: float, arg4: str) -> str:
            return "%s %d %f %s in template" % (arg1, arg2) (arg3, arg4)
        fn("foo", 1, 1.0, "bar")

std::stringstream ss;
size_t used_args = 0;
for (size_t begin = 0; true; ++used_args) {
ansley marked this conversation as resolved.
Show resolved Hide resolved
size_t loc = format.find('%', begin);
if (loc >= format.length() - 1) {
ansley marked this conversation as resolved.
Show resolved Hide resolved
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
char key = format.at(loc + 1);
TORCH_CHECK(std::find(specifiers.begin(), specifiers.end(), key) !=
specifiers.end(), "The specifier ", key, " is not supported in TorchScript");
auto arg = !args.isTuple() ? args : args.toTuple()->elements()[used_args];
ansley marked this conversation as resolved.
Show resolved Hide resolved
auto ins = getFormattedArg(key, arg);
ss << ins;
begin = loc + 2;
}
TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
drop(stack, num_inputs);
push(stack, ss.str());
}

void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ void tupleUnpack(Stack& stack);

void format(Stack& stack, size_t num_inputs);

void percentFormat(Stack& stack, size_t num_inputs);

void listUnpack(Stack& stack, size_t num_outputs);

void tupleConstruct(Stack& stack, size_t num_inputs);
Expand Down