Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU crash during importing Trainer from transformers #6990

Closed
asiff00 opened this issue Apr 29, 2024 · 6 comments
Closed

TPU crash during importing Trainer from transformers #6990

asiff00 opened this issue Apr 29, 2024 · 6 comments
Assignees

Comments

@asiff00
Copy link

asiff00 commented Apr 29, 2024

馃悰 Bug

The Colab/Kaggle notebook crashes while trying to import 'Trainer' from the transformers library.

image

To Reproduce

!pip install transformers !pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html from transformers import Trainer

or

!pip install transformers !pip install torch_xla[tpu] from transformers import Trainer

or

!pip install transformers !pip install torch_xla from transformers import Trainer

Steps to reproduce the behavior:

  1. Install xla
  2. Import Trainer from the transformer library.
  3. Your environment crashes with the following error:
    ERROR: Unknown command line flag 'xla_latency_hiding_scheduler_rerun'

Environment

  • Reproducible on XLA backend [TPU]:
  • torch_xla version: 2.2.0+libtpu
@JackCaoG
Copy link
Collaborator

The flag is from https://github.com/pytorch/xla/blob/r2.2/torch_xla/__init__.py#L43-L44, I am trying to get my kaggle TPU and see if I can repo this.

@JackCaoG
Copy link
Collaborator

OK I was able to confirm that it did crash. I tried to install the new torch 2.3 on my TPUVM with

pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html

and this seems to work

>>> import torch
>>> import torch_xla
>>> t1 = torch.randn(5,5, device='xla:0')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714426096.507062 2529487 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jackcao/.local/lib/python3.8/site-packages/libtpu/libtpu.so
I0000 00:00:1714426096.507156 2529487 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1714426096.507162 2529487 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
>>> t1
tensor([[ 1.1453, -0.9900,  0.5783,  1.7081,  1.1962],
        [ 0.6340,  1.6611,  0.2455, -0.7434,  3.1036],
        [-1.1664,  0.5326,  1.7286, -0.7094,  1.1267],
        [ 1.2665, -0.2168, -3.1145, -1.9214, -1.2044],
        [ 1.8507,  0.0055,  1.2275, -0.2037, -0.7610]], device='xla:0')
>>> torch_xla.__version__
'2.3.0'
>>> from transformers import Trainer
>>> Trainer
<class 'transformers.trainer.Trainer'>

There is also a in flight pr to update the default torch version to 2.3. Do you mind manually install the 2.3 for now?

@JackCaoG
Copy link
Collaborator

Ah I know.. it is Kaggle that preinstall tensorflow and HF will try to import tensorflow which will load tensorflow's libtpu which is not compatible with the pytorch/xla.

!yes | pip3 uninstall tensorflow

fixed the issue on my end.

@JackCaoG
Copy link
Collaborator

I will assign this bug to @will-cromar to add a warning message to make this more clear in the future releases.

@asiff00
Copy link
Author

asiff00 commented Apr 29, 2024

image

This specific problem solved with (#6990 (comment))

!yes | pip3 uninstall tensorflow !pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html
I'll test a few other transformers class before closing it.

@will-cromar
Copy link
Collaborator

Let us know if you're still having issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants