diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 6999dadb9f9..9614826c61a 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -57,6 +57,22 @@ class BackendDetails(ABC): """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + # Allow direct subclasses of BackendDetails + if cls.__bases__ == (BackendDetails,): + return + + # Forbid subclasses whose ANY parent is already a child of BackendDetails + for base in cls.__bases__: + if issubclass(base, BackendDetails) and base is not BackendDetails: + raise TypeError( + f"ExecuTorch delegate doesn't support nested backend, '{base.__name__}' " + " should be a final backend implementation and should not be subclassed " + f"(attempted by '{cls.__name__}')." + ) + @staticmethod # all backends need to implement this method @enforcedmethod diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 42e8ef16bd7..293ef8e43b4 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -12,6 +12,7 @@ import torch from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( AllNodePartitioner, ) @@ -1444,3 +1445,19 @@ def inputs(self): self.assertTrue( torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) ) + + def test_prohibited_nested_backends(self): + class MyBackend(BackendDetails): + @staticmethod + def preprocess(edge_program, compile_specs): + return None + + with self.assertRaises(TypeError) as ctx: + + class MyOtherBackend(MyBackend): + pass + + self.assertIn( + "'MyBackend' should be a final backend implementation and should not be subclassed (attempted by 'MyOtherBackend')", + str(ctx.exception), + )