@@ -683,7 +683,6 @@ <h1>Source code for torch.distributed.distributed_c10d</h1><div class="highlight
683683
684684 < span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> backend</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> Backend</ span > < span class ="p "> ]):</ span >
685685 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device_backend_map</ span > < span class ="p "> :</ span > < span class ="n "> Dict</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> Backend</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {}</ span >
686- < span class ="c1 "> # error check to make sure the config string is valid</ span >
687686
688687 < span class ="c1 "> # Cases for when backend is a single string (without device types)</ span >
689688 < span class ="k "> if</ span > < span class ="n "> backend</ span > < span class ="o "> ==</ span > < span class ="n "> Backend</ span > < span class ="o "> .</ span > < span class ="n "> UNDEFINED</ span > < span class ="p "> :</ span >
@@ -700,13 +699,24 @@ <h1>Source code for torch.distributed.distributed_c10d</h1><div class="highlight
700699 < span class ="s2 "> "cuda"</ span > < span class ="p "> :</ span > < span class ="n "> backend_val</ span > < span class ="p "> ,</ span >
701700 < span class ="p "> }</ span >
702701 < span class ="k "> else</ span > < span class ="p "> :</ span >
703- < span class ="c1 "> # custom backend string in format of "{device_type1}:{backend1},{device_type2}:{backend2}"</ span >
704- < span class ="c1 "> # TODO</ span >
705- < span class ="k "> pass</ span >
706-
707- < span class ="n "> required_devices</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="s2 "> "cpu"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "cuda"</ span > < span class ="p "> ]</ span >
708- < span class ="k "> for</ span > < span class ="n "> device</ span > < span class ="ow "> in</ span > < span class ="n "> required_devices</ span > < span class ="p "> :</ span >
709- < span class ="k "> assert</ span > < span class ="n "> device</ span > < span class ="ow "> in</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device_backend_map</ span >
702+ < span class ="c1 "> # make sure the backend string is in the correct format</ span >
703+ < span class ="c1 "> # "{device_type1}:{backend1},{device_type2}:{backend2}"</ span >
704+ < span class ="c1 "> # e.g. "cpu:gloo,cuda:nccl"</ span >
705+ < span class ="n "> backend_str_error_message</ span > < span class ="o "> =</ span > < span class ="sa "> f</ span > < span class ="s2 "> """The custom backend string argument is invalid: </ span > < span class ="si "> {</ span > < span class ="n "> backend</ span > < span class ="si "> }</ span > < span class ="s2 "> .</ span >
706+ < span class ="s2 "> Custom backend string is an experimental feature where the backend string must be in the format:</ span >
707+ < span class ="s2 "> "<device_type1>:<backend1>,<device_type2>:<backend2>...". e.g. 'cpu:gloo,cuda:nccl'"""</ span >
708+
709+ < span class ="c1 "> # parse the backend string and populate the device_backend_map</ span >
710+ < span class ="k "> for</ span > < span class ="n "> device_backend_pair_str</ span > < span class ="ow "> in</ span > < span class ="n "> backend</ span > < span class ="o "> .</ span > < span class ="n "> lower</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> (</ span > < span class ="s2 "> ","</ span > < span class ="p "> ):</ span >
711+ < span class ="n "> device_backend_pair</ span > < span class ="o "> =</ span > < span class ="n "> device_backend_pair_str</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> (</ span > < span class ="s2 "> ":"</ span > < span class ="p "> )</ span >
712+ < span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> device_backend_pair</ span > < span class ="p "> )</ span > < span class ="o "> !=</ span > < span class ="mi "> 2</ span > < span class ="p "> :</ span >
713+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Invalid device:backend pairing: </ span > < span class ="se "> \</ span >
714+ < span class ="s2 "> </ span > < span class ="si "> {</ span > < span class ="n "> device_backend_pair_str</ span > < span class ="si "> }</ span > < span class ="s2 "> . </ span > < span class ="si "> {</ span > < span class ="n "> backend_str_error_message</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
715+ < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> backend</ span > < span class ="o "> =</ span > < span class ="n "> device_backend_pair</ span >
716+ < span class ="k "> if</ span > < span class ="n "> device</ span > < span class ="ow "> in</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device_backend_map</ span > < span class ="p "> :</ span >
717+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "Duplicate device type </ span > < span class ="si "> {</ span > < span class ="n "> device</ span > < span class ="si "> }</ span > < span class ="s2 "> </ span > < span class ="se "> \</ span >
718+ < span class ="s2 "> in backend string: </ span > < span class ="si "> {</ span > < span class ="n "> backend</ span > < span class ="si "> }</ span > < span class ="s2 "> . </ span > < span class ="si "> {</ span > < span class ="n "> backend_str_error_message</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
719+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device_backend_map</ span > < span class ="p "> [</ span > < span class ="n "> device</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> Backend</ span > < span class ="p "> (</ span > < span class ="n "> backend</ span > < span class ="p "> )</ span >
710720
711721 < span class ="k "> def</ span > < span class ="fm "> __repr__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
712722 < span class ="c1 "> # string with all the device:backend pairs separared by commas</ span >
@@ -1293,7 +1303,9 @@ <h1>Source code for torch.distributed.distributed_c10d</h1><div class="highlight
12931303< span class ="sd "> .. note:: Support for multiple backends is experimental. Currently when no backend is</ span >
12941304< span class ="sd "> specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend</ span >
12951305< span class ="sd "> will be used for collectives with CPU tensors and the ``nccl`` backend will be used</ span >
1296- < span class ="sd "> for collectives with CUDA tensors.</ span >
1306+ < span class ="sd "> for collectives with CUDA tensors. A custom backend can be specified by passing in</ span >
1307+ < span class ="sd "> a string with format "<device_type>:<backend_name>,<device_type>:<backend_name>", e.g.</ span >
1308+ < span class ="sd "> "cpu:gloo,cuda:custom_backend".</ span >
12971309
12981310< span class ="sd "> """</ span >
12991311 < span class ="k "> global</ span > < span class ="n "> _world</ span >
@@ -1444,6 +1456,9 @@ <h1>Source code for torch.distributed.distributed_c10d</h1><div class="highlight
14441456 < span class ="n "> backend_type</ span > < span class ="o "> =</ span > < span class ="n "> ProcessGroup</ span > < span class ="o "> .</ span > < span class ="n "> BackendType</ span > < span class ="o "> .</ span > < span class ="n "> MPI</ span >
14451457 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> backend_class</ span > < span class ="p "> :</ span >
14461458 < span class ="k "> return</ span > < span class ="n "> GroupMember</ span > < span class ="o "> .</ span > < span class ="n "> NON_GROUP_MEMBER</ span >
1459+ < span class ="c1 "> # create new process group with accurate rank and size</ span >
1460+ < span class ="k "> if</ span > < span class ="n "> pg</ span > < span class ="o "> .</ span > < span class ="n "> rank</ span > < span class ="p "> ()</ span > < span class ="o "> ==</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="ow "> and</ span > < span class ="n "> pg</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="p "> ()</ span > < span class ="o "> ==</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1461+ < span class ="n "> pg</ span > < span class ="o "> =</ span > < span class ="n "> ProcessGroup</ span > < span class ="p "> (</ span > < span class ="n "> backend_prefix_store</ span > < span class ="p "> ,</ span > < span class ="n "> backend_class</ span > < span class ="o "> .</ span > < span class ="n "> rank</ span > < span class ="p "> (),</ span > < span class ="n "> backend_class</ span > < span class ="o "> .</ span > < span class ="n "> size</ span > < span class ="p "> (),</ span > < span class ="n "> base_pg_options</ span > < span class ="p "> )</ span >
14471462 < span class ="k "> elif</ span > < span class ="n "> backend_str</ span > < span class ="o "> ==</ span > < span class ="n "> Backend</ span > < span class ="o "> .</ span > < span class ="n "> GLOO</ span > < span class ="p "> :</ span >
14481463 < span class ="c1 "> # TODO: remove this check after lazy initialization is supported</ span >
14491464 < span class ="c1 "> # if pg_options is not None:</ span >
@@ -1527,15 +1542,15 @@ <h1>Source code for torch.distributed.distributed_c10d</h1><div class="highlight
15271542 < span class ="n "> timeout</ span > < span class ="o "> =</ span > < span class ="n "> timeout</ span > < span class ="p "> ,</ span >
15281543 < span class ="p "> )</ span >
15291544
1530- < span class ="c1 "> # only create single backend pg when backend is set to gloo, nccl, mpi, and ucc </ span >
1531- < span class ="k "> if</ span > < span class ="n " > backend </ span > < span class ="ow " > in </ span > < span class =" p "> [ </ span > < span class ="n " > Backend </ span > < span class ="o " > . </ span > < span class =" n " > GLOO </ span > < span class =" p "> , </ span > < span class ="n "> Backend </ span > < span class ="o "> .</ span > < span class ="n "> NCCL </ span > < span class ="p "> , </ span > < span class ="n " > Backend </ span > < span class =" o "> .</ span > < span class ="n "> UCC </ span > < span class ="p "> , </ span > < span class ="n " > Backend </ span > < span class =" o " > . </ span > < span class ="n " > MPI </ span > < span class ="p "> ] :</ span >
1545+ < span class ="c1 "> # register only a single backend when all get_device_backend_map values are the same </ span >
1546+ < span class ="k "> if</ span > < span class ="nb " > len </ span > < span class ="p "> ( </ span > < span class ="nb " > set </ span > < span class ="p "> ( </ span > < span class ="n "> backend_config </ span > < span class ="o "> .</ span > < span class ="n "> get_device_backend_map </ span > < span class ="p "> () </ span > < span class ="o "> .</ span > < span class ="n "> values </ span > < span class ="p "> ())) </ span > < span class ="o " > == </ span > < span class ="mi " > 1 </ span > < span class ="p "> :</ span >
15321547 < span class ="k "> for</ span > < span class ="n "> device</ span > < span class ="ow "> in</ span > < span class ="n "> backend_config</ span > < span class ="o "> .</ span > < span class ="n "> get_device_backend_map</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> keys</ span > < span class ="p "> ():</ span >
15331548 < span class ="n "> pg</ span > < span class ="o "> .</ span > < span class ="n "> _register_backend</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ),</ span > < span class ="n "> backend_type</ span > < span class ="p "> ,</ span > < span class ="n "> backend_class</ span > < span class ="p "> )</ span >
15341549
15351550 < span class ="c1 "> # break out of outer loop to not create any more backends</ span >
15361551 < span class ="k "> break</ span >
1537- < span class =" k " > else </ span > < span class =" p " > : </ span >
1538- < span class ="n "> pg</ span > < span class ="o "> .</ span > < span class ="n "> _register_backend</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ),</ span > < span class ="n "> backend_type</ span > < span class ="p "> ,</ span > < span class ="n "> backend_class</ span > < span class ="p "> )</ span >
1552+
1553+ < span class ="n "> pg</ span > < span class ="o "> .</ span > < span class ="n "> _register_backend</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ),</ span > < span class ="n "> backend_type</ span > < span class ="p "> ,</ span > < span class ="n "> backend_class</ span > < span class ="p "> )</ span >
15391554
15401555 < span class ="c1 "> # update global state</ span >
15411556 < span class ="n "> _world</ span > < span class ="o "> .</ span > < span class ="n "> pg_map</ span > < span class ="p "> [</ span > < span class ="n "> pg</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> (</ span > < span class ="n "> backend</ span > < span class ="p "> ,</ span > < span class ="n "> prefix_store</ span > < span class ="p "> )</ span >
0 commit comments