diff --git a/dev_tools/notebooks/isolated_notebook_test.py b/dev_tools/notebooks/isolated_notebook_test.py index c131994a7b9..1ac7b03312c 100644 --- a/dev_tools/notebooks/isolated_notebook_test.py +++ b/dev_tools/notebooks/isolated_notebook_test.py @@ -177,7 +177,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env): notebook_file = os.path.basename(notebook_path) - rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path) + rewritten_notebook_path = rewrite_notebook(notebook_path) cmd = f""" mkdir -p out/{notebook_rel_dir} @@ -208,9 +208,7 @@ def test_notebooks_against_released_cirq(partition, notebook_path, cloned_env): f"instead of `pip install cirq` to this notebook, and exclude it from " f"dev_tools/notebooks/isolated_notebook_test.py." ) - - if rewritten_notebook_descriptor: - os.close(rewritten_notebook_descriptor) + os.remove(rewritten_notebook_path) @pytest.mark.parametrize("notebook_path", NOTEBOOKS_DEPENDING_ON_UNRELEASED_FEATURES) diff --git a/dev_tools/notebooks/notebook_test.py b/dev_tools/notebooks/notebook_test.py index ebe08b275b0..5e4ec69e7b7 100644 --- a/dev_tools/notebooks/notebook_test.py +++ b/dev_tools/notebooks/notebook_test.py @@ -67,7 +67,7 @@ def test_notebooks_against_released_cirq(notebook_path): notebook_file = os.path.basename(notebook_path) notebook_rel_dir = os.path.dirname(os.path.relpath(notebook_path, ".")) out_path = f"out/{notebook_rel_dir}/{notebook_file[:-6]}.out.ipynb" - rewritten_notebook_descriptor, rewritten_notebook_path = rewrite_notebook(notebook_path) + rewritten_notebook_path = rewrite_notebook(notebook_path) papermill_flags = "--no-request-save-on-cell-execute --autosave-cell-every 0" cmd = f"""mkdir -p out/{notebook_rel_dir} papermill {rewritten_notebook_path} {out_path} {papermill_flags}""" @@ -83,6 +83,4 @@ def test_notebooks_against_released_cirq(notebook_path): f"notebook (in Github Actions, you can download it from the workflow artifact" f" 'notebook-outputs')" ) - - if rewritten_notebook_descriptor: - os.close(rewritten_notebook_descriptor) + os.remove(rewritten_notebook_path) diff --git a/dev_tools/notebooks/utils.py b/dev_tools/notebooks/utils.py index 3823702c815..a7e8d6d8f6c 100644 --- a/dev_tools/notebooks/utils.py +++ b/dev_tools/notebooks/utils.py @@ -69,7 +69,7 @@ def rewrite_notebook(notebook_path): * Lines in this file without `->` are ignored. - * Lines in this file with `->` are split into two (if there are mulitple `->` it is an + * Lines in this file with `->` are split into two (if there are multiple `->` it is an error). The first of these is compiled into a pattern match, via `re.compile`, and the second is the replacement for that match. @@ -82,42 +82,41 @@ def rewrite_notebook(notebook_path): It is the responsibility of the caller of this method to delete the new file. Returns: - Tuple of a file descriptor and the file path for the rewritten file. If no `.tst` file - was found, then the file descriptor is None and the path is `notebook_path`. + The absolute path to the rewritten file in temporary directory. + If no `.tst` file exists the new file is a copy of the input notebook. Raises: AssertionError: If there are multiple `->` per line, or not all of the replacements are used. """ - notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst' - if not os.path.exists(notebook_test_path): - return None, notebook_path - # Get the rewrite rules. patterns = [] - with open(notebook_test_path, 'r') as f: - for line in f: - if '->' in line: + notebook_test_path = os.path.splitext(notebook_path)[0] + '.tst' + if os.path.exists(notebook_test_path): + with open(notebook_test_path, 'r') as f: + pattern_lines = (line for line in f if '->' in line) + for line in pattern_lines: parts = line.rstrip().split('->') assert len(parts) == 2, f'Replacement lines may only contain one -> but was {line}' patterns.append((re.compile(parts[0]), parts[1])) used_patterns = set() with open(notebook_path, 'r') as original_file: - new_file_descriptor, new_file_path = tempfile.mkstemp(suffix='.ipynb') - with open(new_file_path, 'w') as new_file: - for line in original_file: - new_line = line - for pattern, replacement in patterns: - new_line = pattern.sub(replacement, new_line) - if new_line != line: - used_patterns.add(pattern) - break - new_file.write(new_line) + lines = original_file.readlines() + for i, line in enumerate(lines): + for pattern, replacement in patterns: + new_line = pattern.sub(replacement, line) + if new_line != line: + lines[i] = new_line + used_patterns.add(pattern) + break assert len(patterns) == len(used_patterns), ( 'Not all patterns where used. Patterns not used: ' f'{set(x for x, _ in patterns) - used_patterns}' ) - return new_file_descriptor, new_file_path + with tempfile.NamedTemporaryFile(mode='w', suffix='-rewrite.ipynb', delete=False) as new_file: + new_file.writelines(lines) + + return new_file.name diff --git a/dev_tools/notebooks/utils_test.py b/dev_tools/notebooks/utils_test.py index c42d4db293c..149f87b1cb5 100644 --- a/dev_tools/notebooks/utils_test.py +++ b/dev_tools/notebooks/utils_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import filecmp import os import shutil import tempfile @@ -37,40 +38,40 @@ def write_test_data(ipynb_txt, tst_txt): def test_rewrite_notebook(): directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3') - descriptor, path = dt.rewrite_notebook(ipynb_path) + path = dt.rewrite_notebook(ipynb_path) assert path != ipynb_path with open(path, 'r') as f: rewritten = f.read() assert rewritten == 'd = 3\nd = 4' - os.close(descriptor) + os.remove(path) shutil.rmtree(directory) def test_rewrite_notebook_multiple(): directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\nd = 4->d = 1') - descriptor, path = dt.rewrite_notebook(ipynb_path) + path = dt.rewrite_notebook(ipynb_path) with open(path, 'r') as f: rewritten = f.read() assert rewritten == 'd = 3\nd = 1' - os.close(descriptor) + os.remove(path) shutil.rmtree(directory) def test_rewrite_notebook_ignore_non_seperator_lines(): directory, ipynb_path = write_test_data('d = 5\nd = 4', 'd = 5->d = 3\n# comment') - descriptor, path = dt.rewrite_notebook(ipynb_path) + path = dt.rewrite_notebook(ipynb_path) with open(path, 'r') as f: rewritten = f.read() assert rewritten == 'd = 3\nd = 4' - os.close(descriptor) + os.remove(path) shutil.rmtree(directory) @@ -80,11 +81,11 @@ def test_rewrite_notebook_no_tst_file(): with open(ipynb_path, 'w') as f: f.write('d = 5\nd = 4') - descriptor, path = dt.rewrite_notebook(ipynb_path) - - assert descriptor is None - assert path == ipynb_path + path = dt.rewrite_notebook(ipynb_path) + assert path != ipynb_path + assert filecmp.cmp(path, ipynb_path) + os.remove(path) shutil.rmtree(directory)