Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/apple/coreml/test/test_coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def quantize_and_compare(
) -> None:
assert quantization_type in {"PTQ", "QAT"}

pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
pre_autograd_aten_dialect = export_for_training(
model, example_inputs, strict=True
).module()

quantization_config = LinearQuantizerConfig.from_dict(
{
Expand Down
2 changes: 1 addition & 1 deletion backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def lower_module_and_test_output(
expected_output = model(*sample_inputs)

model = torch.export.export_for_training(
model, sample_inputs, dynamic_shapes=dynamic_shapes
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
).module()

edge_program = export_to_edge(
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def convert_pt2(
remove_decompositions(decomp_table, ops_to_keep)
# Export with dynamo
model_gm = (
torch.export.export_for_training(model, inputs)
torch.export.export_for_training(model, inputs, strict=True)
.run_decompositions(decomp_table)
.module()
)
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def forward(self, x):
# Run the standard quant/convert steps, but without fusing
# this leaves two redundant quant/dequant pairs to test with
quantizer = CadenceDefaultQuantizer()
model_exp = export_for_training(M(), (inp,)).module()
model_exp = export_for_training(M(), (inp,), strict=True).module()
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(inp)
converted_model = convert_pt2e(prepared_model)
Expand Down
8 changes: 6 additions & 2 deletions backends/example/test_example_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def get_example_inputs():
)

m = model.eval()
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
# print("original model:", m)
quantizer = ExampleQuantizer()
# quantizer = XNNPACKQuantizer()
Expand Down Expand Up @@ -82,7 +84,9 @@ def test_delegate_mobilenet_v2(self):
)

m = model.eval()
m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
quantizer = ExampleQuantizer()

m = prepare_pt2e(m, quantizer)
Expand Down
6 changes: 3 additions & 3 deletions backends/mediatek/quantizer/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def annotate(graph: Graph, quant_config: QuantizationConfig) -> None:


def register_annotator(ops: List[OpOverload]):

def decorator(annotator_fn: Callable):
for op in ops:
OP_TO_ANNOTATOR[op] = annotator_fn
Expand Down Expand Up @@ -147,7 +146,6 @@ def _annotate_fused_activation_pattern(


def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None:

class ExecuTorchPattern(torch.nn.Module):
def forward(self, x):
norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6)
Expand All @@ -159,7 +157,9 @@ def forward(self, x):
return norm, {}

for pattern_cls in (ExecuTorchPattern, MTKPattern):
pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
pattern_gm = export_for_training(
pattern_cls(), (torch.randn(3, 3),), strict=True
).module()
matcher = SubgraphMatcherWithNameNodeMap(
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def get_prepared_qat_module(
custom_quant_annotations: Tuple[Callable] = (),
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
) -> torch.fx.GraphModule:
m = torch.export.export_for_training(module, inputs).module()
m = torch.export.export_for_training(module, inputs, strict=True).module()

quantizer = make_quantizer(
quant_dtype=quant_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ def _test_duplicate_chain(

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(
m,
example_inputs,
).module()
m = torch.export.export_for_training(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
35 changes: 17 additions & 18 deletions backends/xnnpack/test/quantizer/test_pt2e_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_disallow_eval_train(self) -> None:
m.train()

# After export: this is not OK
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -380,7 +380,7 @@ def forward(self, x):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23]
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()

def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
bn_op = bn_train_op if train else bn_eval_op
Expand Down Expand Up @@ -449,10 +449,7 @@ def forward(self, x):
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()
weight_meta = None
for n in m.graph.nodes: # pyre-ignore[16]
if (
Expand Down Expand Up @@ -481,7 +478,7 @@ def test_reentrant(self) -> None:
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
)
m.conv_bn_relu = export_for_training( # pyre-ignore[8]
m.conv_bn_relu, example_inputs
m.conv_bn_relu, example_inputs, strict=True
).module()
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8]
m(*example_inputs)
Expand All @@ -490,7 +487,7 @@ def test_reentrant(self) -> None:
quantizer = XNNPACKQuantizer().set_module_type(
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
)
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
m = convert_pt2e(m)

Expand Down Expand Up @@ -553,7 +550,7 @@ def check_nn_module(node: torch.fx.Node) -> None:
)

m.conv_bn_relu = export_for_training( # pyre-ignore[8]
m.conv_bn_relu, example_inputs
m.conv_bn_relu, example_inputs, strict=True
).module()
for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16]
if node.op not in ["placeholder", "output", "get_attr"]:
Expand All @@ -568,7 +565,7 @@ def test_speed(self) -> None:

def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule:
torch._dynamo.reset()
model = export_for_training(model, example_inputs).module()
model = export_for_training(model, example_inputs, strict=True).module()
# Per channel quantization for weight
# Dynamic quantization for activation
# Please read a detail: https://fburl.com/code/30zds51q
Expand Down Expand Up @@ -625,7 +622,7 @@ def forward(self, x):

example_inputs = (torch.randn(1, 3, 5, 5),)
m = M()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(),
)
Expand Down Expand Up @@ -701,7 +698,6 @@ def test_save_load(self) -> None:


class TestNumericDebugger(TestCase):

def _extract_debug_handles(self, model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}

Expand Down Expand Up @@ -731,7 +727,7 @@ def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
def test_quantize_pt2e_preserve_handle(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

Expand Down Expand Up @@ -761,7 +757,7 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
def test_extract_results_from_loggers(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
Expand All @@ -779,18 +775,20 @@ def test_extract_results_from_loggers(self) -> None:
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(
ref_results, quant_results # pyre-ignore[6]
ref_results,
quant_results, # pyre-ignore[6]
)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(
node_summary.results[0].sqnr, 35 # pyre-ignore[6]
node_summary.results[0].sqnr,
35, # pyre-ignore[6]
)

def test_extract_results_from_loggers_list_output(self) -> None:
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
Expand All @@ -808,7 +806,8 @@ def test_extract_results_from_loggers_list_output(self) -> None:
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(
ref_results, quant_results # pyre-ignore[6]
ref_results,
quant_results, # pyre-ignore[6]
)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
Expand Down
5 changes: 1 addition & 4 deletions backends/xnnpack/test/quantizer/test_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def _test_representation(
) -> None:
# resetting dynamo cache
torch._dynamo.reset()
model = export_for_training(
model,
example_inputs,
).module()
model = export_for_training(model, example_inputs, strict=True).module()
model_copy = copy.deepcopy(model)

model = prepare_pt2e(model, quantizer) # pyre-ignore[6]
Expand Down
18 changes: 5 additions & 13 deletions backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def forward(self, x):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs).module()
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
# Use a linear count instead of names because the names might change, but
# the order should be the same.
Expand Down Expand Up @@ -497,10 +497,7 @@ def test_propagate_annotation(self):
example_inputs = (torch.randn(1, 3, 5, 5),)

# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
m(*example_inputs)
Expand Down Expand Up @@ -766,8 +763,7 @@ def forward(self, input_tensor, hidden_tensor):

with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph,
example_inputs,
model_graph, example_inputs, strict=True
).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
Expand Down Expand Up @@ -829,8 +825,7 @@ def forward(self, input_tensor, hidden_tensor):

with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph,
example_inputs,
model_graph, example_inputs, strict=True
).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
Expand Down Expand Up @@ -1039,10 +1034,7 @@ def test_resnet18(self):
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m = export_for_training(
m,
example_inputs,
).module()
m = export_for_training(m, example_inputs, strict=True).module()

quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
Expand Down
5 changes: 1 addition & 4 deletions backends/xnnpack/test/test_xnnpack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,7 @@ def quantize_and_test_model_with_quantizer(
module.eval()
# program capture

m = export_for_training(
module,
example_inputs,
).module()
m = export_for_training(module, example_inputs, strict=True).module()

quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config()
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
assert inputs is not None
captured_graph = export_for_training(artifact, inputs).module()
captured_graph = export_for_training(artifact, inputs, strict=True).module()

assert isinstance(captured_graph, torch.fx.GraphModule)
prepared = prepare_pt2e(captured_graph, self.quantizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
from torch.export import export_for_training

example_args = (torch.randn(1, 3, 256, 256),)
pre_autograd_aten_dialect = export_for_training(SimpleConv(), example_args).module()
pre_autograd_aten_dialect = export_for_training(
SimpleConv(), example_args, strict=True
).module()
print("Pre-Autograd ATen Dialect Graph")
print(pre_autograd_aten_dialect)

Expand Down Expand Up @@ -555,7 +557,7 @@ def forward(self, x):


example_args = (torch.randn(3, 4),)
pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
pre_autograd_aten_dialect = export_for_training(M(), example_args, strict=True).module()
# Optionally do quantization:
# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
aten_dialect = export(pre_autograd_aten_dialect, example_args, strict=True)
Expand Down
4 changes: 3 additions & 1 deletion examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def get_model_config(args):

# pre-autograd export. eventually this will become torch.export
with torch.no_grad():
model = torch.export.export_for_training(model, example_inputs).module()
model = torch.export.export_for_training(
model, example_inputs, strict=True
).module()
edge: EdgeProgramManager = export_to_edge(
model,
example_inputs,
Expand Down
9 changes: 6 additions & 3 deletions examples/arm/aot_arm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def forward(self, x):


class MultipleOutputsModule(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
return (x * y, x.sum(dim=-1, keepdim=True))

Expand Down Expand Up @@ -648,7 +647,9 @@ def to_edge_TOSA_delegate(
)
model_int8 = model
# Wrap quantized model back into an exported_program
exported_program = torch.export.export_for_training(model, example_inputs)
exported_program = torch.export.export_for_training(
model, example_inputs, strict=True
)

if args.intermediates:
os.makedirs(args.intermediates, exist_ok=True)
Expand Down Expand Up @@ -681,7 +682,9 @@ def to_edge_TOSA_delegate(

# export_for_training under the assumption we quantize, the exported form also works
# in to_edge if we don't quantize
exported_program = torch.export.export_for_training(model, example_inputs)
exported_program = torch.export.export_for_training(
model, example_inputs, strict=True
)
model = exported_program.module()
model_fp32 = model

Expand Down
Loading