In [13]:
pip install networkx matplotlib



In [15]:
import matplotlib.pyplot as plt
import networkx as nx

In [16]:
class Node:
    def __init__(self, key, color, parent, left, right):
        self.key = key
        self.color = color
        self.parent = parent
        self.left = left
        self.right = right

class RedBlackTree:
    def __init__(self):
        self.NIL = Node(None, "BLACK", None, None, None)
        self.root = self.NIL

    def insert(self, key):
        node = Node(key, "RED", None, self.NIL, self.NIL)
        self._insert(node)
        self._fix_insert(node)

    def _insert(self, node):
        current = self.root
        parent = None

        while current != self.NIL:
            parent = current

            if node.key < current.key:
                current = current.left
            else:
                current = current.right

        node.parent = parent

        if parent is None:
            self.root = node
        elif node.key < parent.key:
            parent.left = node
        else:
            parent.right = node

    def _fix_insert(self, node):
        while node.parent.color == "RED":
            if node.parent == node.parent.parent.left:
                uncle = node.parent.parent.right

                if uncle.color == "RED":
                    node.parent.color = "BLACK"
                    uncle.color = "BLACK"
                    node.parent.parent.color = "RED"
                    node = node.parent.parent
                else:
                    if node == node.parent.right:
                        node = node.parent
                        self._left_rotate(node)

                    node.parent.color = "BLACK"
                    node.parent.parent.color = "RED"
                    self._right_rotate(node.parent.parent)
            else:
                uncle = node.parent.parent.left

                if uncle.color == "RED":
                    node.parent.color = "BLACK"
                    uncle.color = "BLACK"
                    node.parent.parent.color = "RED"
                    node = node.parent.parent
                else:
                    if node == node.parent.left:
                        node = node.parent
                        self._right_rotate(node)

                    node.parent.color = "BLACK"
                    node.parent.parent.color = "RED"
                    self._left_rotate(node.parent.parent)

        self.root.color = "BLACK"

    def _left_rotate(self, x):
        y = x.right
        x.right = y.left

        if y.left != self.NIL:
            y.left.parent = x

        y.parent = x.parent

        if x.parent is None:
            self.root = y
        elif x == x.parent.left:
            x.parent.left = y
        else:
            x.parent.right = y

        y.left = x
        x.parent = y

    def _right_rotate(self, y):
        x = y.left
        y.left = x.right

        if x.right != self.NIL:
            x.right.parent = y

        x.parent = y.parent

        if y.parent is None:
            self.root = x
        elif y == y.parent.right:
            y.parent.right = x
        else:
            y.parent.left = x

        x.right = y
        y.parent = x

    def inorder_traversal(self, node):
        if node != self.NIL:
            self.inorder_traversal(node.left)
            print(f"{node.key} ({node.color})", end=" ")
            self.inorder_traversal(node.right)



In [17]:
class RedBlackTreeVisualization:
    def __init__(self, tree):
        self.tree = tree
        self.G = nx.DiGraph()

    def _build_graph(self, node, parent=None, parent_direction=None):
        if node is not None and node.key is not None:
            self.G.add_node(node.key, color=node.color)

            if parent is not None:
                self.G.add_edge(parent, node.key, color="red" if parent_direction == "left" else "black")

            self._build_graph(node.left, node.key, "left")
            self._build_graph(node.right, node.key, "right")

    def visualize_tree(self):
        self._build_graph(self.tree.root)
        pos = self._graph_pos()
        node_colors = [self.G.nodes[node]["color"] for node in self.G.nodes]
        edge_colors = [self.G.edges[edge]["color"] for edge in self.G.edges]

        nx.draw(self.G, pos, with_labels=True, node_color=node_colors, edge_color=edge_colors)
        plt.show()

    def _graph_pos(self):
        pos = {}
        levels = self._calculate_levels(self.tree.root)
        for node in self.G.nodes:
            level = levels[node]
            pos[node] = (level, -level)

        return pos

    def _calculate_levels(self, node, level=0, levels=None):
        if levels is None:
            levels = {}

        if node is not None:
            levels[node.key] = level
            self._calculate_levels(node.left, level + 1, levels)
            self._calculate_levels(node.right, level + 1, levels)

        return levels

In [18]:
tree = RedBlackTree()
keys = [7, 3, 18, 10, 22, 8, 11, 26, 2, 6, 13]

for key in keys:
  tree.insert(key)

visualizer = RedBlackTreeVisualization(tree)
visualizer.visualize_tree()

AttributeError: ignored