Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.distributed support on MacOS is missing #20380

Closed
yaroslavvb opened this issue May 10, 2019 · 24 comments
Closed

torch.distributed support on MacOS is missing #20380

yaroslavvb opened this issue May 10, 2019 · 24 comments
Assignees
Labels
feature A request for a proper, new feature. module: macos Mac OS related issues oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yaroslavvb
Copy link
Contributor

yaroslavvb commented May 10, 2019

Currently trying to use distributed on MacOS crashes because torch.distributed namespace is empty
I vaguely recall it working a year ago.

This is useful for quick sanity checks on my MacBook before deploying to cluster.

Python 3.7.3 (default, Mar 27 2019, 16:54:48) 
[Clang 4.0.1 (tags/RELEASE_401/final)] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.version.__version__)
1.1.0
>>> import torch.distributed as dist
>>> dist.init_process_group
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'torch.distributed' has no attribute 'init_process_group'
@pytorchbot pytorchbot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 10, 2019
@jeffreyksmithjr jeffreyksmithjr added feature A request for a proper, new feature. module: macos Mac OS related issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 13, 2019
@jeffreyksmithjr
Copy link
Contributor

Cc @pietern @mrshenli

@soumith
Copy link
Member

soumith commented May 13, 2019

@yaroslavvb from what I remember, we dont plan to add distributed support on OSX or Windows

@yaroslavvb
Copy link
Contributor Author

yaroslavvb commented Jul 10, 2019

How hard would it be to add 1-worker DDP support to Mac? For instance -m torch.distributed.launch already works.

This would make it easier for people who prototype on Macs before deploying to Linux. Currently it's a bit annoying with if sys.platform=="darwin" bits cropping up everywhere, example

cc @jspisak who would know if there's enough people using PyTorch on Mac. From my Bay Area vantage, it seems to be everybody.

@yaroslavvb yaroslavvb reopened this Jul 10, 2019
@pietern
Copy link
Contributor

pietern commented Jul 11, 2019

@yaroslavvb Good point. I agree. This has been raised before and I agree we should revisit.

Adding macOS support (and Windows support for that matter) will be contingent on making the Gloo TCP transport work with something other than epoll(2), which is the Linux specific part. This can be done by either adding kqueue support to gloo/transport/tcp/device.cc, or by adding another transport that doesn't have any platform specifics at all (e.g. through libuv).

@pietern pietern self-assigned this Jul 11, 2019
@Ownmarc
Copy link

Ownmarc commented Jul 31, 2019

+1 for Windows support of the distributed training! Currently getting "AttributeError: module 'torch.distributed' has no attribute 'init_process_group'" when trying to run distributed training on Windows. The excat same code works fine on Linux.

@pietern
Copy link
Contributor

pietern commented Jul 31, 2019

@Ownmarc In what kind of environment would you use this? Multiple machine or only multiple processes on a single machine? I have not heard of folks wanting to use torch.distributed on Windows so am curious to hear more about your intended usage.

@Ownmarc
Copy link

Ownmarc commented Jul 31, 2019

@Ownmarc In what kind of environment would you use this? Multiple machine or only multiple processes on a single machine? I have not heard of folks wanting to use torch.distributed on Windows so am curious to hear more about your intended usage.

To take advantage of 2 GPUs in the same machine.

@pietern here is an other person. I am using this repo too ultralytics/yolov3#336 (comment)

@pietern
Copy link
Contributor

pietern commented Sep 16, 2019

@yaroslavvb This is now available in the nightlies. It is full blown support, so you should be able to run DDP on a big stack of MacBooks, if you wanted to... :D

@pietern pietern closed this as completed Sep 16, 2019
@jspisak
Copy link
Contributor

jspisak commented Sep 16, 2019

Great work Pieter!!

@Ownmarc
Copy link

Ownmarc commented Sep 16, 2019

@pietern is this going to support Windows too ?

@pietern
Copy link
Contributor

pietern commented Sep 16, 2019

@Ownmarc It could, because the implementation is based on libuv. Is this something you'd be interested in contributing? I think the majority of time spent will be in 1) getting the Gloo build to work, 2) getting the PyTorch build to work (there are only 2 ifdef'ed pieces of code in ProcessGroupGloo.cpp, and 3) getting CI for torch.distributed to work well on Windows.

@Ownmarc
Copy link

Ownmarc commented Sep 16, 2019

@pietern , I just took a look at ProcessGroupGloo.cpp and was kinda lost. Unfortunatly, I do not know much other then Python, I am still in my early days of programming and just recently switched from tensorflow to pytorch!

@yaroslavvb
Copy link
Contributor Author

Time to put all my 4-cores of macbook pro to good work!

Seriously though, this is great for code consistency, time to finally banish nn.DataParallel from my code :)

@pietern
Copy link
Contributor

pietern commented Sep 17, 2019

@Ownmarc Do you compile PyTorch from source on Windows? If so, I could guide you through some of the steps to do this, but do realize it likely won't be a walk in the park.

@TimZaman
Copy link

@pietern awesome this is exactly what I needed!

@jarednielsen
Copy link

Perfect, thank you @pietern!

@tbwxmu
Copy link

tbwxmu commented Dec 5, 2019

@Ownmarc, have you solve this problem on Windows? I have tried the preview version and still get same problem.

@pietern
Copy link
Contributor

pietern commented Dec 9, 2019

@tbwxmu This issue tracked support for macOS, not for Windows.

@Ownmarc
Copy link

Ownmarc commented Dec 17, 2019

Opened it officialy for Windows, lets see if I am the only one with more then 1 GPU on windows ahah.
#31363

@neggert
Copy link

neggert commented Mar 13, 2020

@pietern Is GLOO supported on MacOS? This code hangs for me at init_process_group on a Mac, but seems to work fine on Linux.

import datetime
import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def setup(rank, world_size):
    print(f"starting rank {rank}")
    dist.init_process_group(
        "gloo",
        init_method="tcp://localhost:12345",
        rank=rank,
        world_size=world_size,
        timeout=datetime.timedelta(seconds=10)
    )
    print(f"started rank {rank}")
    cleanup()


def cleanup():
    dist.destroy_process_group()


if __name__ == "__main__":
    mp.spawn(setup, args=(2, ), nprocs=2, join=True)

@NihalHarish
Copy link

NihalHarish commented Apr 9, 2020

Has support been reverted with 1.15 release?

I was testing the v1.5.0-rc2 release candidate on OSX and see that the output for:

torch.distributed.is_available() is False

@y78h11b09
Copy link

@Ownmarc It could, because the implementation is based on libuv. Is this something you'd be interested in contributing? I think the majority of time spent will be in 1) getting the Gloo build to work, 2) getting the PyTorch build to work (there are only 2 ifdef'ed pieces of code in ProcessGroupGloo.cpp, and 3) getting CI for torch.distributed to work well on Windows.

Can you share your work?
I think that's pretty good, Pytorch to do that on Windows with nccl!

@Ownmarc
Copy link

Ownmarc commented Aug 8, 2020

@y78h11b09 I am not good enough with C++ to implement this, but I think Windows said they would support Pytorch recently, maybe they will implement this!! :D

@mrshenli
Copy link
Contributor

mrshenli commented Aug 9, 2020

For Windows support, please check this RFC (#42095)

Hey @neggert, yes PyTorch + Gloo works on MacOS, but you will need to compile from source using the following steps:

  1. follow the readme in https://github.com/pytorch/pytorch to setup conda and dependencies
  2. then conda install libuv and pkg-config
  3. then run time env MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ BUILD_CAFFE2_OPS=0 USE_CUDA=0 USE_MKLDNN=0 USE_DISTRIBUTED=1 python setup.py develop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: macos Mac OS related issues oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests