From 4708c2077e97d139fb2f1d613873cd7ec167a882 Mon Sep 17 00:00:00 2001 From: Yohann Benchetrit Date: Tue, 28 Mar 2023 15:28:52 +0200 Subject: [PATCH] Add shift-right method to T5 model --- .../models/t5_models_test_impl.py | 22 +++++++++++++++++++ torchtext/models/t5/model.py | 11 ++++++++++ 2 files changed, 33 insertions(+) diff --git a/test/torchtext_unittest/models/t5_models_test_impl.py b/test/torchtext_unittest/models/t5_models_test_impl.py index bd36f32715..ab353d288a 100644 --- a/test/torchtext_unittest/models/t5_models_test_impl.py +++ b/test/torchtext_unittest/models/t5_models_test_impl.py @@ -184,3 +184,25 @@ def _train(model): _train(model) self.assertNotEqual(model.state_dict(), current_state_dict) + + def test_shift_right(self) -> None: + from torchtext.models import T5Conf, T5Model + + dummy_encoder_conf = T5Conf() + dummy_t5_encoder = T5Model(dummy_encoder_conf) + padding_idx = dummy_t5_encoder.padding_idx + + valid_cases_input = [[[1, 2], [3, 4]], [[1]]] + valid_cases_expected = [[[padding_idx, 1], [padding_idx, 3]], [[padding_idx]]] + + invalid_cases_input = [[0], [], [[]]] + + for input_ids, expected in zip(valid_cases_input, valid_cases_expected): + input_ids = torch.Tensor(input_ids) + expected = torch.Tensor(expected) + self.assertEqual(dummy_t5_encoder._shift_right(input_ids), expected) + + for input_ids in invalid_cases_input: + input_ids = torch.Tensor(input_ids) + with self.assertRaises(IndexError): + dummy_t5_encoder._shift_right(input_ids) diff --git a/torchtext/models/t5/model.py b/torchtext/models/t5/model.py index 6ba55089c5..ad85280f5a 100644 --- a/torchtext/models/t5/model.py +++ b/torchtext/models/t5/model.py @@ -193,6 +193,17 @@ def _reorder_cache(self, past: List[PAST_KEY_VALUES_TYPE], beam_idx: Tensor) -> reordered_decoder_past.append(reordered_layer_past_states) return reordered_decoder_past + @torch.jit.export + def _shift_right(self, input_ids: Tensor) -> Tensor: + """Shift all input sequences to the right""" + shifted_input_ids = torch.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + + # T5 implemention uses padding idx to start sequence. + shifted_input_ids[:, 0] = self.padding_idx + + return shifted_input_ids + @torch.jit.export def prepare_inputs_for_generation( self,