Skip to content

Commit

Permalink
Add type annotations to _tensorboard_vis.py and hipify_python.py
Browse files Browse the repository at this point in the history
closes gh-49833
  • Loading branch information
guilhermeleobas committed Dec 25, 2020
1 parent 963f762 commit 550b70f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 20 deletions.
18 changes: 0 additions & 18 deletions mypy.ini
Expand Up @@ -104,24 +104,6 @@ ignore_errors = True
[mypy-torch._utils]
ignore_errors = True

[mypy-torch._overrides]
ignore_errors = True

[mypy-torch.utils.tensorboard._caffe2_graph]
ignore_errors = True

[mypy-torch.contrib._tensorboard_vis]
ignore_errors = True

[mypy-torch.nn.utils.prune]
ignore_errors = True

[mypy-torch.utils.show_pickle]
ignore_errors = True

[mypy-torch.utils.hipify.hipify_python]
ignore_errors = True

[mypy-torch.utils.benchmark.examples.*]
ignore_errors = True

Expand Down
4 changes: 4 additions & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -344,6 +344,10 @@ def _propagate_and_assign_input_shapes(
propagate: _bool
) -> Graph: ...

# Defined in torch/csrc/jit/runtime/graph_executor.h
class GraphExecutorState:
...

# Defined in torch/torch/csrc/jit/ir/ir.h
class Graph:
def eraseInput(self, i: _int) -> None: ...
Expand Down
3 changes: 2 additions & 1 deletion torch/contrib/_tensorboard_vis.py
@@ -1,6 +1,7 @@
import time
from collections import defaultdict
from functools import partial
from typing import DefaultDict

import torch

Expand Down Expand Up @@ -104,7 +105,7 @@ def inline_graph(subgraph, name, node):
for out, val in zip(subgraph.outputs(), node.outputs()):
value_map[val.unique()] = rec_value_map[out.unique()]

op_id_counter = defaultdict(int)
op_id_counter: DefaultDict[str, int] = defaultdict(int)

def name_for(node):
kind = node.kind()[node.kind().index('::') + 2:]
Expand Down
4 changes: 3 additions & 1 deletion torch/utils/hipify/hipify_python.py
Expand Up @@ -782,7 +782,9 @@ def repl(m):
os.path.relpath(header_filepath, output_directory),
all_files, includes, stats, hip_clang_launch, is_pytorch_extension,
clean_ctx, show_progress)
return templ.format(os.path.relpath(HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"], header_dir))
value = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
assert value is not None
return templ.format(os.path.relpath(value, header_dir))

return m.group(0)
return repl
Expand Down

0 comments on commit 550b70f

Please sign in to comment.