diff --git a/backends/test/suite/operators/test_lstm.py b/backends/test/suite/operators/test_lstm.py index 91dd73c9052..11632e1e055 100644 --- a/backends/test/suite/operators/test_lstm.py +++ b/backends/test/suite/operators/test_lstm.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,6 +16,11 @@ operator_test, OperatorTest, ) +from torch.nn.quantizable.modules.rnn import LSTM as QuantizableLSTM + + +def _get_lstm_cls(use_quantizable_lstm: bool): + return QuantizableLSTM if use_quantizable_lstm else torch.nn.LSTM class Model(torch.nn.Module): @@ -27,9 +33,11 @@ def __init__( batch_first=True, dropout=0.0, bidirectional=False, + use_quantizable_lstm: bool = False, ): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, @@ -47,106 +55,133 @@ def forward(self, x): class LSTM(OperatorTest): @dtype_test def test_lstm_dtype(self, flow: TestFlow, dtype) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2).to(dtype), + Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm).to(dtype), ((torch.rand(1, 10, 64) * 10).to(dtype),), # (batch=1, seq_len, input_size) flow, ) @dtype_test def test_lstm_no_bias_dtype(self, flow: TestFlow, dtype) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2, bias=False).to(dtype), + Model( + num_layers=2, bias=False, use_quantizable_lstm=use_quantizable_lstm + ).to(dtype), ((torch.rand(1, 10, 64) * 10).to(dtype),), flow, ) def test_lstm_feature_sizes(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(input_size=32, hidden_size=16), + Model( + input_size=32, + hidden_size=16, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 8, 32),), # (batch=1, seq_len, input_size) flow, ) self._test_op( - Model(input_size=128, hidden_size=64), + Model( + input_size=128, + hidden_size=64, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 12, 128),), flow, ) self._test_op( - Model(input_size=256, hidden_size=128), + Model( + input_size=256, + hidden_size=128, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 6, 256),), flow, ) self._test_op( - Model(input_size=16, hidden_size=32), + Model( + input_size=16, + hidden_size=32, + use_quantizable_lstm=use_quantizable_lstm, + ), (torch.randn(1, 5, 16),), flow, ) def test_lstm_batch_sizes(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(8, 10, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(32, 10, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(100, 10, 64),), flow, ) def test_lstm_seq_lengths(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 5, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 20, 64),), flow, ) self._test_op( - Model(), + Model(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 50, 64),), flow, ) def test_lstm_batch_first_false(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(batch_first=False), + Model(batch_first=False, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(10, 1, 64),), # (seq_len, batch=1, input_size) flow, ) def test_lstm_num_layers(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2), + Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) self._test_op( - Model(num_layers=3), + Model(num_layers=3, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) def test_lstm_bidirectional(self, flow: TestFlow) -> None: + use_quantizable_lstm = flow.quantize self._test_op( - Model(bidirectional=True), + Model(bidirectional=True, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) def test_lstm_with_dropout(self, flow: TestFlow) -> None: # Note: Dropout is only effective with num_layers > 1 + use_quantizable_lstm = flow.quantize self._test_op( - Model(num_layers=2, dropout=0.2), + Model(num_layers=2, dropout=0.2, use_quantizable_lstm=use_quantizable_lstm), (torch.randn(1, 10, 64),), flow, ) @@ -154,9 +189,10 @@ def test_lstm_with_dropout(self, flow: TestFlow) -> None: def test_lstm_with_initial_states(self, flow: TestFlow) -> None: # Create a model that accepts initial states class ModelWithStates(torch.nn.Module): - def __init__(self): + def __init__(self, use_quantizable_lstm: bool = False): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=64, hidden_size=32, num_layers=2, @@ -169,9 +205,10 @@ def forward(self, x, h0, c0): batch_size = 1 num_layers = 2 hidden_size = 32 + use_quantizable_lstm = flow.quantize self._test_op( - ModelWithStates(), + ModelWithStates(use_quantizable_lstm=use_quantizable_lstm), ( torch.randn(batch_size, 10, 64), # input torch.randn(num_layers, batch_size, hidden_size), # h0 @@ -183,9 +220,10 @@ def forward(self, x, h0, c0): def test_lstm_return_hidden_states(self, flow: TestFlow) -> None: # Create a model that returns both output and hidden states class ModelWithHiddenStates(torch.nn.Module): - def __init__(self): + def __init__(self, use_quantizable_lstm: bool = False): super().__init__() - self.lstm = torch.nn.LSTM( + lstm_cls = _get_lstm_cls(use_quantizable_lstm) + self.lstm = lstm_cls( input_size=64, hidden_size=32, num_layers=2, @@ -200,9 +238,10 @@ def forward(self, x): batch_size = 1 seq_len = 10 input_size = 64 + use_quantizable_lstm = flow.quantize self._test_op( - ModelWithHiddenStates(), + ModelWithHiddenStates(use_quantizable_lstm=use_quantizable_lstm), (torch.randn(batch_size, seq_len, input_size),), flow, )