diff --git a/torchao/utils.py b/torchao/utils.py index 652e7f33f1..daf7eab83c 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -49,6 +49,7 @@ # Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711 +@functools.cache def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: """ Returns the unique device for a module, or None if no device is found.