Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support %-based string formatting #45976

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
156 changes: 156 additions & 0 deletions test/jit/test_string_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import sys

import torch
from typing import List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

class TestStringFormatting(JitTestCase):

def test_modulo_operator(self):
@torch.jit.script
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.assertEqual(1, fn(5, 2))

def test_string_interpolation_with_string_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str):
return "%s in template" % arg1
self.assertEqual("foo in template", fn("foo"))
ansley marked this conversation as resolved.
Show resolved Hide resolved

def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
@torch.jit.script
def fn(arg1: str):
return arg1 % "foo"
self.assertEqual("foo in template", fn("%s in template"))

def test_string_interpolation_with_double_percent_in_string(self):
@torch.jit.script
def fn(arg1: str):
return "%s in template %%" % arg1
self.assertEqual("foo in template %", fn("foo"))

def test_string_interpolation_with_percent_in_string(self):
with self.assertRaisesRegex(RuntimeError, "Incomplete format specifier"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s in template %" % arg1
ansley marked this conversation as resolved.
Show resolved Hide resolved
fn("foo")

def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_alternate_digit_placeholder(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but a number is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1
fn("1")
ansley marked this conversation as resolved.
Show resolved Hide resolved

def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but a number is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%e in template" % arg1
fn("1")

def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.assertEqual("1.000000e+00 in template", fn(1))

def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.assertEqual("1.000000E+00 in template", fn(1))

def test_string_interpolation_with_float_placeholder_and_float_variable(self):
@torch.jit.script
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1.0))

def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1))

def test_string_interpolation_with_char_placeholder_and_char_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn("a"))

def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn(97))

def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but an int or char is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
fn("foo")

def test_string_interpolation_with_multiple_placeholders(self):
@torch.jit.script
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.assertEqual("foo 1 1.000000 in template", fn("foo", 1, 1))

#def test_string_interpolation_with_double_tuple(self):
# with self.assertRaisesRegex(RuntimeError, "Argument to format string must be a single str or Tuple"):
# @torch.jit.script
# def fn(arg1: str, arg2: int, arg3: float, arg4: str) -> str:
# return "%s %d %f %s in template" % (arg1, arg2) (arg3, arg4)
# fn("foo", 1, 1.0, "bar")

def test_string_interpolation_with_subscript(self):
@torch.jit.script
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.assertEqual("foo in template", fn(["foo", "bar"]))

def test_string_interpolation_with_too_few_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1
fn("foo")

def test_string_interpolation_with_too_many_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2)
ansley marked this conversation as resolved.
Show resolved Hide resolved
fn("foo", "bar")
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jit.test_onnx_export import TestONNXExport # noqa: F401
from jit.test_with import TestWith # noqa: F401
from jit.test_enum import TestEnum # noqa: F401
from jit.test_string_formatting import TestStringFormatting # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,19 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
auto first_arg = emitSugaredExpr(Expr(tree->tree(0)), 0)->asValue(tree->tree(0)->range(), method);
ansley marked this conversation as resolved.
Show resolved Hide resolved
auto first_arg_type = first_arg->type()->kind();
auto first_arg_type_string = std::string(typeKindToString(first_arg_type));
ansley marked this conversation as resolved.
Show resolved Hide resolved
if (type_string.compare("StringType") == 0) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
}
}
ansley marked this conversation as resolved.
Show resolved Hide resolved
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,7 +3161,6 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ RegisterOperators reg(
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::percentFormat(str self, ...) -> str",
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
Operator(
"aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
[](Stack* stack) {
Expand Down
110 changes: 110 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,116 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}

// IValue tags are intentionally private, so we need additional logic to cast
// the IValue type to the specified format.
void addFormattedArg(char key, const IValue& ival, std::stringstream& ss, int precision=6) {
//TODO: Implement precison-based formatting
switch (key) {
case 'd':
case 'i': {
ansley marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
if (ival.isInt()) {
ss << ival.toInt();
}
else {
ss << static_cast<int>(ival.toDouble());
}
break;
}
case 'e':
case 'E': {
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
ss << std::setprecision(precision) << std::scientific;
if (key == 'E') {
ss << std::uppercase;
ansley marked this conversation as resolved.
Show resolved Hide resolved
}
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
break;
}
case 'f':
case 'F': {
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
ss << std::setprecision(precision) << std::fixed;
ansley marked this conversation as resolved.
Show resolved Hide resolved
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
break;
}
case 'c': {
TORCH_CHECK(ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1), "Got ", ival.tagKind(), ", but an int or char is required for formatting");
if (ival.isInt()) {
ss << static_cast<char>(ival.toInt());
}
ansley marked this conversation as resolved.
Show resolved Hide resolved
else {
ss << ival.toStringRef();
}
break;
}
case 's': {
if (ival.isString()) {
ss << ival.toStringRef();
}
else {
ss << ival;
}
break;
}
default: {
TORCH_CHECK(false, "The specifier ", key, " is not supported in TorchScript");
}
}
}

void percentFormat(Stack& stack, size_t num_inputs) {
auto format = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0];
auto args_size = 1; // assumed size
if (args.isTuple()) {
args_size = args.toTuple()->elements().size();
}
std::stringstream ss;
size_t used_args = 0;
size_t begin = 0;
while (true) {
size_t loc = format.find('%', begin);
if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
TORCH_CHECK(loc < format.length() - 1, "Incomplete format specifier");
ss << format.substr(begin, loc - begin);
if (format.at(loc + 1) == '%') {
ss << '%';
begin = loc + 2;
ansley marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
char key = format.at(loc + 1);
IValue arg;
if (args.isTuple()) {
arg = args.toTuple()->elements()[used_args];
}
else {
arg = args;
}
addFormattedArg(key, arg, ss);
begin = loc + 2;
++used_args;
}
TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
drop(stack, num_inputs);
push(stack, ss.str());
}

void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ void tupleUnpack(Stack& stack);

void format(Stack& stack, size_t num_inputs);

void percentFormat(Stack& stack, size_t num_inputs);

void listUnpack(Stack& stack, size_t num_outputs);

void tupleConstruct(Stack& stack, size_t num_inputs);
Expand Down