Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 23, 2022
1 parent 1f733e8 commit fd6ec66
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
edge_label_time=edge_time,
batch_size=32,
time_attr='time',
neg_sampling_ratio=1.0,
neg_sampling_ratio=0.5,
drop_last=True,
)
for batch in loader:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LinkLoader(torch.utils.data.DataLoader):
neg_sampling (NegativeSamplingConfig, optional): The negative sampling
strategy. Can be either :obj:`"binary"` or :obj:`"triplet"`, and
can be further customized by an additional :obj:`amount` argument
to control the number of negatives to sample.
to control the ratio of sampled negatives to positive edges.
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
In case :obj:`edge_label` does not exist, it will be automatically
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class LinkNeighborLoader(LinkLoader):
neg_sampling (NegativeSamplingConfig, optional): The negative sampling
strategy. Can be either :obj:`"binary"` or :obj:`"triplet"`, and
can be further customized by an additional :obj:`amount` argument
to control the number of negatives to sample.
to control the ratio of sampled negatives to positive edges.
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
In case :obj:`edge_label` does not exist, it will be automatically
Expand Down
21 changes: 11 additions & 10 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def edge_sample(
# of nodes/edges to `src`, `dst`, `src_time`, `dst_time`.
# Later on, we can easily reconstruct what belongs to positive and
# negative examples by slicing via `num_pos`.
num_neg = round(num_pos * neg_sampling.amount)
num_neg = math.ceil(num_pos * neg_sampling.amount)

if neg_sampling.is_binary():
# In the "binary" case, we randomly sample negative pairs of nodes.
Expand All @@ -420,19 +420,17 @@ def edge_sample(
else:
src_node_time = node_time

src_neg = neg_sample(src, math.ceil(neg_sampling.amount),
num_src_nodes, src_time, src_node_time)
src_neg = src_neg[:num_neg]
src_neg = neg_sample(src, neg_sampling.amount, num_src_nodes,
src_time, src_node_time)
src = torch.cat([src, src_neg], dim=0)

if isinstance(node_time, dict):
dst_node_time = node_time.get(input_type[-1])
else:
dst_node_time = node_time

dst_neg = neg_sample(dst, math.ceil(neg_sampling.amount),
num_dst_nodes, dst_time, dst_node_time)
dst_neg = dst_neg[:num_neg]
dst_neg = neg_sample(dst, neg_sampling.amount, num_dst_nodes,
dst_time, dst_node_time)
dst = torch.cat([dst, dst_neg], dim=0)

if edge_label is None:
Expand Down Expand Up @@ -595,12 +593,14 @@ def edge_sample(
return out


def neg_sample(seed: Tensor, num_samples: int, num_nodes: int,
def neg_sample(seed: Tensor, num_samples: Union[int, float], num_nodes: int,
seed_time: Optional[Tensor],
node_time: Optional[Tensor]) -> Tensor:
num_neg = math.ceil(seed.numel() * num_samples)

# TODO: Do not sample false negatives.
if node_time is None:
return torch.randint(num_nodes, (seed.numel() * num_samples, ))
return torch.randint(num_nodes, (num_neg, ))

# If we are in a temporal-sampling scenario, we need to respect the
# timestamp of the given nodes we can use as negative examples.
Expand All @@ -611,6 +611,7 @@ def neg_sample(seed: Tensor, num_samples: int, num_nodes: int,
# each seed.
# TODO See if this greedy algorithm here can be improved.
assert seed_time is not None
num_samples = math.ceil(num_samples)
seed_time = seed_time.view(1, -1).expand(num_samples, -1)
out = torch.randint(num_nodes, (num_samples, seed.numel()))
mask = node_time[out] >= seed_time
Expand All @@ -630,4 +631,4 @@ def neg_sample(seed: Tensor, num_samples: int, num_nodes: int,
# to the node with minimum timestamp.
out[mask] = node_time.argmin()

return out.view(-1)
return out.view(-1)[:num_neg]

0 comments on commit fd6ec66

Please sign in to comment.