Skip to content

Commit

Permalink
Support %-based string formatting
Browse files Browse the repository at this point in the history
ghstack-source-id: f0e3fc4fc230c283047fd5d211f9ac3b8a32f176
Pull Request resolved: #45976
  • Loading branch information
Ansley Adelaide Ussery committed Oct 7, 2020
1 parent 5ce31b6 commit e934078
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 1 deletion.
43 changes: 43 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2585,6 +2585,49 @@ def forward(self, x):


class TestScript(JitTestCase):
def test_percent_operator_overloading(self):
"""
Test that the '%' token can be parsed as both the modulo operator and as the string formatting operator
"""
def test_str_format():
@torch.jit.script
def fn(arg1: str) -> str:
return "This is my {} in my template".format(arg1)
self.assertEqual("This is my string in my cool template", fn("string"))
print(fn.graph)

def test_modulo_operator():
@torch.jit.script
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.assertEqual(1, fn(5, 2))
print(fn.graph)

def test_string_interpolation_with_string_variable():
@torch.jit.script
def fn(arg1: str) -> str:
return "This is my %s in template" % arg1
print(fn.graph)
self.assertEqual("This is my string in template", fn("string"))

def test_string_interpolation_with_digit_variable():
@torch.jit.script
def fn(arg1: int) -> str:
return "This is my %d in template" % arg1
self.assertEqual("This is my 1 in template", fn(1))

def test_string_interpolation_with_float_variable():
@torch.jit.script
def fn(arg1: int) -> str:
return "This is my %f in template" % arg1
self.assertEqual("This is my 1.0 in template", fn(1.0))

#test_str_format()
#test_modulo_operator()
test_string_interpolation_with_string_variable()
#test_string_interpolation_with_digit_variable()
#test_string_interpolation_with_float_variable()

def test_pretty_print_function(self):
@torch.jit.script
def foo(x):
Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,6 +3136,14 @@ struct to_ir {
return emitBuiltinCall(
tree->range(), *method.graph(), kind, named_values, {});
}
case '%': {
if (tree->tree(0)->kind() == TK_STRINGLITERAL) {
auto values = getValues(tree->trees(), /*maybe_unpack=*/false);
auto node = graph->create(aten::percentFormat, values, 1)
->setSourceRange(tree->range());
return graph->insertNode(node)->output();
}
}
case TK_IN:
case TK_POW:
case TK_NE:
Expand All @@ -3148,7 +3156,6 @@ struct to_ir {
case '/':
case '+':
case '-':
case '%':
case '&':
case '|':
case '^':
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ RegisterOperators reg(
format(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
[](Stack* stack) {
size_t num_inputs = pop(stack).toInt();
percentFormat(*stack, num_inputs);
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
[](Stack* stack) {
Expand Down
60 changes: 60 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,66 @@ void format(Stack& stack, size_t num_inputs) {
push(stack, ss.str());
}

// IValue tags are intentionally private, so we need additional logic to cast the
// IValue type to the specified format.
std::string getFormattedArg(char key, IValue ival) {
std::stringstream ss;
ss << ival;
switch (key) {
case 'd':
case 'i': {
if (!ival.isScalar()) {
//AT_ERROR("%" + key + " A number is required, not " + ival.tagKind(), format);
AT_ERROR(" A number is required", format);
}
else if (ival.isDouble()) {
std::stringstream().swap(ss);
ss << (int)ival.toDouble();
}
break;
}
case 'f': {
if (!ival.isScalar()) {
//AT_ERROR("%" + key + " A number is required, not " + ival.tagKind(), format);
AT_ERROR(" A number is required", format);
}
else if (ival.isInt()) {
std::stringstream().swap(ss);
ss << (double)ival.toInt();
}
break;
}
}
return ss.str();
}

void percentFormat(Stack& stack, size_t num_inputs) {
// TODO: add support for more specifiers
std::vector<char> specifiers = { 'd', 'i', 'f', 's'};
auto format = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find("%", begin);
if (loc >= format.length() - 1) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
if (used_args >= args.size()) {
AT_ERROR("Too few arguments for format string: ", format);
}
char key = format.at(loc + 1);
if (std::find(specifiers.begin(), specifiers.end(), key) != specifiers.end()) {
auto ins = getFormattedArg(key, args[used_args]);
ss << ins;
begin = loc + 2;
}
}
drop(stack, num_inputs);
push(stack, ss.str());
}

void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/vararg_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ void tupleUnpack(Stack& stack);

void format(Stack& stack, size_t num_inputs);

void percentFormat(Stack& stack, size_t num_inputs);

void listUnpack(Stack& stack, size_t num_outputs);

void tupleConstruct(Stack& stack, size_t num_inputs);
Expand Down

0 comments on commit e934078

Please sign in to comment.