Skip to content

Commit

Permalink
Fix CI error with socket bind (#8981)
Browse files Browse the repository at this point in the history
Fixed a repetitive issue in distributed tests with a process binding to
a socket that is already in use:
`torch.distributed.DistNetworkError: The server socket has failed to
listen on any local network address. The server socket has failed to
bind to [::]:54902 (errno: 98 - Address already in use). The server
socket has failed to bind to 0.0.0.0:54902 (errno: 98 - Address already
in use).`

by changing the `localhost` to an empty address which allows any local
address to access the TCP port and adding socket opt 'SO_REUSEADDR':
```
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.settimeout(1)
        s.bind(('', 0))
        port = s.getsockname()[1]
```

As a bonus I also added a minor change in `test_model_hub.py` that was
causing CI to fail.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
JakubPietrakIntel and rusty1s committed Feb 28, 2024
1 parent 9068831 commit ee7364c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 32 deletions.
18 changes: 10 additions & 8 deletions test/distributed/test_dist_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ def test_dist_link_neighbor_loader_homo(
):
addr = '127.0.0.1'
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind((addr, 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

data = FakeDataset(
num_graphs=1,
Expand Down Expand Up @@ -198,10 +199,11 @@ def test_dist_link_neighbor_loader_hetero(
):
mp_context = torch.multiprocessing.get_context('spawn')
addr = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind((addr, 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

data = FakeHeteroDataset(
num_graphs=1,
Expand Down
21 changes: 15 additions & 6 deletions test/distributed/test_dist_link_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def dist_link_neighbor_sampler_temporal_hetero(
def test_dist_link_neighbor_sampler(disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand All @@ -466,7 +468,9 @@ def test_dist_link_neighbor_sampler(disjoint):
def test_dist_link_neighbor_sampler_temporal(seed_time, temporal_strategy):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -497,7 +501,9 @@ def test_dist_link_neighbor_sampler_edge_level_temporal(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand All @@ -522,8 +528,9 @@ def test_dist_link_neighbor_sampler_edge_level_temporal(
def test_dist_link_neighbor_sampler_hetero(tmp_path, disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -571,8 +578,9 @@ def test_dist_link_neighbor_sampler_temporal_hetero(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -623,8 +631,9 @@ def test_dist_link_neighbor_sampler_edge_level_temporal_hetero(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down
18 changes: 10 additions & 8 deletions test/distributed/test_dist_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ def test_dist_neighbor_loader_homo(
):
mp_context = torch.multiprocessing.get_context('spawn')
addr = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind((addr, 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

data = FakeDataset(
num_graphs=1,
Expand Down Expand Up @@ -196,10 +197,11 @@ def test_dist_neighbor_loader_hetero(
):
mp_context = torch.multiprocessing.get_context('spawn')
addr = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind((addr, 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('', 0))
port = s.getsockname()[1]

data = FakeHeteroDataset(
num_graphs=1,
Expand Down
18 changes: 12 additions & 6 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,9 @@ def dist_neighbor_sampler_temporal_hetero(
def test_dist_neighbor_sampler(disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand All @@ -418,8 +419,9 @@ def test_dist_neighbor_sampler(disjoint):
def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -450,8 +452,9 @@ def test_dist_neighbor_sampler_edge_level_temporal(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand All @@ -476,8 +479,9 @@ def test_dist_neighbor_sampler_edge_level_temporal(
def test_dist_neighbor_sampler_hetero(tmp_path, disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -522,8 +526,9 @@ def test_dist_neighbor_sampler_temporal_hetero(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down Expand Up @@ -574,8 +579,9 @@ def test_dist_neighbor_sampler_edge_level_temporal_hetero(

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
s.bind(('', 0))
port = s.getsockname()[1]

world_size = 2
Expand Down
9 changes: 5 additions & 4 deletions test/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ def test_dist_feature_lookup():
feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x')

mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

w0 = mp_context.Process(target=run_rpc_feature_test,
args=(2, 0, feature0, partition_book, port))
Expand Down
1 change: 1 addition & 0 deletions test/nn/test_model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_from_pretrained(model, tmp_path):

model = model.from_pretrained(save_directory)
assert isinstance(model, DummyModel)
assert model.config == CONFIG


@withPackage('huggingface_hub')
Expand Down

0 comments on commit ee7364c

Please sign in to comment.