-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to test on a subset of TPUs in a TPU Pod #7714
Comments
Solved, to run on tpu-v3-8 node
For more complicated subsets, configure the following vars
|
Yea, that's the env var you need. I will close this issue if no further question? |
While
It hangs in Multi-processing settings, at the first node of a
Output:
Do you have any suggestions on how to fix this? |
Maybe take a look at https://gist.github.com/skye/f82ba45d2445bb19d53545538754f9a3? I believe for each subprocess you need to set different |
Thanks! I tried using this instead, which still doesn't work
Does it mean we need to provide different env vars to each of the process |
lol @will-cromar I need your help |
You're on the right track. There are two places where we can request information about TPU topology: GCE metadata or environment variables. If you want to do multiprocessing on one host out of a pod, the best way to do that would be to set all of the topology environment variables as if you were running on one host:
If you do that, then Just to be upfront, we can't support manually setting these topology settings in general. The configurations we support are already implemented through Having said that, this particular configuration (skip metadata query and limit the workload to one host) is exactly the configuration used by Kaggle and Colab, which we do support, so you can expect that to keep working. |
Thanks! After some debugging, I found there's a few minor errors in your env var setting.
|
❓ Questions and Help
We have some quota for TPU pods (TPU v3-8N, N>1) but not for single-node machines (TPU v3-8). As everyone knows, single-node machines are really useful for debugging. However, under the default settings, simply launching the XLA code on a single node within a pod won't work -- it will wait for other nodes to join.
From JAX’s documentation, I vaguely remember there’s an environment variable that allows you to run code on a subset of TPUs from a TPU pod. Do we have this feature in PyTorch XLA? If so, could you provide a pointer to this?
The text was updated successfully, but these errors were encountered: