Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return mappable to allow colorbar #441

Merged
merged 36 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4503e9c
feat: return mappable to allow colorbar
maximelucas Aug 1, 2023
8e3eaf1
refact: simplified and improved draw_nodes by fully using plt.scatter
maximelucas Aug 1, 2023
7aca0e0
docs: updated docstrings for draw_nodes. Removed node_ec_cmap
maximelucas Aug 2, 2023
4ad68f9
fix: docstring
maximelucas Aug 2, 2023
a8070dc
fix: remove unused code
maximelucas Aug 2, 2023
2e1e715
feat: accept dict as input in draw
maximelucas Aug 16, 2023
f53a463
add check for negative input values #272
maximelucas Aug 16, 2023
e832f36
feat: added node_shape argument in draw_nodes
maximelucas Aug 16, 2023
26e4a7c
fix: updated draw and other version to match new draw_nodes. fix: mos…
maximelucas Aug 16, 2023
c1aa9ab
fix: last test
maximelucas Aug 16, 2023
88661c3
docs: added colorbar note
maximelucas Aug 16, 2023
e241d00
tests: fixed and added more. fix: now plotting nodes with non-finite …
maximelucas Aug 16, 2023
e8572aa
style: back and isort
maximelucas Aug 16, 2023
a9e81ea
fix: raise for negative node size
maximelucas Aug 16, 2023
fd47b34
fix: raise for negative node lw
maximelucas Aug 16, 2023
3c72b5e
test: added for cmap and vmin/vmax in draw_nodes
maximelucas Aug 16, 2023
20c2bb2
tuto: new about draw_nodes
maximelucas Aug 16, 2023
5f2a1ec
tuto: added non-finite color values
maximelucas Aug 16, 2023
fd0a21e
fix: import numpy in tuto
maximelucas Aug 16, 2023
6fff96c
test: try to fix import
maximelucas Aug 18, 2023
ed4e728
docs: added docstring to update_lims
maximelucas Aug 18, 2023
7c5826e
refact: new _draw_arg_to_arr function
maximelucas Aug 18, 2023
b3e4897
fix: add import
maximelucas Aug 18, 2023
0928229
trying to fix import problem
maximelucas Aug 18, 2023
1b25953
lint: remove unused imports in draw
maximelucas Aug 18, 2023
a9f917a
trying to fix tests again: remove failing tests
maximelucas Aug 18, 2023
0ecf0cf
crazy idea: reimport numpy inside the faulty tests
maximelucas Aug 18, 2023
d664ca9
trying to remove dependency to np in one test
maximelucas Aug 18, 2023
c5849f1
tests: new for _draw_arg_to_arr
maximelucas Aug 18, 2023
5ca8206
style: black
maximelucas Aug 18, 2023
e7ec403
Merge branch 'main' into colorbar
maximelucas Aug 18, 2023
855ed79
moved tests back to test_draw_utils - maybe they were moved in main a…
maximelucas Aug 18, 2023
d459e63
removed double import np - hopefully not necessary anymore
maximelucas Aug 18, 2023
8589339
feat: changed from clipping to interp between min/max. added 'rescali…
maximelucas Aug 18, 2023
3765024
fix: consistent default vals in draw. feat: rescaling_sizes=True by d…
maximelucas Aug 18, 2023
26fa109
minor review comment
maximelucas Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 180 additions & 15 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest

import xgi
Expand All @@ -8,43 +9,196 @@
def test_draw(edgelist8):
H = xgi.Hypergraph(edgelist8)

ax = xgi.draw(H)
fig, ax = plt.subplots()
ax, node_collection = xgi.draw(H, ax=ax)

# number of elements
assert len(ax.lines) == len(H.edges.filterby("size", 2)) # dyads
assert len(ax.patches) == len(H.edges.filterby("size", 2, mode="gt")) # hyperedges
assert len(ax.collections[0].get_sizes()) == H.num_nodes # nodes
offsets = node_collection.get_offsets()
assert offsets.shape[0] == H.num_nodes # nodes

# zorder
for line in ax.lines: # dyads
assert line.get_zorder() == 3
for patch, z in zip(ax.patches, [2, 2, 0, 2, 2]): # hyperedges
assert patch.get_zorder() == z
assert ax.collections[0].get_zorder() == 4 # nodes
assert node_collection.get_zorder() == 4 # nodes

plt.close()


def test_draw_nodes(edgelist8):

H = xgi.Hypergraph(edgelist8)

ax = xgi.draw_nodes(H)
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(H, ax=ax)
fig2, ax2 = plt.subplots()
ax2, node_collection2 = xgi.draw_nodes(
H,
ax=ax2,
node_fc="r",
node_ec="b",
node_lw=2,
node_size=20,
zorder=10,
node_shape="v",
)

# number of elements
assert len(ax.lines) == 0 # dyads
assert len(ax.patches) == 0 # hyperedges
assert len(ax.collections[0].get_sizes()) == H.num_nodes # nodes
offsets = node_collection.get_offsets()
assert offsets.shape[0] == H.num_nodes # nodes

# node_fc
assert np.all(
node_collection.get_facecolor() == np.array([[1.0, 1.0, 1.0, 1.0]])
) # white
assert np.all(
node_collection2.get_facecolor() == np.array([[1.0, 0.0, 0.0, 1.0]])
) # blue

# node_ec
assert np.all(
node_collection.get_edgecolor() == np.array([[0.0, 0.0, 0.0, 1.0]])
) # black
assert np.all(
node_collection2.get_edgecolor() == np.array([[0.0, 0.0, 1.0, 1.0]])
) # red

# node_lw
assert np.all(node_collection.get_linewidth() == np.array([1]))
assert np.all(node_collection2.get_linewidth() == np.array([2]))

# node_size
assert np.all(node_collection.get_sizes() == np.array([15**2]))
assert np.all(node_collection2.get_sizes() == np.array([20**2]))

# zorder
assert ax.collections[0].get_zorder() == 0 # nodes
assert node_collection.get_zorder() == 0
assert node_collection2.get_zorder() == 10

# negative node_lw or node_size
with pytest.raises(ValueError):
ax3, node_collection3 = xgi.draw_nodes(H, node_size=-1)
plt.close()
with pytest.raises(ValueError):
ax3, node_collection3 = xgi.draw_nodes(H, node_lw=-1)
plt.close()

plt.close("all")


def test_draw_nodes_fc_cmap(edgelist8):

H = xgi.Hypergraph(edgelist8)

# unused default when single color
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(H, ax=ax, node_fc="r")
assert node_collection.get_cmap() == plt.cm.viridis
plt.close()

# default cmap
fig, ax = plt.subplots()
colors = [11, 12, 14, 16, 17, 19, 21]
ax, node_collection = xgi.draw_nodes(H, ax=ax, node_fc=colors)
assert node_collection.get_cmap() == plt.cm.Reds
plt.close()

# set cmap
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H, ax=ax, node_fc=colors, node_fc_cmap="Greens"
)
assert node_collection.get_cmap() == plt.cm.Greens
assert (min(colors), max(colors)) == node_collection.get_clim()
plt.close()

# vmin/vmax
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(H, ax=ax, node_fc=colors, vmin=14, vmax=19)
assert (14, 19) == node_collection.get_clim()
plt.close()


def test_draw_nodes_interp(edgelist8):

H = xgi.Hypergraph(edgelist8)
arg = H.nodes.degree
deg_arr = np.array([6, 5, 4, 4, 3, 2, 2])
assert np.all(arg.aslist() == deg_arr)

fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(H, ax=ax, node_size=1, node_lw=10)
assert np.all(node_collection.get_sizes() == np.array([1]))
assert np.all(node_collection.get_linewidth() == np.array([10]))
plt.close()

# rescaling does not affect scalars
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H, ax=ax, node_size=1, node_lw=10, rescale_sizes=True
)
assert np.all(node_collection.get_sizes() == np.array([1]))
assert np.all(node_collection.get_linewidth() == np.array([10]))
plt.close()

# not rescaling IDStat
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H, ax=ax, node_size=arg, node_lw=arg, rescale_sizes=False
)
assert np.all(node_collection.get_sizes() == deg_arr**2)
assert np.all(node_collection.get_linewidth() == deg_arr)
plt.close()

# rescaling IDStat
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H, ax=ax, node_size=arg, node_lw=arg, rescale_sizes=True
)
assert min(node_collection.get_sizes()) == 5**2
assert max(node_collection.get_sizes()) == 30**2
assert min(node_collection.get_linewidth()) == 0
assert max(node_collection.get_linewidth()) == 5
plt.close()

# rescaling IDStat with manual values
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H,
ax=ax,
node_size=arg,
node_lw=arg,
rescale_sizes=True,
**{"min_node_size": 1, "max_node_size": 20, "min_node_lw": 1, "max_node_lw": 10}
)
assert min(node_collection.get_sizes()) == 1**2
assert max(node_collection.get_sizes()) == 20**2
assert min(node_collection.get_linewidth()) == 1
assert max(node_collection.get_linewidth()) == 10
plt.close()

# rescaling ndarray
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_nodes(
H, ax=ax, node_size=arg, node_lw=deg_arr, rescale_sizes=True
)
assert min(node_collection.get_sizes()) == 5**2
assert max(node_collection.get_sizes()) == 30**2
assert min(node_collection.get_linewidth()) == 0
assert max(node_collection.get_linewidth()) == 5
plt.close()


def test_draw_hyperedges(edgelist8):
H = xgi.Hypergraph(edgelist8)

ax = xgi.draw_hyperedges(H)
fig, ax = plt.subplots()
ax = xgi.draw_hyperedges(H, ax=ax)

# number of elements
assert len(ax.lines) == len(H.edges.filterby("size", 2)) # dyads
Expand All @@ -67,7 +221,9 @@ def test_draw_simplices(edgelist8):
plt.close()

S = xgi.SimplicialComplex(edgelist8)
ax = xgi.draw_simplices(S)

fig, ax = plt.subplots()
ax = xgi.draw_simplices(S, ax=ax)

# number of elements
assert len(ax.lines) == 18 # dyads
Expand All @@ -87,23 +243,26 @@ def test_draw_hypergraph_hull(edgelist8):

H = xgi.Hypergraph(edgelist8)

ax = xgi.draw_hypergraph_hull(H)
fig, ax = plt.subplots()
ax, node_collection = xgi.draw_hypergraph_hull(H, ax=ax)

# number of elements
assert len(ax.patches) == len(H.edges.filterby("size", 2, mode="gt")) # hyperedges
assert len(ax.collections[0].get_sizes()) == H.num_nodes # nodes
offsets = node_collection.get_offsets()
assert offsets.shape[0] == H.num_nodes # nodes

# zorder
for patch, z in zip(ax.patches, [2, 2, 0, 2, 2]): # hyperedges
assert patch.get_zorder() == z
assert ax.collections[0].get_zorder() == 4 # nodes
assert node_collection.get_zorder() == 4 # nodes

plt.close()


def test_correct_number_of_collections_draw_multilayer(edgelist8):
# hypergraph
H = xgi.Hypergraph(edgelist8)

ax1 = xgi.draw_multilayer(H)
sizes = xgi.unique_edge_sizes(H)
num_planes = max(sizes) - min(sizes) + 1
Expand Down Expand Up @@ -183,7 +342,9 @@ def test_correct_number_of_collections_draw_multilayer(edgelist8):

def test_draw_dihypergraph(diedgelist2, edgelist8):
DH = xgi.DiHypergraph(diedgelist2)
ax1 = xgi.draw_dihypergraph(DH)

fig, ax1 = plt.subplots()
ax1 = xgi.draw_dihypergraph(DH, ax=ax1)

# number of elements
assert len(ax1.lines) == 7 # number of source nodes
Expand All @@ -203,15 +364,17 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
plt.close()

# test toggle for edges
ax2 = xgi.draw_dihypergraph(DH, edge_marker_toggle=False)
fig, ax2 = plt.subplots()
ax2 = xgi.draw_dihypergraph(DH, edge_marker_toggle=False, ax=ax2)
assert len(ax2.collections) == 1

plt.close()

# test XGI ERROR raise
with pytest.raises(XGIError):
H = xgi.Hypergraph(edgelist8)
ax3 = xgi.draw_dihypergraph(H)
fig, ax3 = plt.subplots()
ax3 = xgi.draw_dihypergraph(H, ax=ax3)
plt.close()


Expand All @@ -224,7 +387,9 @@ def test_draw_dihypergraph_with_str_labels_and_isolated_nodes():
[{"six"}, {}],
]
)
ax4 = xgi.draw_dihypergraph(DH1)

fig, ax4 = plt.subplots()
ax4 = xgi.draw_dihypergraph(DH1, ax=ax4)
assert len(ax4.lines) == 3
assert len(ax4.patches) == 4
assert len(ax4.collections) == DH1.num_edges + 1 - len(
Expand Down
38 changes: 37 additions & 1 deletion tests/drawing/test_draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from matplotlib import cm

import xgi
from xgi.drawing.draw import _CCW_sort, _color_arg_to_dict, _scalar_arg_to_dict
from xgi.drawing.draw import (
_CCW_sort,
_color_arg_to_dict,
_interp_draw_arg,
_scalar_arg_to_dict,
_draw_arg_to_arr,
)


def test_CCW_sort():
Expand All @@ -21,6 +27,36 @@ def test_CCW_sort():
)


def test_draw_arg_to_arr(edgelist4):

H = xgi.Hypergraph(edgelist4)

# arg stat
arg = H.nodes.degree
degree = _draw_arg_to_arr(arg)
assert np.all(degree == np.array([1, 2, 3, 2, 2]))

# arg dict
arg_dict = {1: 1, 2: 2, 3: 3, 4: 2, 5: 2}
degree = _draw_arg_to_arr(arg_dict)
assert np.all(degree == np.array([1, 2, 3, 2, 2]))


def test_interp_draw_arg(edgelist4):

arg = np.linspace(0, 10, num=10)
out = _interp_draw_arg(arg, 1, 9)
assert np.allclose(out, np.linspace(1, 9, num=10))

arg = np.linspace(0, 10, num=10)
out = _interp_draw_arg(arg, 0, 9)
assert np.allclose(out, np.linspace(0, 9, num=10))

arg = np.linspace(0, 10, num=10)
out = _interp_draw_arg(arg, 1, 11)
assert np.allclose(out, np.linspace(1, 11, num=10))


def test_scalar_arg_to_dict(edgelist4):
ids = [1, 2, 3]
min_val = 1
Expand Down
Loading
Loading