In [3]:
class STNode:
    def __init__(self,d):
        self.data = d
        self.left = self.right = self.mid = None
        self.mult = 0
        self.dlistPtr = None
        self.parent = None

    def binary_search(self, d):
        if self.data == d: return self
        if self.data > d: 
            if self.left == None: return self
            return self.left.binary_search(d)
        if self.data < d: 
            if self.right == None: return self
            return self.right.binary_search(d) 

    def __str__(self):
        return "("+str(self.data)+", "+str(self.mult)+")"   

    def strFromTop(self):
        ptr = self
        s = str(self.data)
        while ptr.parent != None and ptr.parent.mid != ptr:
            ptr = ptr.parent
        if ptr.parent == None: return s
        return ptr.parent.strFromTop()+s
   
    def strTree(self):  
        st = "("+str(self.data)+", "+str(self.mult)+")"
        if self.left == self.mid == self.right == None: return st
        st += " -> ["
        if self.left != None:
            st += self.left.strTree()
        else: st += "□"
        if self.mid != None:
            st += ", "+self.mid.strTree()
        else: st += ", □"
        if self.right != None:
            st += ", "+self.right.strTree()
        else: st += ", □"
        return st + "]"
    
class StringTree:
    def __init__(self):
        self.root = None
        self.size = 0
        self.dlist = DLinkedList()
        
    def __str__(self):
        if self.root == None: return "empty"
        return self.root.strTree()

    def add(self,st):
        if st == "": return
        dlistPtr = self.updateDList(st)
        if self.root == None: 
            self.root = STNode(st[0])
        ptr = self.root
        for i in range(len(st)):
            character = st[i]
            found_node = ptr.binary_search(character)
            if character < found_node.data:
                found_node.left = STNode(character)
                found_node.left.parent = found_node
                ptr = found_node.left
            elif character > found_node.data:
                found_node.right = STNode(character)
                found_node.right.parent = found_node
                ptr = found_node.right
            else:
                ptr = found_node
            if i < len(st)-1:
                if ptr.mid == None:
                    ptr.mid = STNode(st[i+1])
                    ptr.mid.parent = ptr
                ptr = ptr.mid
        ptr.mult += 1
        if ptr.mult == 1: 
            ptr.dlistPtr = dlistPtr
        self.size += 1
    
    def addAll(self,A):
        for x in A: self.add(x)

    def printElems(self):
        ptr = self.dlist.head
        st = ""
        while ptr != None:
            st += ptr.data
            if ptr.next != None:
                st += ", "
            ptr = ptr.next
        print(st)

    # returns the smallest string in the tree (None if tree empty)
    def min(self):
        if self.root == None: return None
        return self._min(self.root).strFromTop()

    # returns the lexicographically smallest node in the tree rooted at node
    def _min(self, node):
        ptr = node
        while True: 
            while ptr.left != None: ptr = ptr.left
            if ptr.mult > 0: return ptr 
            ptr = ptr.mid

    def max(self):
        if self.root == None: return None
        return self._max(self.root).strFromTop()

    def _max(self, node):
        ptr = node
        while True:
            while ptr.right is not None:
                ptr = ptr.right

            if ptr.mid is not None:
                ptr = ptr.mid
            else:
                return ptr

    def count(self, st):
        return self._count(self.root, st)

    def _count(self, node, st):
        if node is None:
            return 0
        if node.data == st[0]:
            if len(st) == 1:
                return node.mult
            return self._count(node.mid, st[1:])
        elif node.data > st[0]:
            return self._count(node.left, st)
        else:
            return self._count(node.right, st)

    def updateDList(self, st):
        if self.dlist.length == 0:
            return self.dlist.insertLeft(st,None)
        ptr = self.dlist.head
        while ptr != None and ptr.data < st:
            ptr = ptr.next
        if ptr == None:
            return self.dlist.append(st)
        if ptr.data == st: return None
        return self.dlist.insertLeft(st,ptr)
        
    def succ(self, st): 
        node = self.root
        successor = None

        for character in st:
            if node is None:
                return None

            if character < node.data:
                successor = node
                node = node.left
            elif character > node.data:
                node = node.right
            else:
                node = node.mid
    
        if node is not None:
            successor = self._find_successor(node.right)

        if successor:
            return successor.strFromTop()
        return None

    def _find_successor(self, node):
        if node is None:
            return None
        while node.left is not None:
            node = node.left
        return node

    def pred(self, st): 
        node = self.root
        predecessor = None

        for character in st:
            if node is None:
                return None

            if character < node.data:
                node = node.left
            elif character > node.data:
                predecessor = node
                node = node.right
            else:
                node = node.mid

        if node is not None:
            predecessor = self._find_predecessor(node.left)

        if predecessor:
            return predecessor.strFromTop()
        return None

    def _find_predecessor(self, node):
        if node is None:
            return None
        while node.right is not None:
            node = node.right
        return node    
    
    def remove(self, st):
        if self.root is None:
            return None

        node_to_remove = self._remove(st, self.root)

        if node_to_remove is not None:
            node_to_remove.mult -= 1
            self.size -= 1

            if node_to_remove.mult == 0 and node_to_remove.mid is None:
                self.remove_node(node_to_remove)

            # if node still in tree update
            if node_to_remove.mult > 0 and node_to_remove.dlistPtr is not None:
                node_to_remove.dlistPtr = self.updateDList(st)

    def _remove(self, st, node):
        if node is None:
            return None

        if st < node.data:
            return self._remove(st, node.left)
        elif st > node.data:
            return self._remove(st, node.right)
        else:  # st == node.data
            return node

    def remove_node(self, node):
        if node is None:
            return

        if node.left is None and node.right is None:
            self.remove_leaf(node)
        elif node.left is not None and node.right is None:
            self.remove_node_with_left_child(node)
        elif node.left is None and node.right is not None:
            self.remove_node_with_right_child(node)
        else:
            # node has two children
            successor = self._find_min(node.right)
            node.data, successor.data = successor.data, node.data
            node.dlistPtr = successor.dlistPtr  # Update the dlist pointer
            self.remove_node(successor)

    def remove_leaf(self, node):
        if node.parent is None:
            self.root = None
        elif node.parent.left == node:
            node.parent.left = None
        elif node.parent.right == node:
            node.parent.right = None

    def remove_node_with_left_child(self, node):
        if node.parent is None:
            self.root = node.left
            node.left.parent = None
        elif node.parent.left == node:
            node.parent.left = node.left
            node.left.parent = node.parent
        elif node.parent.right == node:
            node.parent.right = node.left
            node.left.parent = node.parent

    def remove_node_with_right_child(self, node):
        if node.parent is None:
            self.root = node.right
            node.right.parent = None
        elif node.parent.left == node:
            node.parent.left = node.right
            node.right.parent = node.parent
        elif node.parent.right == node:
            node.parent.right = node.right
            node.right.parent = node.parent

    def _find_min(self, node):
        while node.left is not None:
            node = node.left
        return node
        
    def updateDList2(self, st):
        if self.dlist.length == 0:
            return self.dlist.insertLeft(st, None)

        node = self.root
        parent = None
        found_node = None

        for character in st:
            if node is None:
                return None

            parent = node
            if character < node.data:
                node = node.left
            elif character > node.data:
                node = node.right
            else:
                found_node = node
                node = node.mid

        if found_node and found_node.dlistPtr:
            return found_node.dlistPtr

        # Find node(s) and choose closer node between successor and predecessor
        successor = self._find_successor(node.right) if node else None
        predecessor = self._find_predecessor(node.left) if node else None

        closest_node = successor if successor and predecessor else (
            successor if successor else predecessor
        )

        # if neither successor or predecessor, create a new node in dlist
        if closest_node is None:
            dlist_node = self.dlist.insertLeft(st, None)
        else:
            dlist_node = closest_node.dlistPtr

         # Update dlistPtr for node in tree
        if found_node:
            found_node.dlistPtr = dlist_node

        return dlist_node

class DNode:
    def __init__(self, d, n, p):
        self.data = d
        self.next = n
        self.prev = p

    def __str__(self):
        return str(self.data)
        
class DLinkedList:
    def __init__(self):
        self.head = self.tail = None
        self.length = 0

    # inserts a node to the left of n with data d and returns it. 
    # If it is an empty list, it does not matter what n is, 
    # we create just one node.
    def insertLeft(self, d, n):
        self.length += 1
        
        if self.length == 1: 
            self.head = DNode(d, None, None)
            self.tail = self.head
            return self.head

        np = n.prev
        n.prev = DNode(d, n, np)
        if np == None:
            self.head = n.prev
        else:
            np.next = n.prev
        return n.prev

    # inserts node with d at tail of list and returns it
    def append(self, d):
        if self.length == 0:
            return self.insertLeft(d,None)
        self.length += 1
        self.tail.next = DNode(d, None, self.tail)
        self.tail = self.tail.next
        return self.tail

    # removes node n off the list
    def remove(self, n): 
        self.length -= 1
        if n.prev == None:
            if n.next != None: n.next.prev = None
            self.head = n.next
            return
        if n.next == None:
            n.prev.next = None
            self.tail = n.prev
            return
        n.prev.next = n.next
        n.next.prev = n.prev

    def __str__(self):
        if self.head == None: 
            return "empty"
        st = "-"
        ptr = self.head
        while ptr != None:
            st += "-> "+str(ptr)+" "
            ptr = ptr.next
        return st+"|"