diff --git a/tests/classes/test_hypergraph.py b/tests/classes/test_hypergraph.py index edb69e40..2aa16124 100644 --- a/tests/classes/test_hypergraph.py +++ b/tests/classes/test_hypergraph.py @@ -418,3 +418,20 @@ def test_duplicate_nodes(edgelist1): if 1 not in members and 2 in members: H.add_node_to_edge(edgeid, 1) assert list(H.nodes.duplicates()) == [1, 2] + + +def test_remove_node_weak(edgelist1): + H = xgi.Hypergraph(edgelist1) + assert 1 in H + H.remove_node(1) + assert 1 not in H + with pytest.raises(IDNotFound): + H.remove_node(10) + + +def test_remove_node_strong(edgelist1): + H = xgi.Hypergraph(edgelist1) + assert 1 in H + H.remove_node(1, strong=True) + assert 1 not in H + assert 0 not in H.edges diff --git a/xgi/classes/hypergraph.py b/xgi/classes/hypergraph.py index e98d747d..32d16fc8 100644 --- a/xgi/classes/hypergraph.py +++ b/xgi/classes/hypergraph.py @@ -322,13 +322,21 @@ def add_nodes_from(self, nodes_for_adding, **attr): self._node_attr[n] = self._node_attr_dict_factory() self._node_attr[n].update(newdict) - def remove_node(self, n): - """Remove a single node and all adjacent hyperedges. + def remove_node(self, n, strong=False): + """Remove a single node. + + The removal may be weak (default) or strong. In weak removal, the node is + removed from each of its containing edges. If it is contained in any singleton + edges, then these are also removed. In strong removal, all edges containing the + node are removed, regardless of size. Parameters ---------- n : node - A node in the hypergraph + A node in the hypergraph + + strong : bool (default False) + Whether to execute weak or strong removal. Raises ------ @@ -343,11 +351,17 @@ def remove_node(self, n): edge_neighbors = self._node[n] del self._node[n] del self._node_attr[n] - for edge in edge_neighbors: - self._edge[edge].remove(n) - if not self._edge[edge]: + + if strong: + for edge in edge_neighbors: del self._edge[edge] del self._edge_attr[edge] + else: # weak removal + for edge in edge_neighbors: + self._edge[edge].remove(n) + if not self._edge[edge]: + del self._edge[edge] + del self._edge_attr[edge] def remove_nodes_from(self, nodes): """Remove multiple nodes.