<a href="https://colab.research.google.com/github/walkerjian/DailyCode/blob/main/Code_Craft_PrefixMapSum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implement a PrefixMapSum class with the following methods:

insert(key: str, value: int): Set a given key's value in the map. If the key already exists, overwrite the value.
sum(prefix: str): Return the sum of all values of keys that begin with a given prefix.
For example, you should be able to run the following code:

mapsum.insert("columnar", 3)
assert mapsum.sum("col") == 3

mapsum.insert("column", 2)
assert mapsum.sum("col") == 5

To implement the `PrefixMapSum` class, we can use a dictionary to store key-value pairs directly. However, to efficiently calculate the sum of values for keys that start with a given prefix, we should design a structure that supports quick lookups. A Trie (prefix tree) is well-suited for this, as it allows us to traverse the tree up to the end of the prefix and then aggregate the values of all child nodes efficiently.

We'll add an enhancement: instead of just storing the words and their associated values, each node in the Trie will also maintain the sum of all values of its subtree. This will enable the `sum` method to execute in time proportional to the length of the prefix, rather than the number of keys stored.

Here's the Python code for the `PrefixMapSum` class using this approach. This implementation ensures that:
1. Each node has a dictionary that potentially stores child nodes and a special key `_sum` which keeps track of the sum of all values under that prefix.
2. When inserting, if the key already exists, the difference caused by the new value replaces the old value.
3. The `sum` function quickly retrieves the sum by accessing the `_sum` stored at the node corresponding to the last character of the prefix.

This efficient and structured approach makes both the insertion and sum operations fast, with their complexity primarily depending on the length of the key or prefix, rather than the number of keys stored.

In [2]:
class PrefixMapSum:
    def __init__(self):
        # Trie node structure
        self.trie = {}

    def insert(self, key, value):
        node = self.trie
        for char in key:
            if char not in node:
                node[char] = {}
            node = node[char]
            # Each node stores a sum of all values below it in addition to its own value
            if '_sum' not in node:
                node['_sum'] = 0
            node['_sum'] += value  # Increment node sum by new value

        # Handle overwriting the previous value of the key if it already exists
        if '_end' in node:
            old_value = node['_end']
            node = self.trie
            for char in key:
                node = node[char]
                node['_sum'] -= old_value  # Deduct old value from the sum stored at each node
            node['_end'] = value
            node = self.trie
            for char in key:
                node = node[char]
                node['_sum'] += value  # Re-add new value to the sum stored at each node
        else:
            node['_end'] = value  # Set end marker with value at the last node

    def sum(self, prefix):
        node = self.trie
        for char in prefix:
            if char not in node:
                return 0  # If prefix does not exist, return sum as 0
            node = node[char]
        return node.get('_sum', 0)  # Return the cumulative sum at the prefix node

# Example usage:
mapsum = PrefixMapSum()
mapsum.insert("columnar", 3)
assert mapsum.sum("col") == 3

mapsum.insert("column", 2)
assert mapsum.sum("col") == 5