diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index 5076c9225d2eb..5040b8c97d050 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -1,4 +1,5 @@ -# mypy: allow-untyped-defs +from typing import Any + import torch @@ -35,11 +36,11 @@ class LSTM(torch.ao.nn.quantizable.LSTM): _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] - def _get_name(self): + def _get_name(self) -> str: return "QuantizedLSTM" @classmethod - def from_float(cls, *args, **kwargs): + def from_float(cls, *args: Any, **kwargs: Any) -> None: # The whole flow is float -> observed -> quantized # This class does observed -> quantized only raise NotImplementedError( @@ -49,7 +50,7 @@ def from_float(cls, *args, **kwargs): ) @classmethod - def from_observed(cls, other): + def from_observed(cls: type["LSTM"], other: torch.ao.nn.quantizable.LSTM) -> "LSTM": assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] converted = torch.ao.quantization.convert( other, inplace=False, remove_qconfig=True