-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DistributedDataParallel #1715
Conversation
apaszke
commented
Jun 4, 2017
•
edited
Loading
edited
@@ -476,7 +476,7 @@ THDGroup DataChannelMPI::newGroup(const std::vector<rank_type>& ranks) { | |||
MPI_Group_incl(world_group, int_ranks.size(), int_ranks.data(), &ranks_group); | |||
|
|||
MPI_Comm new_comm; | |||
MPI_Comm_create_group(MPI_COMM_WORLD, ranks_group, 0, &new_comm); | |||
//MPI_Comm_create_group(MPI_COMM_WORLD, ranks_group, 0, &new_comm); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
4b93573
to
cbb3fdd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parts I understand look good. A few questions about things that confused me:
torch/lib/THD/base/Cuda.cpp
Outdated
@@ -0,0 +1,23 @@ | |||
#include "Cuda.hpp" | |||
|
|||
THCState** _THDCudaState; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/THD/base/Cuda.hpp
Outdated
|
||
int THDGetStreamId(cudaStream_t stream); | ||
|
||
#include "Cuda.h" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
SYSCHECK(fd = open("/dev/urandom", O_RDONLY)); | ||
SYSCHECK(read(fd, &seed, sizeof(seed))); | ||
SYSCHECK(bytes_read = read(fd, &seed, sizeof(seed))); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/parallel/distributed.py
Outdated
|
||
Example:: | ||
|
||
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/utils/data/distributed.py
Outdated
from .sampler import Sampler | ||
|
||
|
||
class DistributedSampler(Sampler): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Pushed fixes for review comments in a separate commit. I will reset it and squash the changes into appropriate commits when they're accepted. |
51aa47a
to
b2e3d79
Compare
test/run_test.sh
Outdated
@@ -20,72 +20,37 @@ fi | |||
|
|||
pushd "$(dirname "$0")" | |||
|
|||
echo "Running torch tests" | |||
$PYCMD test_torch.py $@ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
* Add keepdim * Fix DataChannel signature * Fix incorrect locking * Use current stream in DataChannelGloo
@apaszke Hi, as far as I understand, the I am suprised that the |
@acgtyrant DistributedDataParallel has only one synchronization, all gradients are all_reduce to all machines, so all machines have same copy of gradients. Then every machine does it's own optimization step. |
`found_non_rfactor_reduction` is used to detect errors when all reduction dims are marked as rfactors. However, this code is not finding non-rfactor reduction, but instead arbitrary reduction. Fortunately, other parts of our code could detect the same error, so this bug does not have any real effect. But still, I think we need to fix this.
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440 Commits that's actually in this PR from the csarofeen branch ``` * dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f Fix missing cooperative launch (#1726) * dc670a2 Async gmem copy support on sm80+ (#1619) * 5e6a8da Add turing mma support and test (#1643) * d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39 Mma op integration on ampere (#1440) * fade8da patch python test for bfloat16 (#1724) * 8fbd0b1 Fine-grained kernel profiling (#1720) * 77c1b4f Adding dry run mode to skip arch dependent checks (#1702) * 151d95b More precise concretization analysis (#1719) * f4d3630 Enable complex python tests (#1667) * 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830 Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7 updating_ci_machine (#1718) * 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453 Allow using nvFuser on CUDA extension (#1701) * 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: #78244 Approved by: https://github.com/csarofeen, https://github.com/malfet
Summary: Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ A few bigger updates: 1. Initial support of cp.async and cp.async.wait: csarofeen#1619 2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643 3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440 Commits that's actually in this PR from the csarofeen branch ``` * dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710) * b3d1c3f Fix missing cooperative launch (#1726) * dc670a2 Async gmem copy support on sm80+ (#1619) * 5e6a8da Add turing mma support and test (#1643) * d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723) * 7093e39 Mma op integration on ampere (#1440) * fade8da patch python test for bfloat16 (#1724) * 8fbd0b1 Fine-grained kernel profiling (#1720) * 77c1b4f Adding dry run mode to skip arch dependent checks (#1702) * 151d95b More precise concretization analysis (#1719) * f4d3630 Enable complex python tests (#1667) * 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715) * 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716) * f68b830 Fix scheduling with polymorphic broadcast (#1714) * 4ab5ef7 updating_ci_machine (#1718) * 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517 * 174d453 Allow using nvFuser on CUDA extension (#1701) * 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676) ``` Pull Request resolved: #78244 Reviewed By: ejguan Differential Revision: D36678948 Pulled By: davidberard98 fbshipit-source-id: 0ccde965acbd31da67d99c6adb2eaaa888948105