diff --git a/extension/training/examples/XOR/targets.bzl b/extension/training/examples/XOR/targets.bzl index 26d0f40d90b..4a85c34c1bb 100644 --- a/extension/training/examples/XOR/targets.bzl +++ b/extension/training/examples/XOR/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", "//executorch/kernels/portable:generated_lib", + "//executorch/extension/flat_tensor/serialize:serialize_cpp" ], external_deps = ["gflags"], define_static_target = True, diff --git a/extension/training/examples/XOR/test/test_export.py b/extension/training/examples/XOR/test/test_export.py index 26a24607d9e..82c9087e84b 100644 --- a/extension/training/examples/XOR/test/test_export.py +++ b/extension/training/examples/XOR/test/test_export.py @@ -13,6 +13,7 @@ class TestXORExport(unittest.TestCase): def test(self): - _ = _export_model() + ep = _export_model() + self.assertTrue(ep is not None) # Expect that we reach this far without an exception being thrown. self.assertTrue(True) diff --git a/extension/training/examples/XOR/train.cpp b/extension/training/examples/XOR/train.cpp index bca433fd889..746daebbf1b 100644 --- a/extension/training/examples/XOR/train.cpp +++ b/extension/training/examples/XOR/train.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -105,4 +106,11 @@ int main(int argc, char** argv) { } optimizer.step(mod.named_gradients("forward").get()); } + std::map param_map; + for (auto& param : param_res.get()) { + param_map.insert(std::pair{ + std::string(param.first.data()), param.second}); + } + + executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16); }