Skip to content

Commit

Permalink
tests: added basic for layouts (#293)
Browse files Browse the repository at this point in the history
* tests: added basic for layouts
* added seed argument to all layout functions
  • Loading branch information
maximelucas committed Mar 14, 2023
1 parent c72af43 commit e37c299
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 6 deletions.
84 changes: 84 additions & 0 deletions tests/drawing/test_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

import numpy as np

import xgi
from xgi.exception import XGIError


def test_random_layout():

H = xgi.random_hypergraph(10, [0.2], seed=1)

# seed
pos1 = xgi.random_layout(H, seed=1)
pos2 = xgi.random_layout(H, seed=2)
pos3 = xgi.random_layout(H, seed=2)
assert pos1.keys() == pos2.keys()
assert pos2.keys() == pos3.keys()
assert not np.allclose(list(pos1.values()), list(pos2.values()))
assert np.allclose(list(pos2.values()), list(pos3.values()))

assert len(pos1) == H.num_nodes


def test_pairwise_spring_layout():

H = xgi.random_hypergraph(10, [0.2], seed=1)

# seed
pos1 = xgi.pairwise_spring_layout(H, seed=1)
pos2 = xgi.pairwise_spring_layout(H, seed=2)
pos3 = xgi.pairwise_spring_layout(H, seed=2)
assert pos1.keys() == pos2.keys()
assert pos2.keys() == pos3.keys()
assert not np.allclose(list(pos1.values()), list(pos2.values()))
assert np.allclose(list(pos2.values()), list(pos3.values()))

assert len(pos1) == H.num_nodes


def test_barycenter_spring_layout():

H = xgi.random_hypergraph(10, [0.2], seed=1)

# seed
pos1 = xgi.barycenter_spring_layout(H, seed=1)
pos2 = xgi.barycenter_spring_layout(H, seed=2)
pos3 = xgi.barycenter_spring_layout(H, seed=2)
assert pos1.keys() == pos2.keys()
assert pos2.keys() == pos3.keys()
assert not np.allclose(list(pos1.values()), list(pos2.values()))
assert np.allclose(list(pos2.values()), list(pos3.values()))

assert len(pos1) == H.num_nodes

# phantom
pos4, G = xgi.barycenter_spring_layout(H, return_phantom_graph=True, seed=1)
pos5 = xgi.barycenter_spring_layout(H, return_phantom_graph=False, seed=1)
assert pos4.keys() == pos5.keys()
assert np.allclose(list(pos4.values()), list(pos5.values()))


def test_weighted_barycenter_spring_layout():

H = xgi.random_hypergraph(10, [0.2], seed=1)

# seed
pos1 = xgi.weighted_barycenter_spring_layout(H, seed=1)
pos2 = xgi.weighted_barycenter_spring_layout(H, seed=2)
pos3 = xgi.weighted_barycenter_spring_layout(H, seed=2)
assert pos1.keys() == pos2.keys()
assert pos2.keys() == pos3.keys()
assert not np.allclose(list(pos1.values()), list(pos2.values()))
assert np.allclose(list(pos2.values()), list(pos3.values()))

assert len(pos1) == H.num_nodes

# phantom
pos4, G = xgi.weighted_barycenter_spring_layout(
H, return_phantom_graph=True, seed=1
)
pos5 = xgi.weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=1)
assert pos4.keys() == pos5.keys()
assert np.allclose(list(pos4.values()), list(pos5.values()))
45 changes: 39 additions & 6 deletions xgi/drawing/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Algorithms to compute node positions for drawing."""

import random

import networkx as nx

from .. import convert
Expand Down Expand Up @@ -61,7 +63,7 @@ def random_layout(H, center=None, dim=2, seed=None):
return pos


def pairwise_spring_layout(H):
def pairwise_spring_layout(H, seed=None):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using the graph projection of the hypergraph
Expand All @@ -71,6 +73,13 @@ def pairwise_spring_layout(H):
----------
H : Hypergraph or SimplicialComplex
A position will be assigned to every node in H.
seed : int, RandomState instance or None optional (default=None)
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
if numpy.random.RandomState instance, `seed` is the random
number generator,
if None, the random number generator is the RandomState instance used
by numpy.random.
Returns
-------
Expand All @@ -95,14 +104,18 @@ def pairwise_spring_layout(H):
>>> H = xgi.random_hypergraph(N, ps)
>>> pos = xgi.pairwise_spring_layout(H)
"""

if seed is not None:
random.seed(seed)

if isinstance(H, SimplicialComplex):
H = convert.from_simplicial_complex_to_hypergraph(H)
G = convert.convert_to_graph(H)
pos = nx.spring_layout(G)
pos = nx.spring_layout(G, seed=seed)
return pos


def barycenter_spring_layout(H, return_phantom_graph=False):
def barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
Expand All @@ -115,6 +128,13 @@ def barycenter_spring_layout(H, return_phantom_graph=False):
----------
H : xgi Hypergraph or SimplicialComplex
A position will be assigned to every node in H.
seed : int, RandomState instance or None optional (default=None)
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
if numpy.random.RandomState instance, `seed` is the random
number generator,
if None, the random number generator is the RandomState instance used
by numpy.random.
Returns
-------
Expand All @@ -135,6 +155,9 @@ def barycenter_spring_layout(H, return_phantom_graph=False):
>>> pos = xgi.barycenter_spring_layout(H)
"""
if seed is not None:
random.seed(seed)

if isinstance(H, SimplicialComplex):
H = convert.from_simplicial_complex_to_hypergraph(H)

Expand Down Expand Up @@ -167,7 +190,7 @@ def barycenter_spring_layout(H, return_phantom_graph=False):
phantom_node_id += 1

# Creating a dictionary for the position of the nodes with the standard spring layout
pos_with_phantom_nodes = nx.spring_layout(G)
pos_with_phantom_nodes = nx.spring_layout(G, seed=seed)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand All @@ -178,7 +201,7 @@ def barycenter_spring_layout(H, return_phantom_graph=False):
return pos


def weighted_barycenter_spring_layout(H, return_phantom_graph=False):
def weighted_barycenter_spring_layout(H, return_phantom_graph=False, seed=None):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
Expand All @@ -194,6 +217,13 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False):
----------
H : Hypergraph or SimplicialComplex
A position will be assigned to every node in H.
seed : int, RandomState instance or None optional (default=None)
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
if numpy.random.RandomState instance, `seed` is the random
number generator,
if None, the random number generator is the RandomState instance used
by numpy.random.
Returns
-------
Expand All @@ -213,6 +243,9 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False):
>>> H = xgi.random_hypergraph(N, ps)
>>> pos = xgi.weighted_barycenter_spring_layout(H)
"""
if seed is not None:
random.seed(seed)

if isinstance(H, SimplicialComplex):
H = convert.from_simplicial_complex_to_hypergraph(H)

Expand Down Expand Up @@ -245,7 +278,7 @@ def weighted_barycenter_spring_layout(H, return_phantom_graph=False):
phantom_node_id += 1

# Creating a dictionary for the position of the nodes with the standard spring layout
pos_with_phantom_nodes = nx.spring_layout(G, weight="weight")
pos_with_phantom_nodes = nx.spring_layout(G, weight="weight", seed=seed)

# Retaining only the positions of the real nodes
pos = {k: pos_with_phantom_nodes[k] for k in list(H.nodes)}
Expand Down

0 comments on commit e37c299

Please sign in to comment.