Skip to content

Commit

Permalink
Simplify dev_tools.notebooks.utils.rewrite_notebook (#6030)
Browse files Browse the repository at this point in the history
* Simplify dev_tools.notebooks.utils.rewrite_notebook

- Do not return the OS file descriptor, it was only used
  to close the file in rewrite_notebook callers
- Return only the filename of the rewritten notebook
- Always create a temporary file, even if identical to the input notebook,
  so it can be safely removed after a test
- Clean up temporary files produced by rewrite_notebook in the tests
  • Loading branch information
pavoljuhas committed Mar 14, 2023
1 parent 8b97aa5 commit f636c5f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 39 deletions.
6 changes: 2 additions & 4 deletions dev_tools/notebooks/isolated_notebook_test.py
Expand Up @@ -177,7 +177,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env):

notebook_file = os.path.basename(notebook_path)

rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path)
rewritten_notebook_path = rewrite_notebook(notebook_path)

cmd = f"""
mkdir -p out/{notebook_rel_dir}
Expand Down Expand Up @@ -208,9 +208,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env):
f"instead of `pip install cirq` to this notebook, and exclude it from "
f"dev_tools/notebooks/isolated_notebook_test.py."
)

if rewritten_notebook_descriptor:
os.close(rewritten_notebook_descriptor)
os.remove(rewritten_notebook_path)


@pytest.mark.parametrize("notebook_path", NOTEBOOKS_DEPENDING_ON_UNRELEASED_FEATURES)
Expand Down
6 changes: 2 additions & 4 deletions dev_tools/notebooks/notebook_test.py
Expand Up @@ -67,7 +67,7 @@ def test_notebooks_against_released_cirq(notebook_path):
notebook_file = os.path.basename(notebook_path)
notebook_rel_dir = os.path.dirname(os.path.relpath(notebook_path, "."))
out_path = f"out/{notebook_rel_dir}/{notebook_file[:-6]}.out.ipynb"
rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path)
rewritten_notebook_path = rewrite_notebook(notebook_path)
papermill_flags = "--no-request-save-on-cell-execute --autosave-cell-every 0"
cmd = f"""mkdir -p out/{notebook_rel_dir}
papermill {rewritten_notebook_path} {out_path} {papermill_flags}"""
Expand All @@ -83,6 +83,4 @@ def test_notebooks_against_released_cirq(notebook_path):
f"notebook (in Github Actions, you can download it from the workflow artifact"
f" 'notebook-outputs')"
)

if rewritten_notebook_descriptor:
os.close(rewritten_notebook_descriptor)
os.remove(rewritten_notebook_path)
41 changes: 20 additions & 21 deletions dev_tools/notebooks/utils.py
Expand Up @@ -69,7 +69,7 @@ def rewrite_notebook(notebook_path):
* Lines in this file without `->` are ignored.
* Lines in this file with `->` are split into two (if there are mulitple `->` it is an
* Lines in this file with `->` are split into two (if there are multiple `->` it is an
error). The first of these is compiled into a pattern match, via `re.compile`, and
the second is the replacement for that match.
Expand All @@ -82,42 +82,41 @@ def rewrite_notebook(notebook_path):
It is the responsibility of the caller of this method to delete the new file.
Returns:
Tuple of a file descriptor and the file path for the rewritten file. If no `.tst` file
was found, then the file descriptor is None and the path is `notebook_path`.
The absolute path to the rewritten file in temporary directory.
If no `.tst` file exists the new file is a copy of the input notebook.
Raises:
AssertionError: If there are multiple `->` per line, or not all of the replacements
are used.
"""
notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst'
if not os.path.exists(notebook_test_path):
return None, notebook_path

# Get the rewrite rules.
patterns = []
with open(notebook_test_path, 'r') as f:
for line in f:
if '->' in line:
notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst'
if os.path.exists(notebook_test_path):
with open(notebook_test_path, 'r') as f:
pattern_lines = (line for line in f if '->' in line)
for line in pattern_lines:
parts = line.rstrip().split('->')
assert len(parts) == 2, f'Replacement lines may only contain one -> but was {line}'
patterns.append((re.compile(parts[0]), parts[1]))

used_patterns = set()
with open(notebook_path, 'r') as original_file:
new_file_descriptor, new_file_path = tempfile.mkstemp(suffix='.ipynb')
with open(new_file_path, 'w') as new_file:
for line in original_file:
new_line = line
for pattern, replacement in patterns:
new_line = pattern.sub(replacement, new_line)
if new_line != line:
used_patterns.add(pattern)
break
new_file.write(new_line)
lines = original_file.readlines()
for i, line in enumerate(lines):
for pattern, replacement in patterns:
new_line = pattern.sub(replacement, line)
if new_line != line:
lines[i] = new_line
used_patterns.add(pattern)
break

assert len(patterns) == len(used_patterns), (
'Not all patterns where used. Patterns not used: '
f'{set(x for x, _ in patterns) - used_patterns}'
)

return new_file_descriptor, new_file_path
with tempfile.NamedTemporaryFile(mode='w', suffix='-rewrite.ipynb', delete=False) as new_file:
new_file.writelines(lines)

return new_file.name
21 changes: 11 additions & 10 deletions dev_tools/notebooks/utils_test.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import filecmp
import os
import shutil
import tempfile
Expand All @@ -37,40 +38,40 @@ def write_test_data(ipynb_txt, tst_txt):
def test_rewrite_notebook():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

assert path != ipynb_path
with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 4'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


def test_rewrite_notebook_multiple():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\nd = 4->d = 1')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 1'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


def test_rewrite_notebook_ignore_non_seperator_lines():
directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\n# comment')

descriptor, path = dt.rewrite_notebook(ipynb_path)
path = dt.rewrite_notebook(ipynb_path)

with open(path, 'r') as f:
rewritten = f.read()
assert rewritten == 'd = 3\nd = 4'

os.close(descriptor)
os.remove(path)
shutil.rmtree(directory)


Expand All @@ -80,11 +81,11 @@ def test_rewrite_notebook_no_tst_file():
with open(ipynb_path, 'w') as f:
f.write('d = 5\nd = 4')

descriptor, path = dt.rewrite_notebook(ipynb_path)

assert descriptor is None
assert path == ipynb_path
path = dt.rewrite_notebook(ipynb_path)
assert path != ipynb_path
assert filecmp.cmp(path, ipynb_path)

os.remove(path)
shutil.rmtree(directory)


Expand Down

0 comments on commit f636c5f

Please sign in to comment.