Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed biased sampling #8934

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519))
- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))
- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624), [#8722](https://github.com/pyg-team/pytorch_geometric/pull/8722))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624), [#8722](https://github.com/pyg-team/pytorch_geometric/pull/8722), [#8934](https://github.com/pyg-team/pytorch_geometric/pull/8934))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
Expand Down
124 changes: 100 additions & 24 deletions test/distributed/test_dist_link_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch_geometric.typing import EdgeType


def create_data(rank, world_size, time_attr: Optional[str] = None):
def create_data(rank, world_size, attr_name: Optional[str] = None):
if rank == 0: # Partition 0:
node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9])
edge_index = torch.tensor([ # Sorted by destination.
Expand Down Expand Up @@ -56,11 +56,12 @@ def create_data(rank, world_size, time_attr: Optional[str] = None):
])
data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10)

if time_attr == 'time': # Create node-level time data:
if attr_name == 'time': # Create node-level time data:
data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4])
feature_store.put_tensor(data.time, group_name=None, attr_name='time')
feature_store.put_tensor(data.time, group_name=None,
attr_name=attr_name)

elif time_attr == 'edge_time': # Create edge-level time data:
elif attr_name == 'edge_time': # Create edge-level time data:
data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11])

if rank == 0:
Expand All @@ -69,7 +70,17 @@ def create_data(rank, world_size, time_attr: Optional[str] = None):
edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11])

feature_store.put_tensor(edge_time, group_name=None,
attr_name=time_attr)
attr_name=attr_name)
elif attr_name == 'edge_weight': # Create edge-level weight data:
data.edge_weight = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11])

if rank == 0:
edge_weight = torch.tensor([0, 1, 2, 3, 4, 5, 11])
if rank == 1:
edge_weight = torch.tensor([4, 7, 7, 7, 7, 7, 11])

feature_store.put_tensor(edge_weight, group_name=None,
attr_name=attr_name)

return (feature_store, graph_store), data

Expand All @@ -87,8 +98,9 @@ def dist_link_neighbor_sampler(
rank: int,
master_port: int,
disjoint: bool = False,
weight_attr: str = None,
):
dist_data, data = create_data(rank, world_size)
dist_data, data = create_data(rank, world_size, weight_attr)

current_ctx = DistContext(
rank=rank,
Expand All @@ -104,6 +116,7 @@ def dist_link_neighbor_sampler(
num_neighbors=[-1, -1],
shuffle=False,
disjoint=disjoint,
weight_attr=weight_attr,
)

# Close RPC & worker group at exit:
Expand Down Expand Up @@ -135,7 +148,7 @@ def dist_link_neighbor_sampler(

# evaluate distributed edge sample function
out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(
inputs, dist_sampler.node_sample, data.num_nodes, disjoint))
inputs, dist_sampler.node_sample, data.num_nodes, disjoint=disjoint))

sampler = NeighborSampler(data=data, num_neighbors=[-1, -1],
disjoint=disjoint)
Expand Down Expand Up @@ -256,6 +269,7 @@ def dist_link_neighbor_sampler_hetero(
master_port: int,
input_type: EdgeType,
disjoint: bool = False,
weight_attr: str = None,
):
dist_data, other_graph_store = create_hetero_data(tmp_path, rank)

Expand All @@ -267,14 +281,10 @@ def dist_link_neighbor_sampler_hetero(
group_name='dist-sampler-test',
)

dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=[-1],
shuffle=False,
disjoint=disjoint,
)
dist_sampler = DistNeighborSampler(data=dist_data, current_ctx=current_ctx,
rpc_worker_names={}, num_neighbors=[-1],
shuffle=False, disjoint=disjoint,
weight_attr=weight_attr)

# close RPC & worker group at exit:
atexit.register(shutdown_rpc)
Expand All @@ -300,8 +310,8 @@ def dist_link_neighbor_sampler_hetero(
col_1 = edge_label_index2[1][0]

# Seed edges:
input_row = torch.tensor([row_0, row_1])
input_col = torch.tensor([col_0, col_1])
input_row = torch.tensor([row_0, row_1], dtype=torch.int64)
input_col = torch.tensor([col_0, col_1], dtype=torch.int64)

inputs = EdgeSamplerInput(
input_id=None,
Expand All @@ -312,20 +322,17 @@ def dist_link_neighbor_sampler_hetero(

# Evaluate distributed `node_sample` function:
out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(
inputs, dist_sampler.node_sample, data.num_nodes, disjoint))
inputs, dist_sampler.node_sample, data.num_nodes, disjoint=disjoint))

sampler = NeighborSampler(
data=data,
num_neighbors=[-1],
disjoint=disjoint,
)
sampler = NeighborSampler(data=data, num_neighbors=[-1], disjoint=disjoint,
weight_attr=weight_attr)

# Evaluate edge sample function:
out = edge_sample(
inputs,
sampler._sample,
data.num_nodes,
disjoint,
disjoint=disjoint,
)

# Compare distributed output with single machine output:
Expand Down Expand Up @@ -523,6 +530,30 @@ def test_dist_link_neighbor_sampler_edge_level_temporal(
w1.join()


@onlyDistributedTest
def test_dist_link_neighbor_sampler_biased():
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
w0 = mp_context.Process(
target=dist_link_neighbor_sampler,
args=(world_size, 0, port, False, 'edge_weight'),
)

w1 = mp_context.Process(
target=dist_link_neighbor_sampler,
args=(world_size, 1, port, False, 'edge_weight'),
)

w0.start()
w1.start()
w0.join()
w1.join()


@onlyDistributedTest
@pytest.mark.parametrize('disjoint', [False, True])
def test_dist_link_neighbor_sampler_hetero(tmp_path, disjoint):
Expand Down Expand Up @@ -671,3 +702,48 @@ def test_dist_link_neighbor_sampler_edge_level_temporal_hetero(
w1.start()
w0.join()
w1.join()


@onlyDistributedTest
def test_dist_link_neighbor_sampler_biased_hetero(tmp_path):
mp_context = torch.multiprocessing.get_context('spawn')
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.settimeout(1)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]

world_size = 2
data = FakeHeteroDataset(
num_graphs=1,
avg_num_nodes=100,
avg_degree=3,
num_node_types=2,
num_edge_types=4,
edge_dim=2,
)[0]
data = T.ToUndirected()(data)

# Add weight information to the data:
for i, edge_type in enumerate(data.edge_types):
data[edge_type].edge_weight = torch.full( #
(data[edge_type].num_edges, ), i, dtype=torch.int64)

partitioner = Partitioner(data, world_size, tmp_path)
partitioner.generate_partition()

w0 = mp_context.Process(
target=dist_link_neighbor_sampler_hetero,
args=(data, tmp_path, world_size, 0, port, ('v0', 'e0', 'v0'), False,
'edge_weight'),
)

w1 = mp_context.Process(
target=dist_link_neighbor_sampler_hetero,
args=(data, tmp_path, world_size, 1, port, ('v0', 'e0', 'v1'), False,
'edge_weight'),
)

w0.start()
w1.start()
w0.join()
w1.join()
9 changes: 7 additions & 2 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
weight_attr: Optional[str] = None,
concurrency: int = 1,
device: Optional[torch.device] = None,
**kwargs,
Expand All @@ -98,6 +99,7 @@ def __init__(
self.disjoint = disjoint
self.temporal_strategy = temporal_strategy
self.time_attr = time_attr
self.weight_attr = weight_attr
self.temporal = time_attr is not None
self.with_edge_attr = self.feature_store.has_edge_attr()
self.csc = True
Expand All @@ -111,13 +113,15 @@ def init_sampler_instance(self):
disjoint=self.disjoint,
temporal_strategy=self.temporal_strategy,
time_attr=self.time_attr,
weight_attr=self.weight_attr,
)

self.num_hops = self._sampler.num_neighbors.num_hops
self.node_types = self._sampler.node_types
self.edge_types = self._sampler.edge_types
self.node_time = self._sampler.node_time
self.edge_time = self._sampler.edge_time
self.edge_weight = self._sampler.edge_weight

def register_sampler_rpc(self) -> None:
partition2workers = rpc_partition_to_workers(
Expand Down Expand Up @@ -952,6 +956,7 @@ def _sample_one_hop(
row = self._sampler.row
node_time = self.node_time
edge_time = self.edge_time
edge_weight = self.edge_weight
else:
# Given edge type, get input data and evaluate sample function:
rel_type = '__'.join(edge_type)
Expand All @@ -960,7 +965,7 @@ def _sample_one_hop(
# `node_time` is a destination node time:
node_time = (self.node_time or {}).get(edge_type[0], None)
edge_time = (self.edge_time or {}).get(edge_type, None)

edge_weight = (self.edge_weight or {}).get(edge_type, None)
out = torch.ops.pyg.dist_neighbor_sample(
colptr,
row,
Expand All @@ -969,7 +974,7 @@ def _sample_one_hop(
node_time,
edge_time,
seed_time,
None, # TODO: edge_weight
edge_weight,
True, # csc
self.replace,
self.subgraph_type != SubgraphType.induced,
Expand Down
8 changes: 8 additions & 0 deletions torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':
feat_store.put_tensor(edge_feats['edge_time'],
group_name=(None, None),
attr_name='edge_time')
if 'edge_weight' in edge_feats:
feat_store.put_tensor(edge_feats['edge_weight'],
group_name=(None, None),
attr_name='edge_weight')

if meta['is_hetero'] and node_feats is not None:
for node_type, node_feat in node_feats.items():
Expand All @@ -467,5 +471,9 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':
feat_store.put_tensor(edge_feat['edge_time'],
group_name=edge_type,
attr_name='edge_time')
if 'edge_weight' in edge_feat:
feat_store.put_tensor(edge_feat['edge_weight'],
group_name=edge_type,
attr_name='edge_weight')

return feat_store
24 changes: 24 additions & 0 deletions torch_geometric/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ def is_edge_level_time(self) -> bool:

return self.data.is_edge_attr('time')

@property
def is_edge_level_weight(self) -> bool:
if 'edge_weight' in self.data:
return True

if self.is_hetero:
return any(
['edge_weight' in store for store in self.data.edge_stores])

@property
def node_types(self) -> Optional[List[NodeType]]:
return self.data.node_types if self.is_hetero else None
Expand Down Expand Up @@ -197,6 +206,11 @@ def generate_partition(self):
elif self.is_node_level_time:
src_node_time = time_data[src]

edge_weight = None
if self.is_edge_level_weight:
if 'edge_weight' in part_data:
edge_weight = part_data.edge_weight[mask]

offsetted_row = global_row - node_offset[src]
offsetted_col = global_col - node_offset[dst]
# Sort by column to avoid keeping track of permutations in
Expand Down Expand Up @@ -235,6 +249,9 @@ def generate_partition(self):
})
if self.is_edge_level_time:
efeat[edge_type].update({'edge_time': edge_time[perm]})
if self.is_edge_level_weight:
efeat[edge_type].update(
{'edge_weight': edge_weight[perm]})

torch.save(efeat, osp.join(path, 'edge_feats.pt'))
torch.save(graph, osp.join(path, 'graph.pt'))
Expand Down Expand Up @@ -303,6 +320,11 @@ def generate_partition(self):
elif self.is_node_level_time:
node_time = data.time

edge_weight = None
if self.is_edge_level_weight:
if 'edge_weight' in part_data:
edge_weight = part_data.edge_weight

# Sort by column to avoid keeping track of permuations in
# `NeighborSampler` when converting to CSC format:
global_row, global_col, perm = sort_csc(
Expand Down Expand Up @@ -346,6 +368,8 @@ def generate_partition(self):
})
if self.is_edge_level_time:
efeat.update({'edge_time': edge_time[perm]})
if self.is_edge_level_weight:
efeat.update({'edge_weight': edge_weight[perm]})

torch.save(efeat, osp.join(path, 'edge_feats.pt'))

Expand Down
Loading
Loading