Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch not able to access all cores #1576

Closed
tmabraham opened this issue Feb 1, 2020 · 15 comments
Closed

PyTorch not able to access all cores #1576

tmabraham opened this issue Feb 1, 2020 · 15 comments
Assignees

Comments

@tmabraham
Copy link

Hello all,

I set up a GCP instance of TPUv3-8, and a VM instance in the same region, with torch-nightly build. When I set up PyTorch XLA, I noticed that xm.xrt_world_size() returns zero but xm.get_xla_supported_devices() returns all 8 devices. When I try to run code for training TPUs, it won't spawn 8 processes saying that nprocs=8 is not allowed as it is more than the xm_wrt_world_size().

Was the GCP instance set up wrong?

@jysohn23
Copy link
Collaborator

jysohn23 commented Feb 1, 2020

Can you try running https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py ? I don't think your setup is wrong, just that the way you're using the multiprocessing API may not be correct. Make sure to specify the correct number of processes to spawn:

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

For v3-8 it should be FLAGS.num_cores=8 to use all cores.

@tmabraham
Copy link
Author

@jysohn23 Interestingly this code is working.

However, if I type python in the command line to open up python in the terminal, then type

import torch_xla.core.xla_model as xm
xm.xrt_world_size()

it returns 1

@jysohn23
Copy link
Collaborator

jysohn23 commented Feb 1, 2020

Yes that's expected since your world is only a single process in that case. The xrt_world_size basically corresponds to how many processes are running to feed the TPUs training. So only after xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores) will it saw that the xrt_world_size is equal to FLAGS.num_cores.

@jysohn23 jysohn23 self-assigned this Feb 1, 2020
@tmabraham
Copy link
Author

@jysohn23

Ah ok I didn't know that.

Here is the error I am getting with my code:

Traceback (most recent call last):
  File "pytorch_bert_tpu.py", line 297, in <module>
    xmp.spawn(run,nprocs=8)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py"
, line 182, in spawn
    start_method=start_method)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 157, in
 start_processes
    while not context.join():
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in
 join
    raise Exception(msg)
Exception: 
-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in 
_wrap
    fn(i, *args)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py"
, line 116, in _start_fn
    _setup_replication()
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py"
, line 109, in _setup_replication
    xm.set_replication(str(device), [str(device)])
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/core/xla_model.py", line 199, in se
t_replication
    replication_devices = xla_replication_devices(devices)
  File "/anaconda3/envs/torch-xla-nightly/lib/python3.6/site-packages/torch_xla/core/xla_model.py", line 186, in xl
a_replication_devices
    .format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8

Any idea what it could be due to?

@dlibenzi
Copy link
Collaborator

dlibenzi commented Feb 1, 2020

That means you are trying to replicate when the system see 8 local devices, and you are using 1.
That should not happen if you follow the usage examples we have for multi-processing.
Once you enter the multi-processing function (_mp_fn below), you will only see one local device:

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)

@tmabraham
Copy link
Author

@dlibenzi
@jysohn23

I actually did set nprocs=8 but I realized I accidentally called xm.xla_device() before spawning the processes. This is what lead to the error.

Thanks for your help!

@sharkdeng
Copy link

@jysohn23 Interestingly this code is working.

However, if I type python in the command line to open up python in the terminal, then type

import torch_xla.core.xla_model as xm
xm.xrt_world_size()

it returns 1
me too. Why?

@dlibenzi
Copy link
Collaborator

Without xmp.spawn() you are in a single-process-many-devices mode, where you have to use thread-based parallelism (which we do not suggest as it is on its way to deprecation), and in that mode world size is 1.

@world2vec
Copy link

world2vec commented Jul 28, 2020

Hi,
I have the same problem. In my case:

  1. train model1 with xmp.spawn,
  2. predict model1 without xmp.spawn, which will call xm.xla_device()
  3. train model2 with xmp.spawn(which will fail because of this problem),

I did not use xmp.spawn to do predict at step2 because I did not find a way to return the predict result from xmp.spawn( save the result to disk is not my option).
So could you kindly let me know

  1. is there a way to return predict result from xmp.spawn? (can still use all cores)
  2. how can I reset device after step 2?

Let me know if you want a new issue open.

Thanks.

@tmabraham
Copy link
Author

@world2vec I think this is similar to #2268 and I think the answer is that it's not possible to use all cores after you use only one core.

@world2vec
Copy link

@tmabraham
Thanks.
@dlibenzi
Is there a way to get the predict result from xmp.spawn directly?

@taylanbil
Copy link
Collaborator

@world2vec you should be able to predict inside the spawned processes, why do you need to predict outside an xmp.spawn call? Alternatively, you could predict inside a separate spawned process. If you want to use only one device while predicting, you could spawn 1 subprocess, as in:

def train(...)
  ...

def predict( ...)
  ...

if __name__ == '__main__':
    xmp.spawn(train, args=(), nprocs=8)
    xmp.spawn(predict, args=(), nprocs=1)

@world2vec
Copy link

world2vec commented Jul 29, 2020

@taylanbil
I thought xmp.spawn(predict, args(), nprocs=1) is same as call predict() directly.
So specify nprocs=1 will not raise the same issue?

xmp.spawn(train, args=(model1_name), nprocs=8)
xmp.spawn(predict, args=(model1_name), nprocs=1)
xmp.spawn(train, args=(model2_name), nprocs=8)

By the way I prefer this way, get the output from the spawn:
predict_rst = xmp.spawn(predict, args=(model1_name), nprocs=8)

I have another function post_process the outputs from model1 and model2, as the function name, predict is predict, post_process is post_process, train is train. Technically we can put all code in one function or a toy notebook, but that is not good.

@taylanbil
Copy link
Collaborator

actually I tried to get this to work w/ a simple example and I'm running into problems. I don't suggest doing this.

this is all doable with one spawn, is there a reason why that's not ok for you?

def main():  # this is run on all 8 cores
  train model 1 on sharded data
  predict using model 1 on unsharded data  # unnecessary, duplicate computation but not a big deal
  train model 2 on sharded data
  ...

xmp.spawn(main, args=(), nprocs=8)

@world2vec
Copy link

well, I will say every time I need wrap all things in one function to use xmp.spawn, but technically yes we can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants