@@ -160,16 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf:
160160 return self ._encoder_conf
161161
162162
163- XLMR_BASE_ENCODER = RobertaBundle (
164- _path = urljoin (_TEXT_BUCKET , "xlmr.base.encoder.pt" ),
165- _encoder_conf = RobertaEncoderConf (vocab_size = 250002 ),
166- transform = lambda : T .Sequential (
163+ def xlmr_transform (truncate_length : int ) -> Module :
164+ """Standard transform for XLMR models."""
165+ return T .Sequential (
167166 T .SentencePieceTokenizer (urljoin (_TEXT_BUCKET , "xlmr.sentencepiece.bpe.model" )),
168167 T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "xlmr.vocab.pt" ))),
169- T .Truncate (254 ),
168+ T .Truncate (truncate_length ),
170169 T .AddToken (token = 0 , begin = True ),
171170 T .AddToken (token = 2 , begin = False ),
172- ),
171+ )
172+
173+
174+ def roberta_transform (truncate_length : int ) -> Module :
175+ """Standard transform for RoBERTa models."""
176+ return T .Sequential (
177+ T .GPT2BPETokenizer (
178+ encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
179+ vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
180+ ),
181+ T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
182+ T .Truncate (truncate_length ),
183+ T .AddToken (token = 0 , begin = True ),
184+ T .AddToken (token = 2 , begin = False ),
185+ )
186+
187+
188+ XLMR_BASE_ENCODER = RobertaBundle (
189+ _path = urljoin (_TEXT_BUCKET , "xlmr.base.encoder.pt" ),
190+ _encoder_conf = RobertaEncoderConf (vocab_size = 250002 ),
191+ transform = lambda : xlmr_transform (254 ),
173192)
174193
175194XLMR_BASE_ENCODER .__doc__ = """
@@ -193,13 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf:
193212 _encoder_conf = RobertaEncoderConf (
194213 vocab_size = 250002 , embedding_dim = 1024 , ffn_dimension = 4096 , num_attention_heads = 16 , num_encoder_layers = 24
195214 ),
196- transform = lambda : T .Sequential (
197- T .SentencePieceTokenizer (urljoin (_TEXT_BUCKET , "xlmr.sentencepiece.bpe.model" )),
198- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "xlmr.vocab.pt" ))),
199- T .Truncate (510 ),
200- T .AddToken (token = 0 , begin = True ),
201- T .AddToken (token = 2 , begin = False ),
202- ),
215+ transform = lambda : xlmr_transform (510 ),
203216)
204217
205218XLMR_LARGE_ENCODER .__doc__ = """
@@ -221,16 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf:
221234ROBERTA_BASE_ENCODER = RobertaBundle (
222235 _path = urljoin (_TEXT_BUCKET , "roberta.base.encoder.pt" ),
223236 _encoder_conf = RobertaEncoderConf (vocab_size = 50265 ),
224- transform = lambda : T .Sequential (
225- T .GPT2BPETokenizer (
226- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
227- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
228- ),
229- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
230- T .Truncate (254 ),
231- T .AddToken (token = 0 , begin = True ),
232- T .AddToken (token = 2 , begin = False ),
233- ),
237+ transform = lambda : roberta_transform (254 ),
234238)
235239
236240ROBERTA_BASE_ENCODER .__doc__ = """
@@ -263,16 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf:
263267 num_attention_heads = 16 ,
264268 num_encoder_layers = 24 ,
265269 ),
266- transform = lambda : T .Sequential (
267- T .GPT2BPETokenizer (
268- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
269- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
270- ),
271- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
272- T .Truncate (510 ),
273- T .AddToken (token = 0 , begin = True ),
274- T .AddToken (token = 2 , begin = False ),
275- ),
270+ transform = lambda : roberta_transform (510 ),
276271)
277272
278273ROBERTA_LARGE_ENCODER .__doc__ = """
@@ -302,16 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf:
302297 num_encoder_layers = 6 ,
303298 padding_idx = 1 ,
304299 ),
305- transform = lambda : T .Sequential (
306- T .GPT2BPETokenizer (
307- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
308- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
309- ),
310- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
311- T .Truncate (510 ),
312- T .AddToken (token = 0 , begin = True ),
313- T .AddToken (token = 2 , begin = False ),
314- ),
300+ transform = lambda : roberta_transform (510 ),
315301)
316302
317303ROBERTA_DISTILLED_ENCODER .__doc__ = """
0 commit comments