In [2]:
# Definition for a binary tree node.
class TreeNode:
    def __init__(self, x):
        self.key = x
        self.left = None
        self.right = None
        self.height = 0

In [31]:
class AVLTree:
    def __init__(self):
        self.root = None

    def height(self, node):
        if node is None:
            return -1
        else:
            return node.height

    def _update_height(self, node):
        node.height = max(self.height(node.left), self.height(node.right)) + 1

    def _unbalance(self, node):
        return abs(self.height(node.left) - self.height(node.right)) is 2

    """右旋处理LL"""
    def _right_rotate(self, node):
        node_right = node
        node = node.left
        node_right.left = node.right
        node.right = node_right
        
        self._update_height(node_right)
        self._update_height(node)
        
        return node

    """左旋处理RR"""
    def _left_rotate(self, node):
        node_left = node
        node = node.right
        node_left.right = node.left
        node.left = node_left
        
        self._update_height(node_left)
        self._update_height(node)
        
        return node

    """双向旋转（先左后右）平衡处理LR"""
    def _left_right_rotate(self, node):
        node.left = self._left_rotate(node.left)
        return self._right_rotate(node)

    """双向旋转（先右后左）平衡处理RL"""
    def _right_left_rotate(self, node):
        node.right = self._right_rotate(node.right)
        return self._left_rotate(node)

    """插入元素"""
    def insert(self, key):
        if self.root is None:
            self.root = TreeNode(key)
        else:
            self.root = self._insert(key, self.root)
            
    def _insert(self, key, node):
        if node is None:
            node = TreeNode(key)

        elif key < node.key: #左侧插入结点
            node.left = self._insert(key, node.left)
            if self._unbalance(node): #不平衡
                if key < node.left.key: #LL不平衡
                    node = self._right_rotate(node) #右旋
                else: #LR不平衡
                    node = self._left_right_rotate(node) #先左旋再右旋
             
        elif key > node.key: #右侧插入结点
            node.right = self._insert(key, node.right)
            if self._unbalance(node): #不平衡
                if key < node.right.key: #LR不平衡
                    node = self._right_left_rotate(node) #先右旋再左旋
                else: #RR不平衡
                    node = self._left_rotate(node) #左旋

        self._update_height(node)
        
        return node
    
    # 查询val在树中介于哪两个数字之间
    def query(self, val: int, maxVal: int) -> (int, int):
        p = self.root
        retL, retR = -1, maxVal
        while p:
            if val < p.key:
                retR = min(retR, p.key)
                p = p.left
            else:
                retL = max(retL, p.key)
                p = p.right
        return retL, retR

In [32]:
def maxDepth(root: TreeNode) -> int:
    def recursive(tr, dep):
        if not tr:
            return dep
        else:
            return max(recursive(tr.left, dep+1), recursive(tr.right, dep+1))
    return recursive(root, 0)

In [33]:
def showTreeNode(tree: TreeNode):
    stack = [tree]
    ret = []
    while stack and len(ret) < 2**maxDepth(tree)-1:
        tr = stack.pop(0)
        if not tr:
            ret.append('null')
            stack.append(None)
            stack.append(None)
        else:
            ret.append(tr.key)
            stack.append(tr.left)
            stack.append(tr.right)
    return ret

In [36]:
class Solution:
    def largestRectangleArea(self, heights: list) -> int:
        if not heights:
            return 0
        indices = sorted(range(len(heights)), key=lambda x: heights[x])
#         print(indices, heights)
        stack = []
        p = 0
        tree = AVLTree()
        maxAreas = [heights[0]]
        while p <= len(indices)-1:
            if not stack:
                stack.append(indices[p])
            else:
                if heights[indices[p]] == heights[stack[0]]:
                    stack.append(indices[p])
                else:
                    for s in stack:
                        l, r = tree.query(s, len(indices))
#                         print(l, r, heights[s])
                        maxAreas.append((r-l-1) * heights[s])
                    while stack:
                        s = stack.pop()
                        tree.insert(s)
#                         print(showTreeNode(tree.root))
                    continue
            p += 1
#         print(stack, maxAreas)
        if stack:
            if len(stack) == 1:
                maxAreas.append(heights[stack[0]])
            else:
                for s in stack:
                    l, r = tree.query(s, len(indices))
                    if r == -1:
                        r = len(heights)
                    maxAreas.append((r-l-1) * heights[s])
        return max(maxAreas)

In [39]:
so = Solution()
print(so.largestRectangleArea([2]), 2)
print(so.largestRectangleArea([2,3]), 4)
print(so.largestRectangleArea([2,1,5,6,2,3]), 10)
print(so.largestRectangleArea([]), 0)
print(so.largestRectangleArea([0,0,0,0,0,0]), 0)
print(so.largestRectangleArea(range(10000)), 5000**2)
print(so.largestRectangleArea([1,1]), 2)
print(so.largestRectangleArea([2,30,30]), 60)
print(so.largestRectangleArea([2,0,2]), 2)
print(so.largestRectangleArea([0,1,0,2,0,3,0]),3)

2 2
4 4
10 10
0 0
0 0
25000000 25000000
2 2
60 60
2 2
3 3
