-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cherry-pick 2.1 release branch into XRT branch through 9/14 #5574
Conversation
* sharding should be per output of IR Node, instead of per IR Node * Update sharding_hash method * Add test for sharding on IR with multiple output * fix cpu test * Fix a bug in getSharding
* Make python Api to respect the virtual device when SPMD is enabled * fix typo
* Also dump output sharding on HLO file * only dump output sharding if dump format is HLO * add test * fix typo
* Make all-reduce a no-op when world size is 1 * Fix torch.distributed test
* fix amp dtype setting for GPU. * fix ut * fix lint. * minor.
* Add python test for SPMD+Runtime Python API * replace test name * Update test_xla_spmd_python_api_interaction.py
…5352) * Check the actual device instead of query env var for virtual device * revert unneeded change * minor changes
* tweak `atol` and `rtol`
* Skip`DynamoTrainingBasicTest.test_resnet18` on TPU
* Add kokoro presubmit for stablehlo tests
…5367) * [BE] use self.assertEquals instead of str equality in test_zero1.py * Use our own assertEqual * Remove print statements
* Fix ReplicateShardedData for int type * add test
Update dynamo.md to remove note about fallback ops since they're supported now
…erent_input_shape` on TPU (#5373) * tweak `atol` and `rtol` for `test_simple_model_with_different_input_shape` on TPU
…mutability (#5382) * [TEST ONLY] print statements for test_zero1.py to debug * Try fix * Rectify test_zero1.py to account for state_dict modification * Fix lint
#5384) * Add gpu doc for how to build PyTorch/XLA from source with GPU support. * fix typo * fix comments * fix comments
* Add more support for in-place ops in dynamo bridge Run linter * Add check to explicitly sync self tensors Remove debugging lines Update unit tests to a model * Clean up some code Surround in an if-statement Update metrics for fallback related dynamo tests Update cloned args logic Revert "Update metrics for fallback related dynamo tests" This reverts commit 3855f43. * Update single_node flag back to False
Add dynamo test in TPU CI
Summary: During the LLaMA2 experiements, I disovered that manually marking 1D tensors to be replicated can greatly save a lot of memory. Then I disocvered that explicitly replicated spec will get dropped after mark_step. That is caused by PrepareOutputShardingPropagation where it explicitly clear the sharding spec for replicated output. So, I went ahead and fix that. Further, I did some experiements of propogating replicated output and that drop the requirements of manually replicating 1D tensors. Hence, I made this change. I'm still not quite sure why, will follow up later. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py
Update more places Add torch_pin
* Update project metadata and remove useless files * Update README * Add manylinux platform tag * formatting
Something actually unconditionally calls |
I'm able to build the current commit against PyTorch on the 2.1 branch and run MNIST on TPU with XRT 🎊 |
I disabled this test that is missing from this branch: I don't know where it got lost, but there's no reason to use stablehlo with XRT anyway |
Also, I was hitting a weird build failure on this branch until I updated the CI cache silo name. I wonder if that's why the wheel build is failing. I'll try updating the silo name in this branch. |
Skipped commits that update bazel workspace or are incompatible with XRT:
I had to make substantial edits (ie not just renaming imports) to the following commits to make them build against our pins:
Last commit picked: ee72332