Skip to content

Commit

Permalink
highlight edges, nbs integration, fix nodealph bug with highlightnodes
Browse files Browse the repository at this point in the history
  • Loading branch information
wiheto committed Apr 25, 2022
1 parent 0cc771e commit e40d00a
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 41 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ plt.show()

```python
netplotbrain.plot(template='MNI152NLin2009cAsym',
templatestyle='surface',
templatestyle='glass',
view='AP',
frames=10,
frames=30,
gifduration=125,
nodes=nodes,
nodesize='centrality_measure1',
edges=edges,
Expand Down
9 changes: 5 additions & 4 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ nodecolorvminvmax | str, list | Scales continuous colormap between certain value

| Argument | Type | Description |
| --- | --- | --- |
edgecols | list | Edge columns names in edge dataframe. Default is i and j (specifying nodes).
edgealpha | float | Transparency of edges (default: 1).
edgecolumnames | list | Edge columns names in edge dataframe. Default is i and j (specifying nodes).
edgecolor | matplotlib coloring | Can be string (default 'black') or list of 3D/4D colors for each edge.
edgewidth | int, float | Specify width of edges. If auto, will plot the value in edge array (if array) or the weight column (if in pandas dataframe), otherwise 2.
edgeweights | str | String that specifies column in edge dataframe that contains weights. If numpy array is edge input, can be True (default) to specify edge weights.
edgealpha | float | Transparency of edges (default: 1).
edgehighlightbehaviour | str | Alternatives "both" or "any" or None. Governs edge dimming when highlightnodes is on. If both, then highlights only edges between highlighted nodes. If any, then only edges connecting any of the nodes are highlighted.
edgewidthscale | int, float | Scale the width of all edges by a factor (default: 1)
edgewidthscale | int, float | Scale the width of all edges by a factor (default: 1).
edgethresholddirection | str | can be "absabove", "above" (or ">"), "below" (or "<") to indicate thresholding behaviour. If absabove, then the thresholding behaviour is np.abs(edges) > edgethreshold.
edgethreshold | float | Edgeweight value to threshold edges.

### TEMPLATE KWARGS

Expand Down
2 changes: 1 addition & 1 deletion netplotbrain/__version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.4"
__version__ = "0.1.6-develop"
57 changes: 46 additions & 11 deletions netplotbrain/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
_select_single_hemisphere_nodes, _add_subplot_title, get_frame_input,\
_setup_legend, _process_edge_input, _process_node_input,\
_add_nodesize_legend, _add_nodecolor_legend, _init_figure, _check_axinput, \
_plot_gif
from .utils import _highlight_nodes, _get_colorby_colors, _set_axes_equal, _get_view, _load_profile, _nrows_in_fig
_plot_gif, _npedges2dfedges
from .utils import _highlight_nodes, _get_colorby_colors, _set_axes_equal, _get_view, \
_load_profile, _nrows_in_fig, _highlight_edges

def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L', frames=None, edges=None, template=None, templatestyle='filled',
arrowaxis='auto', arroworigin=None, edgecolor='k', nodesize=1, nodecolor='salmon', nodetype='circles', nodecolorby=None,
nodecmap='Dark2', edgeweights=None, nodeimg=None, hemisphere='both', title='auto', highlightnodes=None, showlegend=True, **kwargs):
nodecmap='Dark2', edgeweights=None, nodeimg=None, hemisphere='both', title='auto', highlightnodes=None, highlightedges=None, showlegend=True, **kwargs):
"""
Plot a network on a brain
Expand Down Expand Up @@ -40,7 +41,7 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
edges : dataframe, numpy array, or string
If dataframe, must include i, j columns (and weight, for weighted).
i and j specify indices in nodes.
See edgecols if you want to change the default column names.
See edgecolumnames if you want to change the default column names.
if numpy array, square adjacency array.
If string, can load a tsv file (tab seperator), assumes index column is the 0th column.
template : str or nibabel nifti
Expand All @@ -60,11 +61,18 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
If list, should match the size of views and contain strings to specify hemisphere.
Can be abbreviated to L, R and (empty string possible if both hemisphere plotted).
Between hemisphere edges are deleted.
highlightnodes : int, list, dict
highlightnodes : int, list, dict, str
List or int point out which nodes you want to be highlighted.
If dict, should be a single column-value pair.
Example: highlight all nodes of that, in the node dataframe, have a community
value of 1, the input will be {'community': 1}.
If string, should point to a column in the nodes dataframe and all non-zero values will be plotted.
highlightedges : array, dict, str
List or int point out which nodes you want to be highlighted.
If dict, should be a single column-value pair.
Example: highlight all nodes of that, in the edge dataframe, have a community
value of 1, the input will be {'community': 1}.
If string, should point to a column in the nodes dataframe and all non-zero values will be plotted.
highlightlevel : float
Intensity of the highlighting (opposite of alpha).
Value between 0 and 1, if 1, non-highlighted nodes are fully transparent.
Expand Down Expand Up @@ -99,8 +107,7 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
# Check and load the input of nodes and edges
nodes, nodeimg, profile['nodecols'] = _process_node_input(
nodes, nodeimg, profile['nodecols'], template, profile['templatevoxsize'])
edges, edgeweights = _process_edge_input(edges, edgeweights)

edges, edgeweights = _process_edge_input(edges, edgeweights, **profile)
# Set up legend row
# TODO compact code into subfunction
legends = None
Expand All @@ -124,7 +131,7 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
legendrows = len(legends)

# Figure setup
# Get number of non-legend rows
# Get number of non-legend rowsnon
nrows, view, frames = _nrows_in_fig(view, frames)
# Init figure, if not given as input
if ax is None:
Expand All @@ -136,9 +143,35 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
# Set nodecolor to colorby argument
if nodecolorby is not None:
nodecolor = _get_colorby_colors(nodes, nodecolorby, nodecmap, **profile)
if highlightnodes is not None and highlightedges is not None:
raise ValueError('Cannot highlight based on edges and nodes at the same time.')
if highlightnodes is not None:
nodecolor, highlightnodes = _highlight_nodes(
nodecolor, highlightnodes, profile['nodealpha'] = _highlight_nodes(
nodes, nodecolor, highlightnodes, **profile)
# if highlight edges is array, convert to pandas.
if isinstance(highlightedges, np.ndarray):
highlightedges = _npedges2dfedges(highlightedges)
if len(highlightedges != edges):
raise ValueError('Edges and highlight array are not compatible (different size matrices)')
# Get icol and jcol from profile as they may not be i, j, weight
icol, jcol, _ = profile['nodecols']
edges.sort_values(by=[icol, jcol], inplace=True)
highlightedges.sort_values(by=['i', 'j'], inplace=True)
highlightedges.rename(columns = {'weight': 'edge_to_highlight'}, inplace = True)
if all(edges[icol] == highlightedges['i']) and all(edges[jcol] == highlightedges['j']):
# Insert edge_to_highlight column into edges and set highligtedges to column name
edges['edge_to_highlight'] = highlightedges['edge_to_highlight']
highlightedges = 'edge_to_highlight'
else:
raise ValueError('Could not align edge input and highlight edge input')
if highlightedges is not None:
edgecolor, highlightedges, profile['edgealpha'] = _highlight_edges(edges, edgecolor, highlightedges, **profile)
# Get the nodes that are touched by highlighted edges
nodes_to_highlight = edges[highlightedges == 1]
nodes_to_highlight = np.unique(nodes_to_highlight[profile['edgecolumnnames']].values)
print(nodes_to_highlight)
nodecolor, highlightnodes, profile['nodealpha'] = _highlight_nodes(
nodes, nodecolor, nodes_to_highlight, **profile)

# Rename ax as ax_in and prespecfiy ax_out before forloop
ax_in = ax
Expand All @@ -151,9 +184,9 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
for fi in range(frames):
axind = (ri * nrows) + fi
# Get hemisphere for this frame
hemi_frame = get_frame_input(hemisphere, axind, ri, fi)
hemi_frame = get_frame_input(hemisphere, axind, ri, fi, str)
# Get title for this frame
title_frame = get_frame_input(title, axind, ri, fi)
title_frame = get_frame_input(title, axind, ri, fi, str)
# Set up subplot
if ax_in is None:
ax = fig.add_subplot(gridspec[ri, fi], projection='3d')
Expand Down Expand Up @@ -231,8 +264,10 @@ def plot(nodes=None, fig: Optional[plt.Figure] = None, ax=None, view: str = 'L',
ax_out.append(ax)
fig.tight_layout()

# If gif is requested, create the gif.
if profile['gif'] is True:
_plot_gif(fig, ax_out, profile['gifduration'], profile['savename'], profile['gifloop'])
# Save figure if set
elif profile['savename'] is not None:
if profile['savename'].endswith('.png'):
fig.savedfig(profile['savename'], dpi=profile['fig_dpi'])
Expand Down
28 changes: 23 additions & 5 deletions netplotbrain/plotting/plot_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd


def _npedges2dfedges(edges, edgethreshold=0):
def _npedges2dfedges(edges, edgethreshold=0, edgethresholddirection='absabove'):
"""
A function which transforms numpy array edges into dataframe.
Expand All @@ -14,8 +14,16 @@ def _npedges2dfedges(edges, edgethreshold=0):
n x n array of edges
edgethreshold : float
only find edges over a certain threshold
edgetype : str
Can be below, above or absabove. Default is absabove.
This argument says if keeping edges<edgethreshold, below(edge<threshold), or abs(edges) above
"""
ind = np.where(edges > edgethreshold)
if edgethresholddirection == 'above':
ind = np.where(edges > edgethreshold)
if edgethresholddirection == 'below':
ind = np.where(edges < edgethreshold)
if edgethresholddirection == 'absabove':
ind = np.where(np.abs(edges) > edgethreshold)
weights = edges[ind]
# Create dataframe
df = pd.DataFrame(data={'i': ind[0], 'j': ind[1], 'weight': weights})
Expand All @@ -26,7 +34,9 @@ def _get_edge_highlight_alpha(node_i, node_j, highlightnodes, **kwargs):
edgealpha = kwargs.get('edgealpha')
highlightlevel = kwargs.get('highlightlevel')
edgehighlightbehaviour = kwargs.get('edgehighlightbehaviour')
if highlightnodes is None or edgehighlightbehaviour is None:
if edgealpha is None:
pass
elif highlightnodes is None or edgehighlightbehaviour is None:
pass
elif node_i in highlightnodes and node_j in highlightnodes and edgehighlightbehaviour == 'both':
pass
Expand Down Expand Up @@ -82,16 +92,24 @@ def _plot_edges(ax, nodes, edges, edgewidth=None, edgecolor='k', highlightnodes=
# Convert highlightnodes binary list to index list
hl_idx = np.where(np.array(highlightnodes) == 1)[0]
# if dataframe
for _, row in edges.iterrows():
ecset = 0
for i, row in edges.iterrows():
# if row[edgecol[0]] != 0 and row[edgecol[1]] != 0:
if isinstance(edgecolor, np.ndarray):
if edgecolor.shape[0] == len(edges):
ec = edgecolor[i, :]
ecset = 1
if ecset == 0:
ec = edgecolor
if edgewidth is None:
ew = edgewidthscale
else:
ew = row[edgewidth] * edgewidthscale

if row[edgecol[0]] in nodes.index and row[edgecol[1]] in nodes.index:
ea = _get_edge_highlight_alpha(
row[edgecol[0]], row[edgecol[1]], hl_idx, **kwargs)
xp = nodes.loc[list((row[edgecol[0]], row[edgecol[1]]))]['x']
yp = nodes.loc[list((row[edgecol[0]], row[edgecol[1]]))]['y']
zp = nodes.loc[list((row[edgecol[0]], row[edgecol[1]]))]['z']
ax.plot(xp, yp, zp, color=edgecolor, linewidth=ew, alpha=ea)
ax.plot(xp, yp, zp, color=ec, linewidth=ew, alpha=ea)
35 changes: 23 additions & 12 deletions netplotbrain/plotting/process_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import matplotlib.pyplot as plt
from ..plotting import _npedges2dfedges, _get_nodes_from_nii


def get_frame_input(inputvar, axind, ri, fi):
def get_frame_input(inputvar, axind, ri, fi, exnotlist=True):
"""
Gets subplot varible.
Gets subplot variable.
The variable depends on whether the
input is a string, array or 2d array.
"""
if isinstance(inputvar, str):
var_frame = inputvar
elif isinstance(inputvar[0], str):
var_frame = inputvar[axind]
else:
var_frame = inputvar[ri][fi]
if exnotlist:
if not isinstance(inputvar, list):
var_frame = inputvar
elif not isinstance(inputvar[0], list):
var_frame = inputvar[axind]
else:
var_frame = inputvar[ri][fi]
return var_frame


Expand All @@ -40,18 +40,21 @@ def _process_node_input(nodes, nodeimg, nodecols, template, templatevoxsize):
return nodes, nodeimg, nodecols


def _process_edge_input(edges, edgeweights):
def _process_edge_input(edges, edgeweights, **kwargs):
"""
Takes the input edges and edgeweight.
Loads pandas dataframe if edges is string.
Creates pandas dataframe if edges is numpy array.
Sets defauly edgeweight to "weight" if not given.
Sets default edgeweight to "weight" if not given.
"""
edgethreshold = kwargs.get('edgethreshold')
edgethresholddirection = kwargs.get('edgethresholddirection')
if isinstance(edges, str):
edges = pd.read_csv(edges, sep='\t', index_col=0)
# Check input, if numpy array, make dataframe
if isinstance(edges, np.ndarray):
edges = _npedges2dfedges(edges)
edgeweights = 'weight'
# Set default behaviour of edgeweights
if isinstance(edges, pd.DataFrame):
if edgeweights is None or edgeweights is True:
Expand All @@ -60,7 +63,15 @@ def _process_edge_input(edges, edgeweights):
edgeweights = None
if edgeweights is not None and edgeweights not in edges:
raise ValueError('Edgeweights is specified and not in edges')

# If edgeweight and edgethreshold have been set
if edgeweights is not None and edgethreshold is not None:
if edgethresholddirection == 'absabove':
edges = edges[np.abs(edges[edgeweights]) > edgethreshold]
if edgethresholddirection == 'above' or edgethresholddirection == '>':
edges = edges[edges[edgeweights] > edgethreshold]
if edgethresholddirection == 'below' or edgethresholddirection == '<':
edges = edges[edges[edgeweights] < edgethreshold]

return edges, edgeweights


Expand Down
2 changes: 2 additions & 0 deletions netplotbrain/profiles/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"arrowlength": 10,
"edgealpha": 1,
"edgecolumnnames": ["i", "j"],
"edgethreshold": null,
"edgethresholddirection": "absabove",
"edgewidthscale": 1,
"edgehighlightbehaviour": "both",
"font": "DejaVu Sans",
Expand Down
4 changes: 2 additions & 2 deletions netplotbrain/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Util funcitons."""
from .plot_utils import _set_axes_equal, _set_axes_radius, _get_view, _rotate_data_to_viewingangle,\
_node_scale_vminvmax, _nrows_in_fig
from .coloring import _get_colorby_colors, _highlight_nodes, _colorarray_from_string, _detect_nodecolor_type
from .coloring import _get_colorby_colors, _highlight_nodes, _highlight_edges, _colorarray_from_string, _detect_nodecolor_type
from .settings import _load_profile


__all__ = ['_set_axes_equal', '_set_axes_radius', '_get_view', '_rotate_data_to_viewingangle', '_node_scale_vminvmax', '_nrows_in_fig']
__all__ += ['_get_colorby_colors', '_highlight_nodes', '_colorarray_from_string', '_detect_nodecolor_type']
__all__ += ['_get_colorby_colors', '_highlight_nodes', '_highlight_edges', '_colorarray_from_string', '_detect_nodecolor_type']
__all__ += ['_load_profile']
52 changes: 50 additions & 2 deletions netplotbrain/utils/coloring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import matplotlib.colors as pltcol
import numpy as np
import pandas as pd
from ..plotting import _npedges2dfedges


def _colorarray_from_string(cmap, ncolors):
"""Get colormap array from string that loops through colors."""
if cmap in pltcol.cnames or cmap[0] == '#':
if cmap in pltcol.ColorConverter.colors.keys():
colors = pltcol.to_rgba_array(pltcol.ColorConverter.colors[cmap])
colors = np.vstack([colors] * ncolors)
elif cmap[0] == '#':
colors = pltcol.to_rgba_array(cmap)
colors = np.vstack([colors] * ncolors)
else:
Expand All @@ -17,11 +21,24 @@ def _colorarray_from_string(cmap, ncolors):


def _highlight_nodes(nodes, nodecolor, highlightnodes, **kwargs):
"""
Returns
-------
nodecolor : array
a N x 4 color array for colouring of nodes where alpha is set here.
highlight_idx : array
Binary array of N index indicating which nodes are highlighted (for edge purposes)
"""
highlightlevel = kwargs.get('highlightlevel')
nodealpha = kwargs.get('nodealpha')
if isinstance(highlightnodes, dict):
highlight_idx = nodes[highlightnodes.keys()] == highlightnodes.values()
highlight_idx = np.squeeze(highlight_idx.values)
elif isinstance(highlightnodes, str):
if highlightnodes not in nodes:
raise ValueError('If highlightnodes is a str it must be a column in nodes')
highlightnodes = nodes[highlightnodes].values()
else:
highlight_idx = np.zeros(len(nodes))
highlight_idx[highlightnodes] = 1
Expand All @@ -32,7 +49,38 @@ def _highlight_nodes(nodes, nodecolor, highlightnodes, **kwargs):
[nodecolor, np.vstack([nodealpha]*len(nodecolor))])
# dim the non-highlighted nodes
nodecolor[highlight_idx == 0, 3] = nodealpha * (1 - highlightlevel)
return nodecolor, highlight_idx
# Nodealpha is now set in nodecolor, so set as None to avoid any later problems
nodealpha = None
return nodecolor, highlight_idx, nodealpha



def _highlight_edges(edges, edgecolor, highlightedges, **kwargs):
"""
"""
highlightlevel = kwargs.get('highlightlevel')
edgealpha = kwargs.get('edgealpha')
if isinstance(highlightedges, dict):
highlight_idx = edges[highlightedges.keys()] == highlightedges.values()
highlight_idx = np.squeeze(highlight_idx.values)
elif isinstance(highlightedges, str):
if highlightedges not in edges:
raise ValueError('If highlightnodes is a str it must be a column in nodes')
highlight_idx = np.zeros(len(edges))
highlight_idx[edges[highlightedges] != 0] = 1
else:
highlight_idx = np.zeros(len(edges))
highlight_idx[highlightedges] = 1
if isinstance(edgecolor, str):
edgecolor = _colorarray_from_string(edgecolor, len(edges))
if edgecolor.shape[1] == 3:
edgecolor = np.hstack(
[edgecolor, np.vstack([edgealpha]*len(edgecolor))])
# dim the non-highlighted edges
edgecolor[highlight_idx == 0, 3] = edgealpha * (1 - highlightlevel)
# Nodealpha is now set in nodecolor, so set as None to avoid any later problems
edgealpha = None
return edgecolor, highlight_idx, edgealpha


def assign_color(row, colordict):
Expand Down
Loading

0 comments on commit e40d00a

Please sign in to comment.