Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaudi: Fix the pipeline failed issue with hpu device #36990

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

yuanwu2017
Copy link
Contributor

@yuanwu2017 yuanwu2017 commented Mar 26, 2025

What does this PR do?

Fix the pipeline failed issue when using the hpu device.

INFO:datasets:PyTorch version 2.6.0+hpu.1.20.0.543.git4952fce available.
INFO:datasets:TensorFlow version 2.15.1 available.
INFO:datasets:JAX version 0.4.13 available.
2025-03-26 06:21:23.298102: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-26 06:21:23.300227: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-26 06:21:23.323606: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-26 06:21:23.323639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-26 06:21:23.324597: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-26 06:21:23.329263: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-26 06:21:23.329421: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-26 06:21:23.975243: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
INFO:root:args = Namespace(model_id='facebook/mms-lid-256', autocast_dtype='float32', ipex_optimize=False, jit=False, torch_compile=False, model_dtype='bfloat16', backend='inductor', device='hpu', batch_size=1, num_beams=4, input_tokens=512, output_tokens=32, do_sample=False, ipex_optimize_transformers=False, warm_up_steps=10, run_steps=10, optimum_intel=False, compare_outputs=False, quant_algo='None', quant_dtype='None', tp_plan=None, local_rank=0)
Device set to use hpu
Traceback (most recent call last):
  File "/workspace/HuggingFace/tests/workloads/audio-classification/run_audio-classification.py", line 60, in <module>
    generator = pipeline(
  File "/workspace/transformers/src/transformers/pipelines/__init__.py", line 1180, in pipeline
    return pipeline_class(model=model, framework=framework, task=task, **kwargs)
  File "/workspace/transformers/src/transformers/pipelines/audio_classification.py", line 99, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/transformers/src/transformers/pipelines/base.py", line 988, in __init__
    self.model.to(self.device)
  File "/workspace/transformers/src/transformers/modeling_utils.py", line 3725, in to
    return super().to(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1346, in to
    return self._apply(convert)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 903, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 903, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 903, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 930, in _apply
    param_applied = fn(param)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1332, in convert
    return t.to(
ModuleNotFoundError: No module named 'torch.hpu'
~

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: yuanwu <yuan.wu@intel.com>
@github-actions github-actions bot marked this pull request as draft March 26, 2025 05:55
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

Signed-off-by: yuanwu <yuan.wu@intel.com>
@yuanwu2017
Copy link
Contributor Author

For the make fixup issue, Line 812 and line 813 cannot be changed in order. Because when import habana_frameworks.torch, some torch functions will be replaced

warning: Invalid rule code provided to `# noqa` at tests/models/big_bird/test_modeling_big_bird.py:913: E321
src/transformers/utils/import_utils.py:812:5: I001 [*] Import block is un-sorted or un-formatted
    |
810 |           return False
811 |
812 | /     import torch
813 | |     import habana_frameworks.torch.utils.experimental as htexp  # noqa: F401
    | |______________________________________________________________^ I001
814 |
815 |       if not hasattr(torch, "hpu") or not torch.hpu.is_available():
    |
    = help: Organize imports

Found 1 error.
[*] 1 fixable with the `--fix` option.

@yuanwu2017 yuanwu2017 marked this pull request as ready for review March 26, 2025 06:17
@yuanwu2017
Copy link
Contributor Author

@IlyasMoutawwakil Please help to review.

@IlyasMoutawwakil
Copy link
Member

I believe that starting from Pytorch 2.6 and Synapse 1.20 (the one we targeted for upstream), torch.hpu doesn't need patching by habana_frameworks.

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Mar 26, 2025

what's the actual issue here ? please provide the code to reproduce it, the error alone is not enough

@yuanwu2017
Copy link
Contributor Author

Dockerfile:

# Those arguments are required to build the image
ARG HABANA_VERSION=1.20.0
ARG PYTORCH_VERSION=2.6.0
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base

# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
    dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb

WORKDIR /usr/src

ENV:
image

test.py

from transformers import pipeline, AutoTokenizer
pipe = pipeline("text-classification", device="hpu")
out = pipe("This restaurant is awesome")
print (out)

Command:
python test.py

image

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Mar 26, 2025

Okay I see what's happening here, I only targeted non-lazy mode when integrating hpu in transformers, and weirdly, if you have PT_HPU_LAZY_MODE=0 in you environment, you don't need to import habana_frameworks.torch to use torch.hpu.

root@05f434bf385a:/home/ubuntu/workspace/transformers# python
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.hpu
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2681, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'hpu'

vs

root@05f434bf385a:/home/ubuntu/workspace/transformers# PT_HPU_LAZY_MODE=0 python
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
Calling add_step_closure function does not have any effect. It's lazy mode only functionality. (warning logged once)
Calling mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
Calling iter_mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
>>> torch.hpu
<module 'habana_frameworks.torch.hpu' from '/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/__init__.py'>
>>> 

This is a very implicit behavior that I can't see documented anywhere so the least we could do is make it explicit here with something like this:

    import torch

    if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
        # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu
        import habana_frameworks.torch

    if not hasattr(torch, "hpu") or not torch.hpu.is_available():
        return False

@yuanwu2017
Copy link
Contributor Author

Okay I see what's happening here, I only targeted non-lazy mode when integrating hpu in transformers, and weirdly, if you have PT_HPU_LAZY_MODE=0 in you environment, you don't need to import habana_frameworks.torch to use torch.hpu.

root@05f434bf385a:/home/ubuntu/workspace/transformers# python
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.hpu
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2681, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'hpu'

vs

root@05f434bf385a:/home/ubuntu/workspace/transformers# PT_HPU_LAZY_MODE=0 python
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
Calling add_step_closure function does not have any effect. It's lazy mode only functionality. (warning logged once)
Calling mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
Calling iter_mark_step function does not have any effect. It's lazy mode only functionality. (warning logged once)
>>> torch.hpu
<module 'habana_frameworks.torch.hpu' from '/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/__init__.py'>
>>> 

This is a very implicit behavior that I can't see documented anywhere so the least we could do is make it explicit here with something like this:

    import torch

    if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
        # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu
        import habana_frameworks.torch

    if not hasattr(torch, "hpu") or not torch.hpu.is_available():
        return False

Ok.

Signed-off-by: yuanwu <yuan.wu@intel.com>
@yuanwu2017
Copy link
Contributor Author

Done

@IlyasMoutawwakil
Copy link
Member

@yuanwu2017 please run make style

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants