diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 9a45195e58f..e68bd8a557d 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -196,7 +196,7 @@ def forward(self, x): model = model.eval() # pre-autograd export. eventually this will become torch.export - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export_for_training(model, example_inputs).module() # Quantize if required if args.quantize: