Description
when running FlaxRoFormerForMaskedLM model, I have encountered an issue as
AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'split'.
This error is reported in the file transformers/models/roformer/modeling_flax_roformer.py:265
The function responsible for this error in that file is as below
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
sin, cos = sinusoidal_pos.split(2, axis=-1)
While changing this particular line from sinusoidal_pos.split(2, axis=-1)
to sinusoidal_pos._split(2, axis=-1)
, I didn't get that error
My observation is when I replace split()
with _split()
, my issue is resolved
System Info
My environment details are as below :
transformers
version: 4.49.0- Platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.29.3
- Safetensors version: 0.5.3
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.6.0+cu124 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
- Jax version: 0.4.36
- JaxLib version: 0.4.36
I am attaching a screenshot for reference

Who can help?
I am facing this issue for Models like
FlaxRoFormerForMultipleChoice
FlaxRoFormerForSequenceClassification
FlaxRoFormerForTokenClassification
FlaxRoFormerForQuestionAnswering
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Steps to recreate the error:
Run the below code in any python editor
from transformers import AutoTokenizer, FlaxRoFormerForMaskedLM
tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base")
model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
inputs = tokenizer("The capital of France is [MASK].", return_tensors="jax")
outputs = model(**inputs)
logits = outputs.logits
Expected behavior
The model should run and produce error free output