Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/test/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, x):
class AddConstant(torch.nn.Module):
def __init__(self, constant):
super().__init__()
self._constant = constant
self.register_buffer("_constant", constant, persistent=False)

def forward(self, x):
out1 = x + self._constant
Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/test/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def test_fp32_addmm(self):
class AddMMModule(torch.nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.mat = torch.randn(out_size, in_size)
self.bias = torch.randn(1, out_size)
self.mat = torch.nn.Parameter(torch.randn(out_size, in_size))
self.bias = torch.nn.Parameter(torch.randn(1, out_size))

def forward(self, x):
return torch.addmm(self.bias, x, torch.transpose(self.mat, 0, 1))
Expand Down
9 changes: 6 additions & 3 deletions backends/xnnpack/test/test_xnnpack_utils_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def __init__(self, num_sequences, ops_per_sequence):
super().__init__()
self.num_ops = num_sequences * ops_per_sequence
self.num_sequences = num_sequences
self.op_sequence = [[] for _ in range(num_sequences)]
for seq in range(num_sequences):

self.op_sequence = torch.nn.ModuleList()
for _ in range(num_sequences):
inner = torch.nn.ModuleList()
for _ in range(ops_per_sequence):
self.op_sequence[seq].append(
inner.append(
torch.nn.Conv2d(
in_channels=1,
out_channels=1,
Expand All @@ -30,6 +32,7 @@ def __init__(self, num_sequences, ops_per_sequence):
bias=False,
)
)
self.op_sequence.append(inner)

def forward(self, x):
for seq in self.op_sequence:
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/xnnpack_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def preprocess(
verifier=EXIREdgeDialectVerifier(
check_edge_ops=False, enable=False, class_only=True
),
constants=ep.constants,
)

# XNNPACK Delegate Specific Passes
Expand Down
3 changes: 2 additions & 1 deletion exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def to_backend(

# TODO(angelayi): Update this signature in a less manual way (maybe through
# retracing)
new_signature, new_state_dict = _get_new_signature(
new_signature, new_state_dict, new_constants = _get_new_signature(
edge_program, tagged_graph_module
)
return ExportedProgram(
Expand All @@ -362,4 +362,5 @@ def to_backend(
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
example_inputs=None,
verifier=edge_program.verifier,
constants=new_constants,
)
23 changes: 19 additions & 4 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,11 @@ def arrange_graph_placeholders(
# TODO Don't regenerate new signature manually.
def _get_new_signature(
original_program: ExportedProgram, gm: torch.fx.GraphModule
) -> Tuple[ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]]]:
) -> Tuple[
ExportGraphSignature,
Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
Dict[str, Union[torch.Tensor, torch.ScriptObject]],
]:
old_signature = original_program.graph_signature

input_specs = []
Expand All @@ -397,6 +401,9 @@ def _get_new_signature(
input_specs=input_specs, output_specs=output_specs
)
new_state_dict = {}
new_constants = {}

non_persistent_buffers = set(old_signature.non_persistent_buffers)

for node in gm.graph.nodes:
if node.op == "placeholder":
Expand All @@ -417,17 +424,24 @@ def _get_new_signature(
]
elif node.name in old_signature.inputs_to_buffers:
buffer_name = old_signature.inputs_to_buffers[node.name]
persistent = buffer_name not in non_persistent_buffers
# add buffer to graph signature
input_specs.append(
InputSpec(
kind=InputKind.BUFFER,
arg=TensorArgument(name=node.name),
target=buffer_name,
persistent=persistent,
)
)

# add param to new_state_dict
new_state_dict[buffer_name] = original_program.state_dict[buffer_name]
if persistent:
new_state_dict[buffer_name] = original_program.state_dict[
buffer_name
]
else:
new_constants[buffer_name] = original_program.constants[buffer_name]
else:
# not param or buffer then user input
input_specs.append(
Expand All @@ -449,7 +463,7 @@ def _get_new_signature(
)
)

return new_signature, new_state_dict
return new_signature, new_state_dict, new_constants


def create_exported_program_from_submodule(
Expand All @@ -472,7 +486,7 @@ def create_exported_program_from_submodule(
submodule = arrange_graph_placeholders(submodule, owning_program)

# Get updated graph signature
subgraph_signature, subgraph_state_dict = _get_new_signature(
subgraph_signature, subgraph_state_dict, subgraph_constants = _get_new_signature(
owning_program, submodule
)

Expand All @@ -484,6 +498,7 @@ def create_exported_program_from_submodule(
range_constraints=copy.deepcopy(owning_program.range_constraints),
module_call_graph=[],
verifier=owning_program.verifier,
constants=subgraph_constants,
)


Expand Down
1 change: 1 addition & 0 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
kind=InputKind.BUFFER,
arg=TensorArgument(name=const_placeholder_node.name),
target=prop_constant_tensor_fqn,
persistent=True,
)
prop_constant_data.append(prop_constant_node_input_spec)
buffers.append(prop_constant_tensor_fqn)
Expand Down
8 changes: 7 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def _get_updated_graph_signature(
else type(old_input_spec.arg)(node.name)
)
new_input_specs.append(
InputSpec(old_input_spec.kind, arg, old_input_spec.target)
InputSpec(
old_input_spec.kind,
arg,
old_input_spec.target,
persistent=old_input_spec.persistent,
)
)
i += 1

Expand Down Expand Up @@ -196,6 +201,7 @@ def lift_constant_tensor_pass(ep):
kind=InputKind.BUFFER,
arg=TensorArgument(name=const_placeholder_node.name),
target=constant_tensor_fqn,
persistent=True,
)
)
buffers.append(constant_tensor_fqn)
Expand Down
1 change: 1 addition & 0 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec:
buffer=InputToBufferSpec(
arg=TensorArgument(name=spec.arg.name),
buffer_name=spec.target, # pyre-ignore
persistent=spec.persistent, # pyre-ignore
)
)
elif spec.kind == ep.InputKind.CONSTANT_TENSOR:
Expand Down
6 changes: 5 additions & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,12 @@ def test_compile_fix_broken_ops(self) -> None:
model: torch.nn.Linear = torch.nn.Linear(5, 5)

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = model

def forward(self, inp: torch.Tensor) -> torch.Tensor:
return model(inp)
return self.model(inp)

f = Foo()

Expand Down
4 changes: 2 additions & 2 deletions profiler/test/test_profiler_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 3 * torch.ones(2, 2, dtype=torch.float)
self.b = 2 * torch.ones(2, 2, dtype=torch.float)
self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float))
self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float))

def forward(self, x):
a = torch.mul(self.a, x)
Expand Down
1 change: 0 additions & 1 deletion sdk/bundled_program/util/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:schema",
"//executorch/sdk/bundled_program:config",
],
)
32 changes: 28 additions & 4 deletions sdk/bundled_program/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict
import random
import string
from typing import List, Tuple
from typing import Callable, List, Tuple

import torch

Expand All @@ -19,6 +19,7 @@
MethodTestSuite,
)
from torch.export import export, WrapperModule
from torch.export.unflatten import _assign_attr, _AttrKind

# A hacky integer to deal with a mismatch between execution plan and complier.
#
Expand All @@ -45,8 +46,8 @@ class SampleModel(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32)
self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32)
self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.int32))
self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.int32))
self.method_names = ["encode", "decode"]

def encode(
Expand Down Expand Up @@ -228,6 +229,28 @@ def get_random_test_suites_with_eager_model(
return inputs_per_program, method_test_suites


class StatefulWrapperModule(torch.nn.Module):
"""A version of wrapper module that preserves parameters/buffers.

Use this if you are planning to wrap a non-forward method on an existing
module.
"""

def __init__(self, base_mod, method) -> None: # pyre-ignore
super().__init__()
state_dict = base_mod.state_dict()
for name, value in base_mod.named_parameters():
_assign_attr(value, self, name, _AttrKind.PARAMETER)
for name, value in base_mod.named_buffers():
_assign_attr(
value, self, name, _AttrKind.BUFFER, persistent=name in state_dict
)
self.fn = method # pyre-ignore

def forward(self, *args, **kwargs): # pyre-ignore
return self.fn(*args, **kwargs)


def get_common_executorch_program() -> Tuple[
ExecutorchProgramManager, List[MethodTestSuite]
]:
Expand All @@ -246,7 +269,8 @@ def get_common_executorch_program() -> Tuple[
# Trace to FX Graph and emit the program
method_graphs = {
m_name: export(
WrapperModule(getattr(eager_model, m_name)), capture_inputs[m_name]
StatefulWrapperModule(eager_model, getattr(eager_model, m_name)),
capture_inputs[m_name],
)
for m_name in eager_model.method_names
}
Expand Down
4 changes: 2 additions & 2 deletions test/end2end/exported_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def __init__(self, method):
# These cleanup passes are required to convert the `add` op to its out
# variant, along with some other transformations.
for method_name, method_input in method_name_to_args.items():
module = WrapperModule(getattr(eager_module, method_name))
# if not isinstance(eager_module, torch.nn.Module):
exported_methods[method_name] = export(
module,
eager_module,
method_input,
dynamic_shapes=method_name_to_dynamic_shapes[method_name]
if method_name_to_dynamic_shapes
Expand Down