Skip to content

Commit

Permalink
Merge 13554c2 into b7b350b
Browse files Browse the repository at this point in the history
  • Loading branch information
fail committed Jan 23, 2019
2 parents b7b350b + 13554c2 commit 1be337c
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 4 deletions.
117 changes: 113 additions & 4 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint: disable=invalid-name

import itertools

import networkx
import numpy as np


Expand Down Expand Up @@ -778,11 +778,120 @@ def out_of_bounds(self, pos):

class NetworkGrid:
""" Network Grid where each node contains zero or more agents. """
def __init__(self, G=None, generator="Graph", args=()):
"""Default values create empty graph"""
if G is not None:
self.G = G
elif generator is not None:
if isinstance(args, tuple):
self.G = getattr(networkx, generator)(*args)
else:
self.G = getattr(networkx, generator)(args)

def __init__(self, G):
self.G = G
for node_id in self.G.nodes:
G.nodes[node_id]['agent'] = list()
self.G.nodes[node_id]['agent'] = list()

def add_nodes(self, node_ids):
"""adds nodes to the graph"""
if isinstance(node_ids, list):
for node_id in node_ids:
self._add_node(node_id)
else:
self._add_node(node_ids)

def _add_edge(self, edge):
"""adds an edge to the graph and initialises the agent list in any new nodes"""
self.G.add_edge(*edge)
for node_id in edge:
if not isinstance(self.G.nodes[node_id]['agent'], list):
self.G.nodes[node_id]['agent'] = list()

def _add_node(self, node_id):
"""adds a node to the graph and initialises the agent list"""

self.G.add_node(node_id)
self.G.nodes[node_id]['agent'] = list()

def _remove_node(self, node_id):
"""removes a node from the graph and returns a list of it's agents"""
if self.G.has_node(node_id):
agents = self.G.nodes[node_id]['agent']
self.G.remove_node(node_id)
return agents
else:
raise ValueError(str(node_id) + ": node not found")

def _remove_edge(self, node_ids):
"""removes an edge between nodes on the graph"""

if self.G.has_edge(*node_ids):
self.G.remove_edge(*node_ids)
else:
raise ValueError(str(node_ids) + ": edge not found")

def add_edges(self, node_ids):
"""adds edges between each node_ids tuple to the graph"""

if isinstance(node_ids, list):
for edge in node_ids:
if not isinstance(edge, tuple):
raise TypeError("node_ids must be a list of tuples or a tuple")
else:
self._add_edge(edge)

elif isinstance(node_ids, tuple):
self._add_edge(node_ids)
else:
raise TypeError("node_ids must be a list of tuples or a tuple")

def remove_nodes(self, node_ids):
"""removes the nodes from node_ids from the graph and returns their agents"""
agents = []
if isinstance(node_ids, list):
for node_id in node_ids:
agents.extend(self._remove_node(node_id))
else:
agents.extend(self._remove_node(node_ids))

return agents

def remove_edges(self, node_ids):
"""removes the edges between the node_ids tuples from the graph"""
if isinstance(node_ids, list):
for edge in node_ids:
if not isinstance(edge, tuple):
raise TypeError("node_ids must be a list of tuples or a tuple")
else:
if self.G.has_edge(*edge):
self._remove_edge(edge)
else:
raise ValueError("Edge not found")

elif isinstance(node_ids, tuple):
if self.G.has_edge(*node_ids):
self._remove_edge(node_ids)
else:
raise ValueError("Edge not found")
else:
raise TypeError("node_ids must be a list of tuples or a tuple")

def label_edge(self, node_ids, label, value=""):
"""Adds a label to an edge which can also have a value (i.e. for weight, colour)"""

if not isinstance(node_ids, tuple):
raise TypeError("node_ids must be a tuple")
else:
if self.G.has_edge(*node_ids):
self.G[node_ids[0]][node_ids[1]][label] = value
else:
raise ValueError("edge not found")

def label_node(self, node_id, label, value=""):
"""Adds a label to a node which can also have a value"""
if(node_id in self.G):
self.G.node[node_id][label] = value
else:
raise ValueError("node not found")

def place_agent(self, agent, node_id):
""" Place a agent in a node. """
Expand Down
203 changes: 203 additions & 0 deletions tests/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import networkx as nx
import numpy as np
import pytest
from random import choice, randint

from mesa.space import ContinuousSpace
from mesa.space import SingleGrid
Expand Down Expand Up @@ -413,5 +415,206 @@ def test_get_all_cell_contents(self):
self.agents[2]]


class TestNetworkXWrappers(unittest.TestCase):
GRAPH_SIZE = 10
SINGLE_NODE = "node"
MULTI_NODE = [i for i in range(1, 10)]

def test_init(self):
"""
Create a grid with a graph generator and populate with Mock Agents.
"""
self.space = NetworkGrid(generator="Graph")
assert nx.is_isomorphic(self.space.G, nx.Graph())

self.space = NetworkGrid(generator="scale_free_graph", args=(10))
assert self.space.G.number_of_nodes() == 10
for node in list(self.space.G.nodes):
assert type(self.space.G.nodes[node]['agent']) is list

def setUp(self):
"""
Create a test network grid and populate with Mock Agents
"""
G = nx.complete_graph(TestNetworkXWrappers.GRAPH_SIZE)
self.space = NetworkGrid(G)
self.agents = []
for i, pos in enumerate(TEST_AGENTS_NETWORK_MULTIPLE):
a = MockAgent(i, None)
self.agents.append(a)
self.space.place_agent(a, pos)

def test_add_nodes(self):
"""
Check that nodes are added to the graph and initialised
"""
# Add 1 node
self.space.add_nodes(TestNetworkXWrappers.SINGLE_NODE)
assert TestNetworkXWrappers.SINGLE_NODE in self.space.G
assert type(self.space.G.nodes[TestNetworkXWrappers.SINGLE_NODE]['agent']) is list

# Add multiple nodes
self.space.add_nodes(TestNetworkXWrappers.MULTI_NODE)
for node in TestNetworkXWrappers.MULTI_NODE:
assert node in self.space.G
assert type(self.space.G.nodes[node]['agent']) is list

def test_add_edges(self):
"""
Check that edges are added, existing nodes agents aren't overwritten
"""
# Add one edge
node1 = choice(list(self.space.G.nodes()))
node2 = choice(list(self.space.G.nodes()))
a = MockAgent(len(self.agents) + 1, None)
self.agents.append(a)
self.space.place_agent(a, node1)
edge = (node1, node2)
self.space.add_edges(edge)
assert self.space.G.has_edge(node1, node2)
assert a in self.space.G.nodes[node1]['agent']

# Add multiple edges
nodes = []
for i in range(0, 4):
unique = False
while not unique:
chosen = choice(list(self.space.G.nodes()))
if chosen not in nodes:
nodes.append(chosen)
unique = True
a = MockAgent(len(self.agents) + 1, None)
self.space.place_agent(a, nodes[0])

self.space.add_edges([(nodes[0], nodes[1]), (nodes[2], nodes[3])])
assert self.space.G.has_edge(nodes[0], nodes[1])
assert self.space.G.has_edge(nodes[2], nodes[3])
assert a in self.space.G.nodes[nodes[0]]['agent']

# Try to add not an edge
with pytest.raises(Exception):
self.space.add_edges("strings are not a valid edge")

def test_remove_nodes(self):
"""
Check that node is removed and all agents on node are returned for
redeployment
"""
# remove one node
node = choice(list(self.space.G.nodes()))
a = MockAgent(len(self.agents) + 1, None)
self.agents.append(a)
self.space.place_agent(a, node)
ays = self.space.G.nodes[node]['agent']
agents = self.space.remove_nodes(node)
assert agents == ays
assert node not in self.space.G

# Remove multiple nodes
nodes = [choice(list(self.space.G.nodes())),
choice(list(self.space.G.nodes()))]
a = [MockAgent(len(self.agents) + 1, None), MockAgent(len(self.agents) + 2, None)]
self.agents.extend(a)
ays = []
for i, ag in enumerate(a):
self.space.place_agent(ag, nodes[i])
ays.extend(self.space.G.nodes[nodes[i]]['agent'])
agents = self.space.remove_nodes(nodes)
assert len(agents) == len(ays)
for ag in a:
assert ag not in self.space.G

# Try to remove a node that isn't there
present = True
while present:
chosen = randint(0, 1000)
if chosen not in self.space.G:
present = False
with pytest.raises(Exception):
agents = self.space.remove_nodes(chosen)

def test_remove_edges(self):
"""
Check that correct edges are removed
"""
# remove one edge
edge = (choice(list(self.space.G.nodes())),
choice(list(self.space.G.nodes())))
if not self.space.G.has_edge(*edge):
self.space.add_edges(edge)
self.space.remove_edges(edge)
assert self.space.G.has_edge(*edge) is False

# remove multiple edges
edges = []
for i in range(0, 4):
unique = False
edge = ()
while(not unique):
edge = (choice(list(self.space.G.nodes())), choice(list(self.space.G.nodes())))
if edge not in edges:
unique = True
edges.append(edge)
if not self.space.G.has_edge(*edges[i]):
self.space.add_edges(edges[i])

self.space.remove_edges(edges)
for edge in edges:
assert self.space.G.has_edge(*edge) is False

# try to remove not an edge
edge = "strings are not valid edges"
with pytest.raises(Exception):
self.space.remove_edges(edge)

def test_label_edge(self):
"""
Check that edge is labeled and apply different values to label
"""
# label an edge
edge = (choice(list(self.space.G.nodes())),
choice(list(self.space.G.nodes())))
self.space.add_edges(edge)
self.space.label_edge(edge, "some attribute")
assert "some attribute" in self.space.G[edge[0]][edge[1]]

# label an edge with a value
edge = (choice(list(self.space.G.nodes())),
choice(list(self.space.G.nodes())))
self.space.add_edges(edge)
self.space.label_edge(edge, "attribute", "value")
assert "attribute" in self.space.G[edge[0]][edge[1]]
assert self.space.G[edge[0]][edge[1]]['attribute'] == "value"

# try to label an edge that's not there
edge = (choice(list(self.space.G.nodes())),
choice(list(self.space.G.nodes())))
if self.space.G.has_edge(*edge):
self.space.remove_edges(edge)
with pytest.raises(Exception):
self.space.label_edge(edge)

def test_label_node(self):
"""
Check that node is labeled and apply different values
"""
# label a node
node = choice(list(self.space.G.nodes()))
self.space.label_node(node, "attribute")
assert "attribute" in self.space.G.nodes[node]

# label a node with a value
node = choice(list(self.space.G.nodes()))
self.space.label_node(node, "attribute", "value")
assert "attribute" in self.space.G.nodes[node]
assert self.space.G.nodes[node]["attribute"] == "value"

# try to label a node that isn't there
node = choice(list(self.space.G.nodes()))
self.space.remove_nodes(node)
with pytest.raises(Exception):
self.space.label_node(node, "attribute")


if __name__ == '__main__':
unittest.main()

0 comments on commit 1be337c

Please sign in to comment.