Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing RunTime Attribute error while running different Flax models for RoFormer #36854

Open
2 of 4 tasks
ctr-pmuruganTT opened this issue Mar 20, 2025 · 0 comments
Open
2 of 4 tasks

Comments

@ctr-pmuruganTT
Copy link

ctr-pmuruganTT commented Mar 20, 2025

when running FlaxRoFormerForMaskedLM model, I have encountered an issue as

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'split'.

This error is reported in the file transformers/models/roformer/modeling_flax_roformer.py:265

The function responsible for this error in that file is as below

def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
      sin, cos = sinusoidal_pos.split(2, axis=-1)

While changing this particular line from sinusoidal_pos.split(2, axis=-1) to sinusoidal_pos._split(2, axis=-1) , I didn't get that error

My observation is when I replace split() with _split() , my issue is resolved

System Info

My environment details are as below :

  • transformers version: 4.49.0
  • Platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
  • Jax version: 0.4.36
  • JaxLib version: 0.4.36

I am attaching a screenshot for reference

Image

Who can help?

@gante @Rocketknight1

I am facing this issue for Models like

FlaxRoFormerForMultipleChoice

FlaxRoFormerForSequenceClassification

FlaxRoFormerForTokenClassification

FlaxRoFormerForQuestionAnswering

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to recreate the error:

Run the below code in any python editor

from transformers import AutoTokenizer, FlaxRoFormerForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base")
model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")

inputs = tokenizer("The capital of France is [MASK].", return_tensors="jax")

outputs = model(**inputs)
logits = outputs.logits

Expected behavior

The model should run and produce error free output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants