Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dev_tools.notebooks.utils.rewrite_notebook #6030

Merged
merged 8 commits into from
Mar 14, 2023
6 changes: 2 additions & 4 deletions dev_tools/notebooks/isolated_notebook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env):

notebook_file = os.path.basename(notebook_path)

rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path)
rewritten_notebook_path = rewrite_notebook(notebook_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we are (and have always been) leaking temporary files. Could you fix while you're at it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viathor - thank you for the review. The temporary files cleanup was a bit more involved - can you please take a quick look again?


cmd = f"""
mkdir -p out/{notebook_rel_dir}
Expand Down Expand Up @@ -208,9 +208,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env):
f"instead of `pip install cirq` to this notebook, and exclude it from "
f"dev_tools/notebooks/isolated_notebook_test.py."
)

if rewritten_notebook_descriptor:
os.close(rewritten_notebook_descriptor)
os.remove(rewritten_notebook_path)


@pytest.mark.parametrize("notebook_path", NOTEBOOKS_DEPENDING_ON_UNRELEASED_FEATURES)
Expand Down
6 changes: 2 additions & 4 deletions dev_tools/notebooks/notebook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_notebooks_against_released_cirq(notebook_path):
notebook_file = os.path.basename(notebook_path)
notebook_rel_dir = os.path.dirname(os.path.relpath(notebook_path, "."))
out_path = f"out/{notebook_rel_dir}/{notebook_file[:-6]}.out.ipynb"
rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path)
rewritten_notebook_path = rewrite_notebook(notebook_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

papermill_flags = "--no-request-save-on-cell-execute --autosave-cell-every 0"
cmd = f"""mkdir -p out/{notebook_rel_dir}
papermill {rewritten_notebook_path} {out_path} {papermill_flags}"""
Expand All @@ -83,6 +83,4 @@ def test_notebooks_against_released_cirq(notebook_path):
f"notebook (in Github Actions, you can download it from the workflow artifact"
f" 'notebook-outputs')"
)

if rewritten_notebook_descriptor:
os.close(rewritten_notebook_descriptor)
os.remove(rewritten_notebook_path)
41 changes: 20 additions & 21 deletions dev_tools/notebooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def rewrite_notebook(notebook_path):

* Lines in this file without `->` are ignored.

* Lines in this file with `->` are split into two (if there are mulitple `->` it is an
* Lines in this file with `->` are split into two (if there are multiple `->` it is an
error). The first of these is compiled into a pattern match, via `re.compile`, and
the second is the replacement for that match.

Expand All @@ -82,42 +82,41 @@ def rewrite_notebook(notebook_path):
It is the responsibility of the caller of this method to delete the new file.

Returns:
Tuple of a file descriptor and the file path for the rewritten file. If no `.tst` file
was found, then the file descriptor is None and the path is `notebook_path`.
The absolute path to the rewritten file in temporary directory.
If no `.tst` file exists the new file is a copy of the input notebook.

Raises:
AssertionError: If there are multiple `->` per line, or not all of the replacements
are used.
"""
notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst'
if not os.path.exists(notebook_test_path):
return None, notebook_path

# Get the rewrite rules.
patterns = []
with open(notebook_test_path, 'r') as f:
for line in f:
if '->' in line:
notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst'
if os.path.exists(notebook_test_path):
with open(notebook_test_path, 'r') as f:
pattern_lines = (line for line in f if '->' in line)
for line in pattern_lines:
parts = line.rstrip().split('->')
assert len(parts) == 2, f'Replacement lines may only contain one -> but was {line}'
patterns.append((re.compile(parts[0]), parts[1]))

used_patterns = set()
with open(notebook_path, 'r') as original_file:
new_file_descriptor, new_file_path = tempfile.mkstemp(suffix='.ipynb')
with open(new_file_path, 'w') as new_file:
for line in original_file:
new_line = line
for pattern, replacement in patterns:
new_line = pattern.sub(replacement, new_line)
if new_line != line:
used_patterns.add(pattern)
break
new_file.write(new_line)
lines = original_file.readlines()
for i, line in enumerate(lines):
for pattern, replacement in patterns:
new_line = pattern.sub(replacement, line)
if new_line != line:
lines[i] = new_line
used_patterns.add(pattern)
break

assert len(patterns) == len(used_patterns), (
'Not all patterns where used. Patterns not used: '
f'{set(x for x, _ in patterns) - used_patterns}'
)

return new_file_descriptor, new_file_path
with tempfile.NamedTemporaryFile(mode='w', suffix='-rewrite.ipynb', delete=False) as new_file:
new_file.writelines(lines)

return new_file.name
21 changes: 11 additions & 10 deletions dev_tools/notebooks/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import filecmp
import os
import shutil
import tempfile
Expand All @@ -37,40 +38,40 @@ def write_test_data(ipynb_txt, tst_txt):
def test_rewrite_notebook():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

assert path != ipynb_path
with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 4'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


def test_rewrite_notebook_multiple():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\nd = 4->d = 1')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 1'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


def test_rewrite_notebook_ignore_non_seperator_lines():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\n# comment')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 4'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


Expand All @@ -80,11 +81,11 @@ def test_rewrite_notebook_no_tst_file():
with open(ipynb_path, 'w') as f:
f.write('d = 5\nd = 4')

descriptor, path = dt.rewrite_notebook(ipynb_path)

assert descriptor is None
assert path == ipynb_path
path = dt.rewrite_notebook(ipynb_path)
assert path != ipynb_path
assert filecmp.cmp(path, ipynb_path)

os.remove(path)
shutil.rmtree(directory)


Expand Down