-
Notifications
You must be signed in to change notification settings - Fork 524
/
Copy pathgraph.py
86 lines (68 loc) · 2.69 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import warnings
from typing import Any, Optional
import torch
import torch.fx as fx
from executorch.exir.tensor import TensorSpec
class ExportGraph:
"""
ExportGraph serves as a layer between EXIR and FX Graph API.
It enforces EXIR-specific invariants (ex. having nodes contain specs)
"""
owning_module: fx.GraphModule
_graph: fx.Graph
def __init__(self, owning_module: fx.GraphModule, graph: fx.Graph) -> None:
self.owning_module = owning_module
self._graph = graph
@property
def nodes(self) -> fx.graph._node_list:
"""
Get the list of Nodes that constitute this Graph.
"""
return self._graph.nodes
def erase_node(self, to_erase: fx.Node) -> None:
"""
Erases a ``Node`` from the ``Graph``. Throws an exception if
there are still users of that node in the ``Graph``.
"""
return self._graph.erase_node(to_erase)
def inserting_before(self, n: Optional[fx.Node] = None) -> fx.graph._InsertPoint:
"""
Sets the point at which we will insert the graph.
"""
return self._graph.inserting_before(n)
# pyre-ignore
def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> fx.Node:
"""
Inserts a ``get_attr`` node into the Graph.
"""
node = self._graph.get_attr(qualified_name, type_expr)
# Gets the actual value of the attribute if it exists so that we can use
# it to set the 'spec' metadata
def _maybe_get_attr_value(
mod: torch.nn.Module, qualified_name: str
) -> Optional[torch.Tensor]:
module_path, _, name = qualified_name.rpartition(".")
try:
submod: torch.nn.Module = mod.get_submodule(module_path)
except AttributeError:
warnings.warn(f"Failed to fetch module {module_path}!", stacklevel=1)
return None
# See if the value is a buffer
if name in submod._buffers:
return submod._buffers[name]
# See if the value is a parameter
if hasattr(submod, name):
attr = getattr(submod, name)
if isinstance(attr, torch.nn.Parameter):
return attr
return None
buffer = _maybe_get_attr_value(self.owning_module, qualified_name)
if buffer is not None:
node.meta["spec"] = TensorSpec.from_tensor(buffer, True)
return node