Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Improving the graphs.

  • Loading branch information...
commit afa58645cec7a2e71e96d0a016246845cbc829b8 1 parent c45d037
@rik0 authored
View
11 pynetsym/graph/_abstract.py
@@ -0,0 +1,11 @@
+
+from traits.api import HasTraits, implements, Instance
+
+from .interface import IGraph
+from pynetsym import identifiers_manager
+
+class AbstractGraph(HasTraits):
+ implements(IGraph)
+
+ index_store = Instance(identifiers_manager.IntIdentifierStore,
+ allow_none=False, args=())
View
6 pynetsym/graph/interface.py
@@ -1,11 +1,9 @@
from traits.api import Interface
class IGraph(Interface):
- def add_node(self, identifier):
+ def add_node(self):
"""
- Add specified node to the graph.
- @param identifier: the identifier of the node
- @type identifier: int
+ Add a node to the graph.
"""
def add_edge(self, source, target):
View
11 pynetsym/graph/nx_impl.py
@@ -2,9 +2,14 @@
import networkx as nx
from traits.api import HasTraits, Instance
from traits.api import DelegatesTo
+from traits.has_traits import implements
+from ._abstract import AbstractGraph
+
+
+class NxGraph(AbstractGraph):
+ implements(interface.IGraph)
-class NxGraph(HasTraits):
nx_graph = Instance(nx.Graph, allow_none=False)
number_of_nodes = DelegatesTo('nx_graph')
@@ -15,8 +20,8 @@ class NxGraph(HasTraits):
def __init__(self, graph_type=nx.Graph, data=None, **kwargs):
self.nx_graph = graph_type(data=data, **kwargs)
- def add_node(self, node):
- self.nx_graph.add_node(node)
+ def add_node(self):
+ self.nx_graph.add_node(self.index_store.take())
def add_edge(self, source, target):
self.nx_graph.add_edge(source, target)
View
26 pynetsym/graph/scipy_impl.py
@@ -1,26 +1,34 @@
from scipy import sparse
from traits.api import HasTraits, implements, Callable, Instance
from traits.api import Int
+from traits.trait_types import Set
from .interface import IGraph
+from ._abstract import AbstractGraph
-
-class ScipyGraph(HasTraits):
+class ScipyGraph(AbstractGraph):
implements(IGraph)
matrix_factory = Callable(sparse.lil_matrix)
matrix = Instance(
sparse.spmatrix, factory=matrix_factory,
allow_none=False)
- max_nodes = Int(0)
+
+ _nodes = Set(Int)
+
+ def _max_nodes(self):
+ return self.matrix.shape[0]
def __init__(self, max_nodes):
self.matrix = self.matrix_factory(
(max_nodes, max_nodes), dtype=bool)
- def add_node(self, node):
- assert node < self.max_nodes
+ def add_node(self):
+ node_index = self.index_store.take()
+ if node_index >= self._max_nodes():
+ self._enlarge(node_index)
+ self._nodes.add(node_index)
def add_edge(self, source, target):
# consider direct vs. undirected
@@ -28,16 +36,20 @@ def add_edge(self, source, target):
self.matrix[target, source] = True
def number_of_nodes(self):
- return self.matrix.shape[0]
+ return len(self._nodes)
def number_of_edges(self):
- assert self.matrix.nnx % 2 == 0
+ assert self.matrix.nnz % 2 == 0
return self.matrix.nnz / 2
def remove_edge(self, source, target):
self.matrix[source, target] = \
self.matrix[target, source] = False
+ def _enlarge(self, node_index):
+ self.matrix.reshape((node_index, node_index))
+
+
class DirectedScipyGraph(ScipyGraph):
def number_of_edges(self):
return self.matrix.nnz
View
5 tests/test_graph.py
@@ -15,4 +15,9 @@ def setParameters(self, graph_factory, *args):
def testEmpty(self):
self.assertEqual(0, self.graph.number_of_nodes())
+ self.assertEqual(0, self.graph.number_of_edges())
+
+ def testAddedNode(self):
+ self.graph.add_node()
+ self.assertEqual(1, self.graph.number_of_nodes())
self.assertEqual(0, self.graph.number_of_edges())
Please sign in to comment.
Something went wrong with that request. Please try again.