Skip to content

Commit

Permalink
Support %-based string formatting
Browse files Browse the repository at this point in the history
ghstack-source-id: c45db31e90244ce83cede64109bfcd5360351220
Pull Request resolved: #45976
  • Loading branch information
Ansley Adelaide Ussery committed Oct 12, 2020
1 parent a5c0dbc commit 005970c
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
156 changes: 156 additions & 0 deletions test/jit/test_string_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import sys

import torch
from typing import List

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

class TestStringFormatting(JitTestCase):

def test_modulo_operator(self):
@torch.jit.script
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.assertEqual(1, fn(5, 2))

def test_string_interpolation_with_string_placeholder_and_string_variable(self):
@torch.jit.script
def fn(arg1: str):
return "%s in template" % arg1
self.assertEqual("foo in template", fn("foo"))

def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
@torch.jit.script
def fn(arg1: str):
return arg1 % "foo"
self.assertEqual("foo in template", fn("%s in template"))

def test_string_interpolation_with_double_percent_in_string(self):
@torch.jit.script
def fn(arg1: str):
return "%s in template %%" % arg1
self.assertEqual("foo in template %", fn("foo"))

def test_string_interpolation_with_percent_in_string(self):
with self.assertRaisesRegex(RuntimeError, "Incomplete format specifier"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s in template %" % arg1
fn("foo")

def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_alternate_digit_placeholder(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but a number is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1
fn("1")

def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but a number is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%e in template" % arg1
fn("1")

def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.assertEqual("1.000000e+00 in template", fn(1))

def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.assertEqual("1.000000E+00 in template", fn(1))

def test_string_interpolation_with_float_placeholder_and_float_variable(self):
@torch.jit.script
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1.0))

def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1))

def test_string_interpolation_with_char_placeholder_and_char_variable(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn("a"))

def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
@torch.jit.script
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.assertEqual("a in template", fn(97))

def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
with self.assertRaisesRegex(RuntimeError, "Got String, but an int or char is required for formatting"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%c in template" % arg1
fn("foo")

def test_string_interpolation_with_multiple_placeholders(self):
@torch.jit.script
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.assertEqual("foo 1 1.000000 in template", fn("foo", 1, 1))

#def test_string_interpolation_with_double_tuple(self):
# with self.assertRaisesRegex(RuntimeError, "Argument to format string must be a single str or Tuple"):
# @torch.jit.script
# def fn(arg1: str, arg2: int, arg3: float, arg4: str) -> str:
# return "%s %d %f %s in template" % (arg1, arg2) (arg3, arg4)
# fn("foo", 1, 1.0, "bar")

def test_string_interpolation_with_subscript(self):
@torch.jit.script
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.assertEqual("foo in template", fn(["foo", "bar"]))

def test_string_interpolation_with_too_few_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too few arguments for format string"):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1
fn("foo")

def test_string_interpolation_with_too_many_arguments(self):
with self.assertRaisesRegex(RuntimeError, "Too many arguments for format string"):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2)
fn("foo", "bar")
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jit.test_onnx_export import TestONNXExport # noqa: F401
from jit.test_with import TestWith # noqa: F401
from jit.test_enum import TestEnum # noqa: F401
from jit.test_string_formatting import TestStringFormatting # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,19 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
auto first_arg = emitSugaredExpr(Expr(tree->tree(0)), 0)->asValue(tree->tree(0)->range(), method);
auto first_arg_type = first_arg->type()->kind();
auto first_arg_type_string = std::string(typeKindToString(first_arg_type));
if (type_string.compare("StringType") == 0) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
}
}
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,7 +3161,6 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ RegisterOperators reg(
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
Operator(
"aten::percentFormat(str self, ...) -> str",
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
Operator(
"aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
[](Stack* stack) {
Expand Down
110 changes: 110 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,116 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}

// IValue tags are intentionally private, so we need additional logic to cast
// the IValue type to the specified format.
void addFormattedArg(char key, const IValue& ival, std::stringstream& ss, int precision=6) {
//TODO: Implement precison-based formatting
switch (key) {
case 'd':
case 'i': {
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
if (ival.isInt()) {
ss << ival.toInt();
}
else {
ss << static_cast<int>(ival.toDouble());
}
break;
}
case 'e':
case 'E': {
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
ss << std::setprecision(precision) << std::scientific;
if (key == 'E') {
ss << std::uppercase;
}
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
break;
}
case 'f':
case 'F': {
TORCH_CHECK(
ival.isScalar(), "Got ", ival.tagKind(), ", but a number is required for formatting");
ss << std::setprecision(precision) << std::fixed;
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
break;
}
case 'c': {
TORCH_CHECK(ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1), "Got ", ival.tagKind(), ", but an int or char is required for formatting");
if (ival.isInt()) {
ss << static_cast<char>(ival.toInt());
}
else {
ss << ival.toStringRef();
}
break;
}
case 's': {
if (ival.isString()) {
ss << ival.toStringRef();
}
else {
ss << ival;
}
break;
}
default: {
TORCH_CHECK(false, "The specifier ", key, " is not supported in TorchScript");
}
}
}

void percentFormat(Stack& stack, size_t num_inputs) {
auto format = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0];
auto args_size = 1; // assumed size
if (args.isTuple()) {
args_size = args.toTuple()->elements().size();
}
std::stringstream ss;
size_t used_args = 0;
size_t begin = 0;
while (true) {
size_t loc = format.find('%', begin);
if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
TORCH_CHECK(loc < format.length() - 1, "Incomplete format specifier");
ss << format.substr(begin, loc - begin);
if (format.at(loc + 1) == '%') {
ss << '%';
begin = loc + 2;
continue;
}
TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
char key = format.at(loc + 1);
IValue arg;
if (args.isTuple()) {
arg = args.toTuple()->elements()[used_args];
}
else {
arg = args;
}
addFormattedArg(key, arg, ss);
begin = loc + 2;
++used_args;
}
TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
drop(stack, num_inputs);
push(stack, ss.str());
}

void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ void tupleUnpack(Stack& stack);

void format(Stack& stack, size_t num_inputs);

void percentFormat(Stack& stack, size_t num_inputs);

void listUnpack(Stack& stack, size_t num_outputs);

void tupleConstruct(Stack& stack, size_t num_inputs);
Expand Down

0 comments on commit 005970c

Please sign in to comment.