Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ jobs:
pip install fsspec
pip install rich
pip install flax
pip install sentencepiece
- name: Extra CI deps
if: inputs.has_code_changes == 'true'
shell: bash
Expand Down
4 changes: 4 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ function run_xla_op_tests3 {

function run_xla_op_tests4 {
run_test "$_TEST_DIR/test_jax_interop.py"
# issue #9691: random crashes with sentencepiece protobuf; run multiple times to trigger
for i in $(seq 1 5); do
run_test "$_TEST_DIR/test_sentencepiece_interop.py"
done
}

function run_xla_op_tests5 {
Expand Down
19 changes: 19 additions & 0 deletions test/test_sentencepiece_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest


class SentencepieceInterop(unittest.TestCase):

def test_sentencepiece_interop(self):
import os
if not os.path.exists("/tmp/test_model.model"):
import urllib.request
urllib.request.urlretrieve(
"https://github.com/google/sentencepiece/raw/refs/heads/master/python/test/test_model.model",
"/tmp/test_model.model")
import torch_xla
import sentencepiece as spm
sp_model = spm.SentencePieceProcessor("/tmp/test_model.model")


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@

import torch

# issue 9691: ensure sentencepiece protobuf init happen between
# torch/torch-xla protobuf inits to work-around protobuf crash
try:
import sentencepiece as spm
sp_model = spm.SentencePieceProcessor()
sp_model.load('')
except:
pass

import _XLAC
from ._internal import tpu
from .version import __version__
Expand Down
1 change: 1 addition & 0 deletions torchax/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
torch==2.8.0 ; sys_platform == 'darwin' # macOS
torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
jax==0.7.2 # N.B.: torchax breaks on newer JAX versions that would be pulled from `flax` dependencies
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails for Python 3.10. I think jax<0.8.0 would do the job.
Not sure what are the exact requirements for TorchAX, though. Maybe @qihqi can weigh in, here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Why is it pinning torch to 2.8?

flax==0.10.6
Loading