Skip to content

Commit

Permalink
Adding Transitive Reduction Function (Qiskit#923)
Browse files Browse the repository at this point in the history
* Implementing filter_nodes and filter_edges funcs

* Running fmt and clippy

* Fixed issue where errors were not being propagated up to Python. Created tests for filter_edges and filter_nodes for both PyGraph and PyDiGraph. Created release notes for the functions.

* Ran fmt, clippy, and tox

* Fixing release notes

* Fixing release notes again

* Fixing release notes again again

* Fixed release notes

* Fixed release notes. Changed Vec allocation. Expanded on documentation.

* ran cargo fmt and clippy

* working on adding different parallel edge behavior

* Fixing docs for filter functions

* Working on graph_adjacency_matrix

* Implementing changes to graph_adjacency_matrix and digraph_adjacency_matrix

* working on release notes

* Fixed release notes and docs

* Ran cargo fmt

* Ran cargo clippy

* Fixed digraph_adjacency_matrix, passes tests

* Removed mpl_draw from r
elease notes

* Changed if-else blocks in adjacency_matrix functions to match blocks. Wrote tests.

* Fixed tests to pass lint

* Added transitive reduction function to dag algo module

* Fixed issue with graph that have nodes removed. Function now returns index_map for cases where there were nodes removed. Added tests.

* Changing graph.nodes_removed to be false again. Return graph does not have removed nodes

* Adding requested changes:
- Fixing Docs
- Fixing Maps to only have capacity of node_count
- Fixing tests

* Adding function to DAG algorithsm index
  • Loading branch information
danielleodigie authored and raynelfss committed Aug 10, 2023
1 parent f6666dd commit fdc17eb
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/api/algorithm_functions/dag_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ DAG Algorithms
rustworkx.dag_weighted_longest_path_length
rustworkx.is_directed_acyclic_graph
rustworkx.layers
rustworkx.transitive_reduction
36 changes: 36 additions & 0 deletions releasenotes/notes/transitive-reduction-6db2b80351c15887.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
features:
- |
Added a new function, :func:`~.transitive_reduction` which returns the transtive reduction
of a given :class:`~rustworkx.PyDiGraph` and a dictionary with the mapping of indices from the given graph to the returned graph.
The given graph must be a Directed Acyclic Graph (DAG).
For example:
.. jupyter-execute::
from rustworkx import PyDiGraph
from rustworkx import transitive_reduction
graph = PyDiGraph()
a = graph.add_node("a")
b = graph.add_node("b")
c = graph.add_node("c")
d = graph.add_node("d")
e = graph.add_node("e")
graph.add_edges_from([
(a, b, 1),
(a, d, 1),
(a, c, 1),
(a, e, 1),
(b, d, 1),
(c, d, 1),
(c, e, 1),
(d, e, 1)
])
tr, _ = transitive_reduction(graph)
list(tr.edge_list())
Ref: https://en.wikipedia.org/wiki/Transitive_reduction
93 changes: 92 additions & 1 deletion src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@

mod longest_path;

use super::DictMap;
use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use rustworkx_core::dictmap::InitWithHasher;
use std::cmp::Ordering;
use std::collections::BinaryHeap;

use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode};
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};

use rustworkx_core::traversal::dfs_edges;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -637,3 +642,89 @@ pub fn collect_bicolor_runs(

Ok(block_list)
}

/// Returns the transitive reduction of a directed acyclic graph
///
/// The transitive reduction of :math:`G = (V,E)` is a graph :math:`G\prime = (V,E\prime)`
/// such that for all :math:`v` and :math:`w` in :math:`V` there is an edge :math:`(v, w)` in
/// :math:`E\prime` if and only if :math:`(v, w)` is in :math:`E`
/// and there is no path from :math:`v` to :math:`w` in :math:`G` with length greater than 1.
///
/// :param PyDiGraph graph: A directed acyclic graph
///
/// :returns: a directed acyclic graph representing the transitive reduction, and
/// a map containing the index of a node in the original graph mapped to its
/// equivalent in the resulting graph.
/// :rtype: Tuple[PyGraph, dict]
///
/// :raises PyValueError: if ``graph`` is not a DAG

#[pyfunction]
#[pyo3(text_signature = "(graph, /)")]
pub fn transitive_reduction(
graph: &digraph::PyDiGraph,
py: Python,
) -> PyResult<(digraph::PyDiGraph, DictMap<usize, usize>)> {
let g = &graph.graph;
let mut index_map = DictMap::with_capacity(g.node_count());
if !is_directed_acyclic_graph(graph) {
return Err(PyValueError::new_err(
"Directed Acyclic Graph required for transitive_reduction",
));
}
let mut tr = StablePyGraph::<Directed>::with_capacity(g.node_count(), 0);
let mut descendants = DictMap::new();
let mut check_count = HashMap::with_capacity(g.node_count());

for node in g.node_indices() {
let i = node.index();
index_map.insert(
node,
tr.add_node(graph.get_node_data(i).unwrap().clone_ref(py)),
);
check_count.insert(node, graph.in_degree(i));
}

for u in g.node_indices() {
let mut u_nbrs: IndexSet<NodeIndex> = g.neighbors(u).collect();
for v in g.neighbors(u) {
if u_nbrs.contains(&v) {
if !descendants.contains_key(&v) {
let dfs = dfs_edges(&g, Some(v));
descendants.insert(v, dfs);
}
for desc in &descendants[&v] {
u_nbrs.remove(&NodeIndex::new(desc.1));
}
}
*check_count.get_mut(&v).unwrap() -= 1;
if check_count[&v] == 0 {
descendants.remove(&v);
}
}
for v in u_nbrs {
tr.add_edge(
*index_map.get(&u).unwrap(),
*index_map.get(&v).unwrap(),
graph
.get_edge_data(u.index(), v.index())
.unwrap()
.clone_ref(py),
);
}
}
return Ok((
digraph::PyDiGraph {
graph: tr,
node_removed: false,
multigraph: graph.multigraph,
attrs: py.None(),
cycle_state: algo::DfsSpace::default(),
check_cycle: graph.check_cycle,
},
index_map
.iter()
.map(|(k, v)| (k.index(), v.index()))
.collect::<DictMap<usize, usize>>(),
));
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(dag_longest_path_length))?;
m.add_wrapped(wrap_pyfunction!(dag_weighted_longest_path))?;
m.add_wrapped(wrap_pyfunction!(dag_weighted_longest_path_length))?;
m.add_wrapped(wrap_pyfunction!(transitive_reduction))?;
m.add_wrapped(wrap_pyfunction!(number_connected_components))?;
m.add_wrapped(wrap_pyfunction!(connected_components))?;
m.add_wrapped(wrap_pyfunction!(is_connected))?;
Expand Down
76 changes: 76 additions & 0 deletions tests/rustworkx_tests/digraph/test_transitive_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx


class TestTransitiveReduction(unittest.TestCase):
def test_tr1(self):
graph = rustworkx.PyDiGraph()
a = graph.add_node("a")
b = graph.add_node("b")
c = graph.add_node("c")
d = graph.add_node("d")
e = graph.add_node("e")
graph.add_edges_from(
[(a, b, 1), (a, d, 1), (a, c, 1), (a, e, 1), (b, d, 1), (c, d, 1), (c, e, 1), (d, e, 1)]
)
tr, _ = rustworkx.transitive_reduction(graph)
self.assertCountEqual(list(tr.edge_list()), [(0, 2), (0, 1), (1, 3), (2, 3), (3, 4)])

def test_tr2(self):
graph2 = rustworkx.PyDiGraph()
a = graph2.add_node("a")
b = graph2.add_node("b")
c = graph2.add_node("c")
graph2.add_edges_from(
[
(a, b, 1),
(b, c, 1),
(a, c, 1),
]
)
tr2, _ = rustworkx.transitive_reduction(graph2)
self.assertCountEqual(list(tr2.edge_list()), [(0, 1), (1, 2)])

def test_tr3(self):
graph3 = rustworkx.PyDiGraph()
graph3.add_nodes_from([0, 1, 2, 3])
graph3.add_edges_from([(0, 1, 1), (0, 2, 1), (0, 3, 1), (1, 2, 1), (1, 3, 1)])
tr3, _ = rustworkx.transitive_reduction(graph3)
self.assertCountEqual(list(tr3.edge_list()), [(0, 1), (1, 2), (1, 3)])

def test_tr_with_deletion(self):
graph = rustworkx.PyDiGraph()
a = graph.add_node("a")
b = graph.add_node("b")
c = graph.add_node("c")
d = graph.add_node("d")
e = graph.add_node("e")

graph.add_edges_from(
[(a, b, 1), (a, d, 1), (a, c, 1), (a, e, 1), (b, d, 1), (c, d, 1), (c, e, 1), (d, e, 1)]
)

graph.remove_node(3)

tr, index_map = rustworkx.transitive_reduction(graph)

self.assertCountEqual(list(tr.edge_list()), [(0, 1), (0, 2), (2, 3)])
self.assertEqual(index_map[4], 3)

def test_tr_error(self):
digraph = rustworkx.generators.directed_cycle_graph(1000)
with self.assertRaises(ValueError):
rustworkx.transitive_reduction(digraph)

0 comments on commit fdc17eb

Please sign in to comment.