Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiline string dedent support #45580

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_jit.py
Expand Up @@ -7205,6 +7205,20 @@ def f(x):
x = torch.rand(3, 4)
self.assertEqual(scripted_f(x), f(x))

def test_multiline_string_dedents(self):
def foo() -> None:
multiline_string_dedent_1 = """
This is a string dedent """
multiline_string_dedent_2 = """ This is a
string dedent """
multiline_string_dedent_3 = """
This is a string
dedent """
multiline_string_dedent_4 = """ This is a string dedent """

scripted_foo = torch.jit.script(foo)
self.assertEqual(scripted_foo(), foo())

# adapted from test in test_torch
def test_tensor_to(self):
template = dedent('''
Expand Down
29 changes: 29 additions & 0 deletions torch/jit/frontend.py
Expand Up @@ -178,6 +178,34 @@ def get_jit_class_def(cls, self_name):
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, False)
return build_class_def(ctx, py_ast.body[0], methods, properties, self_name)

def check_and_indent_multiline_strings(sourcelines):
"""
This is a helper function which checks for multiline strings and
indents the strings by calculating the leading space and appending
the spaces to each line of the multiline string.The failure to indent
multiline strings causes failures during downstream dedent
Arguments:
sourcelines: This is an array of source lines of the function
Returns:
This function returns the updated indented sources,i.e,sourcelines
"""
indices = []
triple_quotes = '\"\"\"'
# Extract the start and end line number of the multiline string
for index, source in enumerate(sourcelines):
if triple_quotes in source and source.find(triple_quotes) == source.rfind(triple_quotes):
indices.append(index)

# Adding leading space for every line of the multiline string
indices_length = len(indices)
for i in range(0, indices_length, 2):
if i + 1 < indices_length:
start = indices[i]
end = indices[i + 1]
leading_space = len(sourcelines[start]) - len(sourcelines[start].lstrip())
for lines in range(start + 1, end + 1):
sourcelines[lines] = ' ' * leading_space + sourcelines[lines]
return sourcelines

def get_jit_def(fn, def_name, self_name=None):
"""
Expand All @@ -195,6 +223,7 @@ def _forward(self):
self_name: If this function is a method, what the type name of `self` is.
"""
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = check_and_indent_multiline_strings(sourcelines)
source = ''.join(sourcelines)
dedent_src = dedent(source)
py_ast = ast.parse(dedent_src)
Expand Down