Skip to content

Commit

Permalink
Make conversion true id to int id in RW run (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
pedugnat committed Aug 9, 2022
1 parent c96831e commit bf5da42
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion dynnode2vec/biased_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _generate_walk_simple(

def run(
self,
nodes: List[int],
nodes: List[Any],
*,
n_walks: int = 10,
walk_length: int = 10,
Expand All @@ -160,6 +160,8 @@ def run(
"""
rn = random.Random(seed)

nodes = self.convert_true_ids_to_int_ids(nodes)

# weights are multiplied by inverse p and q
ip, iq = 1.0 / p, 1.0 / q

Expand Down
5 changes: 1 addition & 4 deletions dynnode2vec/dynnode2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,8 @@ def generate_updated_walks(
# that changed compared to the previous time step
delta_nodes = self.get_delta_nodes(current_graph, previous_graph)

brw = BiasedRandomWalk(current_graph)
delta_nodes = brw.convert_true_ids_to_int_ids(delta_nodes)

# run walks for updated nodes only
updated_walks = brw.run(
updated_walks = BiasedRandomWalk(current_graph).run(
nodes=delta_nodes,
walk_length=self.walk_length,
n_walks=self.n_walks_per_node,
Expand Down
3 changes: 3 additions & 0 deletions dynnode2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def generate_dynamic_graphs(
# Create a random graph
graph = nx.fast_gnp_random_graph(n=n_base_nodes, p=base_density)

# add one to each node to avoid the perfect case where true_ids match int_ids
graph = nx.relabel_nodes(graph, mapping={n: n + 1 for n in graph.nodes()})

# initialize graphs list with first graph
graphs = [graph.copy()]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_biased_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def test_run(graphs, p, q, weighted):
random_walks = brw.run(graph.nodes(), p=p, q=q, weighted=weighted)

assert all(isinstance(walk, list) for walk in random_walks)
assert all(n in brw.graph.nodes() for walk in random_walks for n in walk)
assert all(n in graph.nodes() for walk in random_walks for n in walk)

0 comments on commit bf5da42

Please sign in to comment.