diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 174f641733..3e32835a9e 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -80,6 +80,14 @@ def test_sentencepiece_tokenizer(self): self.assertEqual(list(spm_generator([test_sample]))[0], ref_results) + def test_sentencepiece_unsupported_input_type(self): + with self.assertRaisesRegex( + TypeError, + 'Unsupported type for spm argument: dict. ' + 'Supported types are: str, io.BufferedReader' + ): + load_sp_model(dict()) + # TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_BasicEnglishNormalize(self): diff --git a/torchtext/data/functional.py b/torchtext/data/functional.py index 3333c2dc63..6e20c8e667 100644 --- a/torchtext/data/functional.py +++ b/torchtext/data/functional.py @@ -58,7 +58,12 @@ def load_sp_model(spm): elif isinstance(spm, io.BufferedReader): return torch.ops.torchtext.load_sp_model_string(spm.read()) else: - raise RuntimeError('the input of the load_sp_model func is not supported.') + raise TypeError( + f'Unsupported type for spm argument: {type(spm).__name__}. ' + + 'Supported types are: ' + + ', '.join([ + 'str', 'io.BufferedReader' + ])) def sentencepiece_numericalizer(sp_model):