Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to test on a subset of TPUs in a TPU Pod #7714

Closed
Jiayi-Pan opened this issue Jul 19, 2024 · 8 comments
Closed

How to test on a subset of TPUs in a TPU Pod #7714

Jiayi-Pan opened this issue Jul 19, 2024 · 8 comments

Comments

@Jiayi-Pan
Copy link

Jiayi-Pan commented Jul 19, 2024

❓ Questions and Help

We have some quota for TPU pods (TPU v3-8N, N>1) but not for single-node machines (TPU v3-8). As everyone knows, single-node machines are really useful for debugging. However, under the default settings, simply launching the XLA code on a single node within a pod won't work -- it will wait for other nodes to join.

From JAX’s documentation, I vaguely remember there’s an environment variable that allows you to run code on a subset of TPUs from a TPU pod. Do we have this feature in PyTorch XLA? If so, could you provide a pointer to this?

@Jiayi-Pan Jiayi-Pan changed the title How to test on a single Node in a TPU Pod How to test on a slice of TPUs in a TPU Pod Jul 19, 2024
@Jiayi-Pan Jiayi-Pan changed the title How to test on a slice of TPUs in a TPU Pod How to test on a subset of TPUs in a TPU Pod Jul 19, 2024
@Jiayi-Pan
Copy link
Author

Solved, to run on tpu-v3-8 node

export TPU_HOST_BOUNDS='1,1,1'

For more complicated subsets, configure the following vars

TPU_CHIPS_PER_HOST_BOUNDS
TPU_HOST_BOUNDS'
TPU_VISIBLE_DEVICES

@JackCaoG
Copy link
Collaborator

Yea, that's the env var you need. I will close this issue if no further question?

@Jiayi-Pan
Copy link
Author

While export TPU_HOST_BOUNDS='1,1,1' works for naive code like

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device()
print(t)

It hangs in Multi-processing settings, at the first node of a tpu-v3-64 pod, with that env var set, following code hangs

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch

def _mp_fn(index):
  device = xm.xla_device()
  data = torch.randn(2, 2, device=device)
  print(data)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

Output:

jiayipan@t1v-n-bc530acf-w-0:~/prismatic-video-lms$ python example.py 
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.110325   95807 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.110423   95807 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.110434   95807 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.114095   95805 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.114171   95805 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.114181   95805 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.118047   95802 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.118122   95802 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.118132   95802 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.142126   95806 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.142198   95806 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.142215   95806 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 59, in _ru
n_thread_per_device
    initializer_fn(local_rank, local_world_size)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 122, in in
itialize_multiprocess
    devices = xm.get_xla_supported_devices()
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 99, in get
_xla_supported_devices
    devices = torch_xla._XLAC._xla_get_devices()
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder 
grpc channel to 10.142.0.20:8479.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jiayipan/prismatic-video-lms/example.py", line 14, in <module>
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 211, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 171, in run_multiprocess
    replica_results = list(
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 172, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 570, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.142.0.20:8479.

Do you have any suggestions on how to fix this?

@Jiayi-Pan Jiayi-Pan reopened this Jul 20, 2024
@JackCaoG
Copy link
Collaborator

Maybe take a look at https://gist.github.com/skye/f82ba45d2445bb19d53545538754f9a3? I believe for each subprocess you need to set different TPU_VISIBLE_DEVICES

@Jiayi-Pan
Copy link
Author

Jiayi-Pan commented Jul 22, 2024

Thanks! I tried using this instead, which still doesn't work

export TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1"

# Set the TPU process bounds
export TPU_PROCESS_BOUNDS="2,2,1"

# Set the TPU process addresses
export TPU_PROCESS_ADDRESSES="localhost:8476,localhost:8477,localhost:8478,localhost:8479"

# Set the visible TPU devices
export TPU_VISIBLE_DEVICES="0"  # "1", "2", "3"

# Set the TPU process port
export TPU_PROCESS_PORT="8476"  # "8477", "8478", "8479"

export CLOUD_TPU_TASK_ID=0

Does it mean we need to provide different env vars to each of the process xmp.spawn creates? If so, how should we do this.

@JackCaoG
Copy link
Collaborator

lol @will-cromar I need your help

@will-cromar
Copy link
Collaborator

You're on the right track. There are two places where we can request information about TPU topology: GCE metadata or environment variables.

If you want to do multiprocessing on one host out of a pod, the best way to do that would be to set all of the topology environment variables as if you were running on one host:

TPU_SKIP_MDS_QUERY=1 # Don't query metadata
TPU_HOST_BOUNDS=1,1,1 # Pretend there's one host in the "pod"
TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 # 4 chips per host
TPU_WORKER_HOSTNAMES=localhost
WORKER_ID=0 # Since there's only one worker in this cluster, index is always 0

If you do that, then xmp.spawn will take care of TPU_PROCESS_BOUNDS, TPU_PROCESS_ADDRESSES, TPU_VISIBLE_DEVICES, etc. The logic for setting all of these lives in tpu.py if you're curious.

Just to be upfront, we can't support manually setting these topology settings in general. The configurations we support are already implemented through xmp.spawn.

Having said that, this particular configuration (skip metadata query and limit the workload to one host) is exactly the configuration used by Kaggle and Colab, which we do support, so you can expect that to keep working.

@Jiayi-Pan
Copy link
Author

Thanks! After some debugging, I found there's a few minor errors in your env var setting.
However the following works!

export TPU_SKIP_MDS_QUERY=1 # Don't query metadata
export TPU_HOST_BOUNDS=1,1,1 # Pretend there's one host in the "pod"
export TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 # 4 chips per host
export TPU_WORKER_HOSTNAMES=localhost
export TPU_WORKER_ID=0 # Since there's only one worker in this cluster, index is always 0
export TPU_ACCELERATOR_TYPE=v3-8 # Use v3-8 as the accelerator type

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants