## Main

- This is legit one of the hardest problems I've seen so far
- I'll be going through the solution early, because I'm pretty sure I'll forget what the solution even means by the time we get around to doing this problem

### Theory

- Let's first talk about the brute force approach
    - For an array of size $N$, there are $\frac{N \cdot (N+1)}{2}$ possible contiguous subarrays
        - e.g. for [1,2,3,4,5], you can have [1,2,3,4,5], [1,2,3,4], ... [1], [2,3,4,5], [2,3,4] ... [2] 
        - This gives you 5 + 4 + 3 + 2 + 1 = 15 possible subarrays
        - For $N$-sized array, you have $\frac{N \cdot (N+1)}{2}$
    - In other words, the brute force solution will run in $O(N^2)$ time

- This can be faster! But where is the repeated/redundant work?

- Let's take $A=[1,2,3,4,5]$ as an example
    - Imagine we have an array starting at index 0
    - Since we know that $A$ is increasing, the minimum value of any array starting at index 0 MUST be at index 0!
    - No matter how much we expand our array, our minimum value is fixed. 
    - Hence, the min-product (sum of array * min of array) must be maximised when the array length is maximised! No matter what number appears at the end of the array, the additional number can only increase the sum
    - For example:
        - We previously said that there are 5 subarrays that start at index 0: 
            - [1,2,3,4,5], [1,2,3,4], [1,2,3], [1,2], [1]
    - But actually, we only need to care about [1,2,3,4,5]. Because the other subarrays cannot possibly contain the maximum min-product!

- This argument also holds true for index 1, 2, 3, and 4
    - So in effect, we only need to consider 5 cases
        - [1,2,3,4,5], [2,3,4,5], [3,4,5], [4,5], [5]
    - This can be done in $O(N)$ time

- So far so good. What happens when we see an array in **decreasing** order?

- Let's suppose we have [5,4,3,2,1] now
    - In the case where the array is increasing, expanding to the right only increases the sum, but is guaranteed to have the same minimum value 
        - i.e. going from array [1,2] to [1,2,3] strictly increases the min-product, so no tradeoff
    - But in a decreasing array this is no longer the case
        - i.e. going from array [5,4] to [5,4,3] increases the sum, but decreases the minimum, so there is a tradeoff

- Again, brute forcing this will take us $O(N^2)$ time, same as the [1,2,3,4,5] example

- There is again repeated work we can cut!
    - In the first case, we note that expanding rightwards will always increase our min-product
    - In this case, expanding leftwards always increases our min-product
    - So all we need to check must be the same 5 arrays;
        - [5], [5,4], [5,4,3], [5,4,3,2], [5,4,3,2,1]

- More generally, it must be the case that, whenever we see a minimum value, we want to expand it as far as it can go (left AND right) to maximise the min-product

- So let's consider a case where we have [3,2,1,2,3]
    - We will iterate through the array, and at each point, we treat the value as the minimum value
    - Starting at 3
        - we cannot expand left (because we are at the end of the array), and we cannot expand right (because 2 will become the new minimum)
        - Therefore, the maximum min-product of the subarray with `3` must be min=3 * sum=3 = 9
    - Next, assume min is at 2
        - We can expand left without changing the mininum, so expand left to get [3,2]
        - We cannot expand right without changing the minimum
        - So the maximum min-product where `2` is the minimum must be min=2 * sum=5 = 10
    - Next, assume min is at 1
        - We can expand left 2 times without changing the mininum, so expand left to get [3,2,1]
        - We can expand right 2 times without changing the mininum, so expand right to get [3,2,1,2,3]
        - So the maximum min-product where `1` is the minimum must be min=1 * sum=11 = 11
    - Next, assume min is at 2
        - We cannot expand left without changing minimum
        - We can expand right, so expand right to get [2,3]
        - So the maximum min-product where `2` is the minimum must be min=2 * sum=5 = 10
    - Next, assume min is at 3
        - We cannot expand right because we are at array end
        - We cannot expand left without changing minimum
        - So the maximum min-product where `3` is the minimum must be min=3 * sum=3 = 9
    - $\therefore$ Max min-product must be 11



- So this is the basic idea; we treat every element as the minimum, and find the largest possible array we can get using this as the minimum

- But notice how this is still doing a lot of redundant comparisons
    - Take step 3 for example
    - We know that 1 expands twice to the right
    - This must mean that, in step 4, there cannot be a way to expand to the left, AND there must be 1 expansion rightwards
    - Yet, if we follow the algorithm above, we will end up making the same comparison again 

- What we truly want at each step of the iteration is a way to get the index of the previous greater value (which halts our expansion leftwards) and next greater value (which halts our expansion rightwards)

- How can we do this? Simple, we use a stack!
    - To find the next smaller element/previous smaller element is actually a Leetcode problem in itself (see [here](https://leetcode.com/problems/next-greater-element-i/description/))
    - So first, we initialise a stack, and a hashmap where keys are indices and values are the index of the next smaller element
    - Next, traversing the array from left to right, we append the value of the array to the stack **IF** the value is not smaller than the current element
    - If we terminate without finding a smaller value than current element, then there is no smaller element
    - If we find a smaller element, pop values from the stack, and add the index of the popped value to the hashmap, with the index of the next smaller element we just found as the value
    - Similar logic applies for previous smaller element

- So for every iteration, we can find the min-product in $O(1)$ time, because we can get the next smaller element, the previous smaller element, and the array sum in $O(1)$ time

- Hence, the algorithm runs in $O(N)$ time, and requires $O(N)$ space

### Solution

In [64]:
from typing import List

class Solution:
    
    def __init__(self):
        self.element_to_index_map = None

    def get_element_to_index_map(self, nums: List[int]):
        if not self.element_to_index_map:
            self.element_to_index_map = {elem: index for index, elem in enumerate(nums)}
        return self.element_to_index_map

    def get_next_smaller_element(self, nums: List[int]) -> List[int]:
        stack = []
        element_to_index_map = self.get_element_to_index_map(nums)
        next_smaller_element_index = [-1] * len(nums)
        
        for i, num in enumerate(nums):
            ## If the current value num exceeds the top of the stack, it cannot be the next smallest value. Hence, append to stack
            if (not stack) or (num >= stack[-1]):
                stack.append(num)
                continue

            ## If num is less than top of stack, then it must be the next smaller value for the top of stack. We keep popping from the stack, because num may be the next smallest value for the rest of the stack too
            while stack and num < stack[-1]:
                next_smaller_element_index[element_to_index_map.get(stack.pop())] = i
            
            ## Append num to stack
            stack.append(num)
        return next_smaller_element_index
    
    def get_prev_smaller_element(self, nums: List[int]) -> List[int]:
        stack = []
        element_to_index_map = self.get_element_to_index_map(nums)
        prev_smaller_element_index = [-1] * len(nums)

        for i, num in enumerate(nums):
            # if you see number bigger than the current value on the stack, remove it, because anything larger than current value that comes before it cannot possibly be a left boundary (i.e. if current value is 5, and stack value is 6, if 5 is not a left boundary, 6 will never be a left boundary)
            while stack and (num <= stack[-1]):
                stack.pop()

            # If the value at the top of the stack now is smaller than num, it must be the left boundary. Set the index of prev_smaller_element to the index of the stack value
            if stack and (num > stack[-1]):
                prev_smaller_element_index[i] = element_to_index_map.get(stack[-1])

            # Append num to the stack
            stack.append(num)

        return prev_smaller_element_index
            
    def get_cumulative_sum(self, nums: List[int]) -> List[int]:
        cumval = 0
        res = [0] * len(nums)
        for i,num in enumerate(nums):
            cumval += num
            res[i] = cumval
        return [0]+res
            
    def maxSumMinProduct(self, nums: List[int]) -> int:
        next_smaller_element = self.get_next_smaller_element(nums)
        prev_smaller_element = self.get_prev_smaller_element(nums)
        cumulative_sum = self.get_cumulative_sum(nums)
        element_to_index_map = self.get_element_to_index_map(nums)
        
        max_min_product = -float('inf')
        for candidate_minval in nums:
            # print('='*50)
            # print(f"{nums=}, {next_smaller_element=}, {prev_smaller_element=}, {cumulative_sum=}, {element_to_index_map=}")
            left_boundary = prev_smaller_element[element_to_index_map.get(candidate_minval)]
            right_boundary = next_smaller_element[element_to_index_map.get(candidate_minval)]

            if right_boundary == -1:
                right_boundary = len(nums)

            candidate_array_sum = cumulative_sum[right_boundary] - cumulative_sum[left_boundary+1]

            candidate_min_product = candidate_array_sum * candidate_minval
            max_min_product = max(max_min_product, candidate_min_product)

            # print(f'{left_boundary=}, {right_boundary=}, {candidate_array_sum=}, {candidate_minval=}, {candidate_min_product=}')
        return max_min_product

In [66]:
s = Solution()
nums = [2,4,6,1,3,5]
# s.get_next_smaller_element(nums)
# s.get_prev_smaller_element(nums)
# s.get_cumulative_sum(nums)
s.maxSumMinProduct(nums)

nums=[2, 4, 6, 1, 3, 5], next_smaller_element=[3, 3, 3, -1, -1, -1], prev_smaller_element=[-1, 0, 1, -1, 3, 4], cumulative_sum=[0, 2, 6, 12, 13, 16, 21], element_to_index_map={2: 0, 4: 1, 6: 2, 1: 3, 3: 4, 5: 5}
left_boundary=-1, right_boundary=3, candidate_array_sum=12, candidate_minval=2, candidate_min_product=24
nums=[2, 4, 6, 1, 3, 5], next_smaller_element=[3, 3, 3, -1, -1, -1], prev_smaller_element=[-1, 0, 1, -1, 3, 4], cumulative_sum=[0, 2, 6, 12, 13, 16, 21], element_to_index_map={2: 0, 4: 1, 6: 2, 1: 3, 3: 4, 5: 5}
left_boundary=0, right_boundary=3, candidate_array_sum=10, candidate_minval=4, candidate_min_product=40
nums=[2, 4, 6, 1, 3, 5], next_smaller_element=[3, 3, 3, -1, -1, -1], prev_smaller_element=[-1, 0, 1, -1, 3, 4], cumulative_sum=[0, 2, 6, 12, 13, 16, 21], element_to_index_map={2: 0, 4: 1, 6: 2, 1: 3, 3: 4, 5: 5}
left_boundary=1, right_boundary=3, candidate_array_sum=6, candidate_minval=6, candidate_min_product=36
nums=[2, 4, 6, 1, 3, 5], next_smaller_element=[3, 3,

40