From 71f346607e2f7a1a8addd41536870b7a1bc8548b Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 23 Aug 2023 10:37:39 -0700 Subject: [PATCH] Enable ResNet-18 (#107) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/107 Differential Revision: D48591055 fbshipit-source-id: a1c0bed02269b36ec607fdba8f36a679e8f2a403 --- examples/export/test/test_export.py | 8 ++++++++ examples/models/TARGETS | 2 +- examples/models/models.py | 9 ++++++++- examples/models/{resnet50 => resnet}/TARGETS | 4 ++-- .../models/{resnet50 => resnet}/__init__.py | 3 ++- examples/models/{resnet50 => resnet}/export.py | 17 +++++++++++++++++ 6 files changed, 38 insertions(+), 5 deletions(-) rename examples/models/{resnet50 => resnet}/TARGETS (73%) rename examples/models/{resnet50 => resnet}/__init__.py (78%) rename examples/models/{resnet50 => resnet}/export.py (64%) diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 2af84f20653..e4cb98bcffa 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -96,6 +96,14 @@ def test_ic3_export_to_executorch(self): eager_model, example_inputs, self.validate_tensor_allclose ) + def test_resnet18_export_to_executorch(self): + eager_model, example_inputs = MODEL_NAME_TO_MODEL["resnet18"]() + eager_model = eager_model.eval() + + self._assert_eager_lowered_same_result( + eager_model, example_inputs, self.validate_tensor_allclose + ) + def test_resnet50_export_to_executorch(self): eager_model, example_inputs = MODEL_NAME_TO_MODEL["resnet50"]() eager_model = eager_model.eval() diff --git a/examples/models/TARGETS b/examples/models/TARGETS index 4a518a2566d..d7284beaed8 100644 --- a/examples/models/TARGETS +++ b/examples/models/TARGETS @@ -12,7 +12,7 @@ python_library( "//executorch/examples/models/inception_v4:ic4_export", "//executorch/examples/models/mobilenet_v2:mv2_export", "//executorch/examples/models/mobilenet_v3:mv3_export", - "//executorch/examples/models/resnet50:resnet50_export", + "//executorch/examples/models/resnet:resnet_export", "//executorch/examples/models/torchvision_vit:vit_export", "//executorch/examples/models/wav2letter:w2l_export", "//executorch/exir/backend:compile_spec_schema", diff --git a/examples/models/models.py b/examples/models/models.py index 8373dd199a4..94800fcf6da 100644 --- a/examples/models/models.py +++ b/examples/models/models.py @@ -114,8 +114,14 @@ def gen_inception_v4_model_and_inputs() -> Tuple[torch.nn.Module, Any]: return InceptionV4Model.get_model(), InceptionV4Model.get_example_inputs() +def gen_resnet18_model_and_inputs() -> Tuple[torch.nn.Module, Any]: + from ..models.resnet import ResNet18Model + + return ResNet18Model.get_model(), ResNet18Model.get_example_inputs() + + def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]: - from ..models.resnet50 import ResNet50Model + from ..models.resnet import ResNet50Model return ResNet50Model.get_model(), ResNet50Model.get_example_inputs() @@ -131,5 +137,6 @@ def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]: "w2l": gen_wav2letter_model_and_inputs, "ic3": gen_inception_v3_model_and_inputs, "ic4": gen_inception_v4_model_and_inputs, + "resnet18": gen_resnet18_model_and_inputs, "resnet50": gen_resnet50_model_and_inputs, } diff --git a/examples/models/resnet50/TARGETS b/examples/models/resnet/TARGETS similarity index 73% rename from examples/models/resnet50/TARGETS rename to examples/models/resnet/TARGETS index 8b19fc6ddde..33738e8dce3 100644 --- a/examples/models/resnet50/TARGETS +++ b/examples/models/resnet/TARGETS @@ -1,12 +1,12 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( - name = "resnet50_export", + name = "resnet_export", srcs = [ "__init__.py", "export.py", ], - base_module = "executorch.examples.models.resnet50", + base_module = "executorch.examples.models.resnet", deps = [ "//caffe2:torch", "//pytorch/vision:torchvision", diff --git a/examples/models/resnet50/__init__.py b/examples/models/resnet/__init__.py similarity index 78% rename from examples/models/resnet50/__init__.py rename to examples/models/resnet/__init__.py index 47f685e33fb..459641bf86c 100644 --- a/examples/models/resnet50/__init__.py +++ b/examples/models/resnet/__init__.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .export import ResNet50Model +from .export import ResNet18Model, ResNet50Model __all__ = [ + ResNet18Model, ResNet50Model, ] diff --git a/examples/models/resnet50/export.py b/examples/models/resnet/export.py similarity index 64% rename from examples/models/resnet50/export.py rename to examples/models/resnet/export.py index f98f6622841..9e77331852b 100644 --- a/examples/models/resnet50/export.py +++ b/examples/models/resnet/export.py @@ -13,6 +13,23 @@ logging.basicConfig(format=FORMAT) +class ResNet18Model: + def __init__(self): + pass + + @staticmethod + def get_model(): + logging.info("loading torchvision resnet18 model") + resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) + logging.info("loaded torchvision resnet18 model") + return resnet18 + + @staticmethod + def get_example_inputs(): + input_shape = (1, 3, 224, 224) + return (torch.randn(input_shape),) + + class ResNet50Model: def __init__(self): pass