Skip to content

Facing RunTime Attribute error while running different Flax models for RoFormer #36854

Closed
@ctr-pmuruganTT

Description

@ctr-pmuruganTT

when running FlaxRoFormerForMaskedLM model, I have encountered an issue as

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'split'.

This error is reported in the file transformers/models/roformer/modeling_flax_roformer.py:265

The function responsible for this error in that file is as below

def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
      sin, cos = sinusoidal_pos.split(2, axis=-1)

While changing this particular line from sinusoidal_pos.split(2, axis=-1) to sinusoidal_pos._split(2, axis=-1) , I didn't get that error

My observation is when I replace split() with _split() , my issue is resolved

System Info

My environment details are as below :

  • transformers version: 4.49.0
  • Platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
  • Jax version: 0.4.36
  • JaxLib version: 0.4.36

I am attaching a screenshot for reference

Image

Who can help?

@gante @Rocketknight1

I am facing this issue for Models like

FlaxRoFormerForMultipleChoice

FlaxRoFormerForSequenceClassification

FlaxRoFormerForTokenClassification

FlaxRoFormerForQuestionAnswering

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to recreate the error:

Run the below code in any python editor

from transformers import AutoTokenizer, FlaxRoFormerForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base")
model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")

inputs = tokenizer("The capital of France is [MASK].", return_tensors="jax")

outputs = model(**inputs)
logits = outputs.logits

Expected behavior

The model should run and produce error free output

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions