In [25]:
import networkx as nx
from collections import deque

def generate_steiner_trees(G, source, terminals):
    terminals_set = set(terminals)
    steiner_trees = []

    # Start BFS from the source
    queue = deque([(source, nx.DiGraph())])
    while queue:
        current_node, tree = queue.popleft()

        # If current_node is a terminal and all terminals are reached, save the tree
        if current_node in terminals_set:
            if terminals_set.issubset(tree.nodes):
                steiner_trees.append(tree.edges())
                continue

        # Expand to neighboring nodes
        for neighbor in G.neighbors(current_node):
            if neighbor not in tree.nodes:  # Avoid cycles
                # Create new tree including this edge
                new_tree = nx.DiGraph(tree.edges(data=True))
                new_tree.add_edge(current_node, neighbor)
                
                # Check if new_tree is still a valid DAG
                if nx.is_directed_acyclic_graph(new_tree):
                    queue.append((neighbor, new_tree))

    return steiner_trees

# Define the graph
G = nx.Graph()
G.add_edges_from([(0, 5), (0, 3), (3, 4), (4, 5), (4, 6), (5, 2), (3, 1), (1, 6), (6, 2)])
source = 0
receivers = [1, 2]

# Generate Steiner trees
steiner_trees = generate_steiner_trees(G, source, receivers)
print(f"Found {len(steiner_trees)} Steiner trees.")
for index, tree in enumerate(steiner_trees):
    print(f"Tree {index + 1}: {list(tree)}")


Found 1 Steiner trees.
Tree 1: [(0, 5), (0, 3), (5, 4), (5, 2), (3, 1)]
