Skip to content

Commit

Permalink
Add shift-right method to T5 model
Browse files Browse the repository at this point in the history
  • Loading branch information
yohann-benchetrit committed Mar 28, 2023
1 parent 22a998e commit 10dc158
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/torchtext_unittest/models/t5_models_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,24 @@ def _train(model):

_train(model)
self.assertNotEqual(model.state_dict(), current_state_dict)

def test_shit_right(self) -> None:
from torchtext.models import T5Conf, T5Model

valid_cases_input = [[[1, 2], [3, 4]], [[1]]]
valid_cases_expected = [[[0, 1], [0, 3]], [[0]]]

invalid_cases_input = [[0], [], [[]]]

dummy_encoder_conf = T5Conf()
dummy_t5_encoder = T5Model(dummy_encoder_conf)

for input_ids, expected in zip(valid_cases_input, valid_cases_expected):
input_ids = torch.Tensor(input_ids)
expected = torch.Tensor(expected)
self.assertEqual(dummy_t5_encoder._shift_right(input_ids), expected)

for input_ids in invalid_cases_input:
input_ids = torch.Tensor(input_ids)
with self.assertRaises(IndexError):
dummy_t5_encoder._shift_right(input_ids)
11 changes: 11 additions & 0 deletions torchtext/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,17 @@ def _reorder_cache(self, past: List[PAST_KEY_VALUES_TYPE], beam_idx: Tensor) ->
reordered_decoder_past.append(reordered_layer_past_states)
return reordered_decoder_past

@torch.jit.export
def _shift_right(self, input_ids: Tensor) -> Tensor:
"""Shift all input sequences to the right"""
shifted_input_ids = torch.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()

# T5 implemention uses padding idx to start sequence.
shifted_input_ids[:, 0] = self.config.padding_idx

return shifted_input_ids

@torch.jit.export
def prepare_inputs_for_generation(
self,
Expand Down

0 comments on commit 10dc158

Please sign in to comment.