Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions backends/apple/coreml/test/test_coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from torch.export import export_for_training
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
Expand All @@ -32,9 +32,7 @@ def quantize_and_compare(
) -> None:
assert quantization_type in {"PTQ", "QAT"}

pre_autograd_aten_dialect = export_for_training(
model, example_inputs, strict=True
).module()
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()

quantization_config = LinearQuantizerConfig.from_dict(
{
Expand Down
2 changes: 1 addition & 1 deletion backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def lower_module_and_test_output(

expected_output = model(*sample_inputs)

model = torch.export.export_for_training(
model = torch.export.export(
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
).module()

Expand Down
10 changes: 3 additions & 7 deletions backends/cortex_m/test/test_quantize_op_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_node_args,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import export, export_for_training
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


Expand All @@ -42,9 +42,7 @@ def _prepare_quantized_model(self, model_class):
model = model_class()

# Export and quantize
exported_model = export_for_training(
model.eval(), self.example_inputs, strict=True
).module()
exported_model = export(model.eval(), self.example_inputs, strict=True).module()
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
quantized_model = convert_pt2e(prepared_model)

Expand Down Expand Up @@ -242,9 +240,7 @@ def forward(self, x, y):
inputs = (torch.randn(shape), torch.randn(shape))

model = SingleAddModel()
exported_model = export_for_training(
model.eval(), inputs, strict=True
).module()
exported_model = export(model.eval(), inputs, strict=True).module()
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
quantized_model = convert_pt2e(prepared_model)

Expand Down
8 changes: 2 additions & 6 deletions backends/example/test_example_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def get_example_inputs():
)

m = model.eval()
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
# print("original model:", m)
quantizer = ExampleQuantizer()
# quantizer = XNNPACKQuantizer()
Expand Down Expand Up @@ -84,9 +82,7 @@ def test_delegate_mobilenet_v2(self):
)

m = model.eval()
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
quantizer = ExampleQuantizer()

m = prepare_pt2e(m, quantizer)
Expand Down
6 changes: 2 additions & 4 deletions backends/mediatek/quantizer/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._ops import OpOverload
from torch._subclasses import FakeTensor

from torch.export import export_for_training
from torch.export import export
from torch.fx import Graph, Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
Expand Down Expand Up @@ -158,9 +158,7 @@ def forward(self, x):
return norm, {}

for pattern_cls in (ExecuTorchPattern, MTKPattern):
pattern_gm = export_for_training(
pattern_cls(), (torch.randn(3, 3),), strict=True
).module()
pattern_gm = export(pattern_cls(), (torch.randn(3, 3),), strict=True).module()
matcher = SubgraphMatcherWithNameNodeMap(
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def get_prepared_qat_module(
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
) -> torch.fx.GraphModule:
m = torch.export.export_for_training(module, inputs, strict=True).module()
m = torch.export.export(module, inputs, strict=True).module()

quantizer = make_quantizer(
quant_dtype=quant_dtype,
Expand Down
4 changes: 2 additions & 2 deletions backends/test/harness/stages/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DuplicateDynamicQuantChainPass,
)

from torch.export import export_for_training
from torch.export import export

from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
Expand Down Expand Up @@ -47,7 +47,7 @@ def run(
assert inputs is not None
if self.is_qat:
artifact.train()
captured_graph = export_for_training(artifact, inputs, strict=True).module()
captured_graph = export(artifact, inputs, strict=True).module()

assert isinstance(captured_graph, torch.fx.GraphModule)

Expand Down
5 changes: 0 additions & 5 deletions backends/test/suite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import re
import time
import unittest
import warnings

from datetime import timedelta
from typing import Any
Expand Down Expand Up @@ -283,10 +282,6 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
def runner_main():
args = parse_args()

# Suppress deprecation warnings for export_for_training, as it generates a
# lot of log spam. We don't really need the warning here.
warnings.simplefilter("ignore", category=FutureWarning)

seed = args.seed or random.randint(0, 100_000_000)
print(f"Running with seed {seed}.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _test_duplicate_chain(

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/test/test_vulkan_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def quantize_and_lower_module(
_check_ir_validity=False,
)

program = torch.export.export_for_training(
program = torch.export.export(
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
).module()

Expand Down Expand Up @@ -95,7 +95,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) ->


class TestVulkanPasses(unittest.TestCase):

def test_fuse_int8pack_mm(self):
K = 256
N = 256
Expand Down Expand Up @@ -184,7 +183,7 @@ def test_fuse_linear_qta8a_qga4w(self):
_check_ir_validity=False,
)

program = torch.export.export_for_training(
program = torch.export.export(
quantized_model, sample_inputs, strict=True
).module()

Expand Down
8 changes: 3 additions & 5 deletions backends/vulkan/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
_load_for_executorch_from_buffer,
)
from executorch.extension.pytree import tree_flatten
from torch.export import export, export_for_training
from torch.export import export

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

Expand All @@ -53,7 +53,7 @@ def get_exported_graph(
dynamic_shapes=None,
qmode=QuantizationMode.NONE,
) -> torch.fx.GraphModule:
export_training_graph = export_for_training(
export_training_graph = export(
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
).module()

Expand Down Expand Up @@ -590,9 +590,7 @@ def op_ablation_test( # noqa: C901
logger.info("Starting fast binary search operator ablation test...")

# Step 1: Export model to get edge_program and extract operators
export_training_graph = export_for_training(
model, sample_inputs, strict=True
).module()
export_training_graph = export(model, sample_inputs, strict=True).module()
program = export(
export_training_graph,
sample_inputs,
Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/test/ops/test_check_quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from executorch.backends.xnnpack.utils.utils import get_param_tensor
from executorch.exir import to_edge_transform_and_lower
from torch.export import export_for_training
from torch.export import export
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


Expand Down Expand Up @@ -52,7 +52,7 @@ def _test_check_quant_message(self, ep_modifier, expected_message):
torch._dynamo.reset()
mod = torch.nn.Linear(10, 10)
quantizer = XNNPACKQuantizer()
captured = export_for_training(mod, (torch.randn(1, 10),), strict=True).module()
captured = export(mod, (torch.randn(1, 10),), strict=True).module()
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True))
prepared = prepare_pt2e(captured, quantizer)

Expand Down
27 changes: 13 additions & 14 deletions backends/xnnpack/test/quantizer/test_pt2e_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
weight_observer_range_neg_127_to_127,
)
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
TestHelperModules,
Expand Down Expand Up @@ -58,7 +58,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
# resetting dynamo cache
torch._dynamo.reset()

m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
if is_qat:
m = prepare_qat_pt2e(m, quantizer)
else:
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_disallow_eval_train(self) -> None:
m.train()

# After export: this is not OK
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -405,7 +405,7 @@ def forward(self, x):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23]
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()

def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None:
bn_op = bn_train_op if train else bn_eval_op
Expand Down Expand Up @@ -474,7 +474,7 @@ def forward(self, x):
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
weight_meta = None
for n in m.graph.nodes: # pyre-ignore[16]
if (
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_reentrant(self) -> None:
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
)
m.conv_bn_relu = export_for_training( # pyre-ignore[8]
m.conv_bn_relu = export( # pyre-ignore[8]
m.conv_bn_relu, example_inputs, strict=True
).module()
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8]
Expand All @@ -513,7 +513,7 @@ def test_reentrant(self) -> None:
quantizer = XNNPACKQuantizer().set_module_type(
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
m = convert_pt2e(m)

Expand Down Expand Up @@ -575,7 +575,7 @@ def check_nn_module(node: torch.fx.Node) -> None:
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
)

m.conv_bn_relu = export_for_training( # pyre-ignore[8]
m.conv_bn_relu = export( # pyre-ignore[8]
m.conv_bn_relu, example_inputs, strict=True
).module()
for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16]
Expand All @@ -591,7 +591,7 @@ def test_speed(self) -> None:

def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule:
torch._dynamo.reset()
model = export_for_training(model, example_inputs, strict=True).module()
model = export(model, example_inputs, strict=True).module()
# Per channel quantization for weight
# Dynamic quantization for activation
# Please read a detail: https://fburl.com/code/30zds51q
Expand Down Expand Up @@ -648,7 +648,7 @@ def forward(self, x):

example_inputs = (torch.randn(1, 3, 5, 5),)
m = M()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(),
)
Expand Down Expand Up @@ -724,11 +724,10 @@ def test_save_load(self) -> None:


class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase):

def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
m = ep.module()

quantizer = XNNPACKQuantizer().set_global(
Expand Down Expand Up @@ -768,7 +767,7 @@ def test_quantize_pt2e_preserve_handle(self):
def test_extract_results_from_loggers(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)

Expand All @@ -792,7 +791,7 @@ def test_extract_results_from_loggers(self):
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)

Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/test/quantizer/test_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
XNNPACKQuantizer,
)
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
Expand All @@ -33,7 +33,7 @@ def _test_representation(
) -> None:
# resetting dynamo cache
torch._dynamo.reset()
model = export_for_training(model, example_inputs, strict=True).module()
model = export(model, example_inputs, strict=True).module()
model_copy = copy.deepcopy(model)

model = prepare_pt2e(model, quantizer) # pyre-ignore[6]
Expand Down
Loading
Loading