From 550b70ff37ecd5441fee6cad8daf64389448f30a Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 24 Dec 2020 19:39:33 +0000 Subject: [PATCH] Add type annotations to _tensorboard_vis.py and hipify_python.py closes gh-49833 --- mypy.ini | 18 ------------------ torch/_C/__init__.pyi.in | 4 ++++ torch/contrib/_tensorboard_vis.py | 3 ++- torch/utils/hipify/hipify_python.py | 4 +++- 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/mypy.ini b/mypy.ini index 8c900bcced76..7d6161bddd17 100644 --- a/mypy.ini +++ b/mypy.ini @@ -104,24 +104,6 @@ ignore_errors = True [mypy-torch._utils] ignore_errors = True -[mypy-torch._overrides] -ignore_errors = True - -[mypy-torch.utils.tensorboard._caffe2_graph] -ignore_errors = True - -[mypy-torch.contrib._tensorboard_vis] -ignore_errors = True - -[mypy-torch.nn.utils.prune] -ignore_errors = True - -[mypy-torch.utils.show_pickle] -ignore_errors = True - -[mypy-torch.utils.hipify.hipify_python] -ignore_errors = True - [mypy-torch.utils.benchmark.examples.*] ignore_errors = True diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 79c93cb191f1..6427a4a4ed80 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -344,6 +344,10 @@ def _propagate_and_assign_input_shapes( propagate: _bool ) -> Graph: ... +# Defined in torch/csrc/jit/runtime/graph_executor.h +class GraphExecutorState: + ... + # Defined in torch/torch/csrc/jit/ir/ir.h class Graph: def eraseInput(self, i: _int) -> None: ... diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index 8f4ca71ff202..e939059762ef 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -1,6 +1,7 @@ import time from collections import defaultdict from functools import partial +from typing import DefaultDict import torch @@ -104,7 +105,7 @@ def inline_graph(subgraph, name, node): for out, val in zip(subgraph.outputs(), node.outputs()): value_map[val.unique()] = rec_value_map[out.unique()] - op_id_counter = defaultdict(int) + op_id_counter: DefaultDict[str, int] = defaultdict(int) def name_for(node): kind = node.kind()[node.kind().index('::') + 2:] diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index d1639d20adba..adc480793d82 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -782,7 +782,9 @@ def repl(m): os.path.relpath(header_filepath, output_directory), all_files, includes, stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) - return templ.format(os.path.relpath(HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"], header_dir)) + value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"] + assert value is not None + return templ.format(os.path.relpath(value, header_dir)) return m.group(0) return repl