From 7eb392d73f897f51b0be97749ca6231ee38e58a4 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Dec 2020 19:02:30 -0800 Subject: [PATCH] Fix TCPStore type coercion (#49685) Summary: Fixes https://github.com/pytorch/pytorch/issues/49052 The TCPStore example with 4 arguments was working because the datetime value was being implicitly converted to a bool. Modified the pybind definition and updated documentation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/49685 Test Plan: ``` import torch.distributed as dist from datetime import timedelta dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) ``` Now fails with ``` TypeError: __init__(): incompatible constructor arguments. The following argument types are supported: 1. torch._C._distributed_c10d.TCPStore(host_name: str, port: int, world_size: int, is_master: bool, timeout: datetime.timedelta = datetime.timedelta(seconds=300)) Invoked with: '127.0.0.1', 0, True, datetime.timedelta(seconds=30) ``` Reviewed By: mrshenli, ngimel Differential Revision: D25668021 Pulled By: H-Huang fbshipit-source-id: ce40b8648d0a414f0255666fbc680f1a66fae090 --- torch/csrc/distributed/c10d/init.cpp | 35 +++++++++++++++++++--------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4db2f2600837..b31d44a1d295 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -383,7 +383,8 @@ value with the new supplied ``value``. Example:: >>> import torch.distributed as dist - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> from datetime import timedelta + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # Should return "first_value" >>> store.get("first_key") @@ -411,7 +412,8 @@ when initializing the store, before throwing an exception. Example:: >>> import torch.distributed as dist - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> from datetime import timedelta + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # Should return "first_value" >>> store.get("first_key") @@ -434,8 +436,9 @@ in an exception. Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.add("first_key", 1) >>> store.add("first_key", 6) >>> # Should return 7 @@ -461,8 +464,9 @@ Deletes the key-value pair associated with ``key`` from the store. Returns Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, HashStore can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key") >>> # This should return true >>> store.delete_key("first_key") @@ -487,8 +491,9 @@ the workers using the store. Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set("first_key", "first_value") >>> # This should return 2 >>> store.num_keys() @@ -506,8 +511,9 @@ Sets the store's default timeout. This timeout is used during initialization and Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> store.set_timeout(timedelta(seconds=10)) >>> # This will throw an exception after 10 seconds >>> store.wait(["bad_key"]) @@ -528,8 +534,9 @@ will throw an exception. Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> # This will throw an exception after 30 seconds >>> store.wait(["bad_key"]) )") @@ -551,8 +558,9 @@ if the keys have not been set by the supplied ``timeout``. Example:: >>> import torch.distributed as dist + >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used - >>> store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) + >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> # This will throw an exception after 10 seconds >>> store.wait(["bad_key"], timedelta(seconds=10)) )"); @@ -617,8 +625,11 @@ pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. Example:: >>> import torch.distributed as dist - >>> server_store = dist.TCPStore("127.0.0.1", 0, True, timedelta(seconds=30)) - >>> client_store = dist.TCPStore("127.0.0.1", 0, False) + >>> from datetime import timedelta + >>> # Run on process 1 (server) + >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) + >>> # Run on process 2 (client) + >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) >>> # Use any of the store methods from either the client or server after initialization >>> server_store.set("first_key", "first_value") >>> client_store.get("first_key") @@ -633,7 +644,9 @@ Example:: py::arg("host_name"), py::arg("port"), py::arg("world_size"), - py::arg("is_master"), + // using noconvert() requires this argument to be True or False + // prevents accidental implicit conversion to bool + py::arg("is_master").noconvert(), py::arg("timeout") = std::chrono::milliseconds(::c10d::Store::kDefaultTimeout));