diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 42629815717..666de4bda03 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -109,6 +109,7 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", ], ) @@ -209,6 +210,7 @@ python_unittest( "//executorch/exir/passes:debug_handle_generator_pass", "//executorch/exir/passes:insert_write_back_for_buffers_pass", "//executorch/exir/passes:lib", + "//executorch/exir/passes:memory_format_ops_pass", "//executorch/exir/passes:normalize_view_copy_base_pass", "//executorch/exir/passes:remove_graph_asserts_pass", "//executorch/exir/passes:remove_mixed_type_operators", diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py index 0aa724479bf..7c80439610b 100644 --- a/exir/tests/test_joint_graph.py +++ b/exir/tests/test_joint_graph.py @@ -11,6 +11,10 @@ import torch._dynamo from executorch.exir import to_edge + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) from torch.export._trace import _export from torch.export.experimental import _export_forward_backward from torch.export.exported_program import OutputKind @@ -89,3 +93,18 @@ def forward(self, x, y): .val.allocation_info.memory_offset_low, 48, ) + + loss = m(*example_inputs) + loss.backward() + et_mod = _load_for_executorch_from_buffer(et.buffer) + et_outputs = et_mod.forward( + example_inputs + ) # ET outputs are [loss, grads, weights] + + self.assertTrue(torch.allclose(loss, et_outputs[0])) + self.assertTrue( + torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6] + ) + self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2])) + self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3])) + self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4]))