<a href="https://colab.research.google.com/github/simrathanspal/deep_models_from_scratch/blob/main/Softmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import math
import heapq

class Softmax:
  def naive(self, arr):
    d = sum(math.exp(x) for x in arr)
    return [math.exp(x)/d for x in arr]

  def stable(self, arr):
    """
    Extremely small or large values can make exponent value 0 or extremely large
    This leads to instable training.

    Very large value -> Numerical overflow -> treated as inf
    inf/inf -> NaN

    Very small value -> Numerical underflow -> treated as 0
    0/0 -> inf

    Solution: subtract all values by max value so that it doesn't grow very big
    """

    m = max(arr)
    d = sum(math.exp(x-m) for x in arr)
    return [math.exp(x-m)/d for x in arr]

  def online(self, arr):
    """
    In Transformers we compute Softmax at multiple places
    1) Attention
    2) Next token prediction from FFN

    The softmax in attention layer causes the max inefficiency because
    the attention matrix is NxN where N is the length of the sequence.

    Every token attends to every other token.

    So, for computing Softmax for every token we go over the whole sequence.
    And, we do this for the whole sequence.
    For n_heads * n_layers. Which makes it a bottle neck.

    Hence the Softmax optimization was introduced in Flash Attention.
    The key idea is to take small chunks of the data and keep a running sum and
    max so far and adjust the previous chunk values.
    """
    m = float("-inf")
    d = 0.0
    tile_size = 2

    for i in range(0, len(arr), tile_size):
      tile = arr[i:i+tile_size]

      # Compute running sum for the tile
      m_tile = max(tile)
      d_tile = sum(math.exp(x-m_tile) for x in tile)

      # Update variables
      m_prev = m
      m = max(m, m_tile)

      d = d*math.exp(m_prev - m) + d_tile*math.exp(m_tile - m)

    return [math.exp(x-m)/d for x in arr]

  def online_topk(self, arr, k):
    """
    Introduced by the paper "Online Normalizer function for Softmax" by Nvidia.
    Softmax for output prediction becomes expensive when we are running it
    repeated for Beam Search.

    Inefficiencies
    1) Beam search will sample multiple times
    2) For every prediction we need top k but we find softmax over the whole vocabulary.

    Solution: Min Heap + Online Softmax
    """
    m = float("-inf")
    d = 0.0
    tile_size = 2
    min_heap = []

    for i in range(0, len(arr), tile_size):
      tile = arr[i:i+tile_size]

      # Calculate for the tile
      m_tile = max(tile)
      d_tile = sum(math.exp(x-m_tile) for x in tile)

      # Update values
      m_prev = m
      m = max(m, m_tile)

      # Adjust values
      d = d* math.exp(m_prev - m) + d_tile* math.exp(m_tile -m)

      for j,x in enumerate(tile):
        if len(min_heap)<k:
          heapq.heappush(min_heap, (x, i+j))
        elif x > min_heap[0][0]:
          heapq.heapreplace(min_heap, (x, i+j))

    min_heap.sort(key = lambda x: x[0], reverse = True)
    top_prob = [math.exp(x-m)/d for x,j in min_heap]
    top_ids = list(map(lambda x: x[1], min_heap))

    return top_prob, top_ids

