Skip to content

Commit

Permalink
Support %-based string formatting
Browse files Browse the repository at this point in the history
ghstack-source-id: 23134fdb20f9e6c36ab1c76729debb8f32ed710f
Pull Request resolved: #45976
  • Loading branch information
Ansley Adelaide Ussery committed Oct 8, 2020
1 parent de0d0bd commit a388199
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ namespace c10 {
_(aten, append) \
_(aten, item) \
_(aten, format) \
_(aten, percentFormat) \
_(aten, __not__) \
_(aten, __is__) \
_(aten, __isnot__) \
Expand Down
1 change: 1 addition & 0 deletions c10/util/C++17.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <utility>
#include <memory>
#include <sstream>
#include <iostream>
#include <string>
#include <cstdlib>
#include <functional>
Expand Down
63 changes: 63 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,69 @@ def forward(self, x):


class TestScript(JitTestCase):
def test_percent_operator_overloading(self):
"""
Test that the '%' token can be parsed as both the modulo operator and as the string formatting operator
"""
def test_modulo_operator():
@torch.jit.script
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.assertEqual(1, fn(5, 2))

def test_string_interpolation_with_string_placeholder_and_string_varaiable():
@torch.jit.script
def fn(arg1: str):
return "%s in template" % arg1
self.assertEqual("string in template", fn("string"))

def test_string_interpolation_with_string_placeholder_and_digit_varaiable():
@torch.jit.script
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_digit_varaiable():
@torch.jit.script
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.assertEqual("1 in template", fn(1))

def test_string_interpolation_with_digit_placeholder_and_string_varaiable():
with self.assertRaises(RuntimeError):
@torch.jit.script
def fn(arg1: str) -> str:
return "%d in template" % arg1
fn("1")

def test_string_interpolation_with_float_placeholder_and_float_varaiable():
@torch.jit.script
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1.0))

def test_string_interpolation_with_float_placeholder_and_digit_varaiable():
@torch.jit.script
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.assertEqual("1.000000 in template", fn(1))

def test_string_interpolation_with_too_few_arguments():
with self.assertRaises(RuntimeError):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s %s in template" % arg1
fn("string")

test_modulo_operator()
test_string_interpolation_with_string_placeholder_and_string_varaiable()
test_string_interpolation_with_string_placeholder_and_digit_varaiable()
test_string_interpolation_with_digit_placeholder_and_digit_varaiable()
test_string_interpolation_with_digit_placeholder_and_string_varaiable()
test_string_interpolation_with_float_placeholder_and_float_varaiable()
test_string_interpolation_with_float_placeholder_and_digit_varaiable()
test_string_interpolation_with_too_few_arguments()

def test_pretty_print_function(self):
@torch.jit.script
def foo(x):
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,16 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
if (tree->tree(0)->kind() == TK_STRINGLITERAL) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
Value* output = graph->insertNode(node)->output();
output->setType(StringType::get());
return output;
}
}
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,7 +3158,6 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ RegisterOperators reg(
format(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
[](Stack* stack) {
Expand Down
73 changes: 73 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,79 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}

// IValue tags are intentionally private, so we need additional logic to cast
// the IValue type to the specified format.
std::string getFormattedArg(char key, const IValue& ival) {
std::stringstream ss;
bool added = false;
switch (key) {
case 'd':
case 'i': {
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" format: A number is required, not ",
ival.tagKind());
if (ival.isDouble()) {
std::stringstream().swap(ss);
ss << static_cast<int>(ival.toDouble());
added = true;
}
break;
}
case 'f': {
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" format: A number is required, not ",
ival.tagKind());
ss << std::setprecision(6) << std::fixed;
if (ival.isInt()) {
ss << static_cast<float>(ival.toInt());
} else {
ss << static_cast<float>(ival.toDouble());
}
added = true;
break;
}
}
if (!added) {
ss << ival;
}
return ss.str();
}

void percentFormat(Stack& stack, size_t num_inputs) {
// TODO: add support for more specifiers
std::vector<char> specifiers = {'d', 'i', 'f', 's'};
auto format = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find('%', begin);
if (loc >= format.length() - 1) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
TORCH_CHECK(
used_args < args.size(),
"Too few arguments for format string: ",
format);
char key = format.at(loc + 1);
if (std::find(specifiers.begin(), specifiers.end(), key) !=
specifiers.end()) {
auto ins = getFormattedArg(key, args[used_args]);
ss << ins;
begin = loc + 2;
}
}
drop(stack, num_inputs);
push(stack, ss.str());
}

void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ void tupleUnpack(Stack& stack);

void format(Stack& stack, size_t num_inputs);

void percentFormat(Stack& stack, size_t num_inputs);

void listUnpack(Stack& stack, size_t num_outputs);

void tupleConstruct(Stack& stack, size_t num_inputs);
Expand Down

0 comments on commit a388199

Please sign in to comment.