Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

You must find a solution with a memory complexity better than O(n2).

 

Example 1:

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
Example 2:

Input: matrix = [[-5]], k = 1
Output: -5
 

Constraints:

n == matrix.length == matrix[i].length
1 <= n <= 300
-109 <= matrix[i][j] <= 109
All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
1 <= k <= n2
 

Follow up:

Could you solve the problem with a constant memory (i.e., O(1) memory complexity)?
Could you solve the problem in O(n) time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.

# intution 1:
- using minheap, to give you always the min value.
- Use a min-heap to simulate the sorted order.

    - Initially, push the first element of each row into the heap.
    - Store (value, row index, column index) in the heap.
    - Pop the smallest element from the heap k-1 times.
    - For every pop, push the next element in the same row (i.e., move right) into the heap.
    - The k-th popped element is your answer.


[ [ 1,  5,  9 ],
  [10, 11, 13],
  [12, 13, 15] ]
k = 8

Push: (1,0,0), (10,1,0), (12,2,0) → heap based on first column.

Pop 1 → Push (5,0,1)

Pop 5 → Push (9,0,2)

Pop 9 → nothing to push from row 0

Pop 10 → Push (11,1,1)

Pop 11 → Push (13,1,2)

Pop 12 → Push (13,2,1)

Pop 13 → 8th element! ✅

Answer = 13

| Complexity | Explanation                                              |
| ---------- | -------------------------------------------------------- |
| **Time**   | `O(k log n)` — Heap of size `n`, and we do `k` pops.     |
| **Space**  | `O(n)` — Heap stores at most `n` elements (one per row). |


In [None]:
import heapq

class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        n = len(matrix)
        min_heap = []

        # Step 1: Add the first element of each row
        for row in range(min(k, n)):
            heapq.heappush(min_heap, (matrix[row][0], row, 0))

        # Step 2: Pop k-1 elements
        for _ in range(k - 1):
            val, row, col = heapq.heappop(min_heap)

            if col + 1 < n:
                heapq.heappush(min_heap, (matrix[row][col + 1], row, col + 1))

        return heapq.heappop(min_heap)[0]


In [None]:
class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        n = len(matrix)
        low = matrix[0][0]
        high = matrix[-1][-1]
        ans = -1

        while low <= high:
            mid = (low + high) // 2
            count = self.count_less_equal(matrix, mid)
            
            if count >= k:
                # [[1,5,9],[10,11,13],[12,13,15]]
                # search space:
                # [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], k = 8
                # Does the ele has atlest k elemnts below them.
                # [f,f,f,f,f,f,f,f,f,f, f, f, t,  t, t]
                # looking for the first true.
                ans = mid
                high = mid - 1
            else:
                low = mid + 1

        return ans

    def count_less_equal(self, matrix, mid):
        n = len(matrix)
        # why search start at the bottom left.
        i = n - 1
        j = 0
        count = 0
        
        while i >= 0 and j < n:
            
            if matrix[i][j] <= mid:
                count += i + 1
                j += 1
            else:
                i -= 1
        return count


# tc - O(n[counting the values] * log n[bin.ser])
# counting the values, move to next col value are below row value... is its not O(n^2)
# sc - O(1)

🔍 Why Bottom-Left?
We want to efficiently count how many elements are <= mid.

Starting at matrix[i][j] = matrix[n-1][0] (bottom-left):
If matrix[i][j] <= mid, then:

All elements above it in the same column (from row 0 to i) are also ≤ mid (because columns are sorted).

So we can count i + 1 elements in this column and move right (j += 1) to the next column.

If matrix[i][j] > mid, then:

All elements below it in the same column will also be > mid.

So we move up (i -= 1) to try a smaller number.

This makes it possible to skip over entire rows or columns quickly and count efficiently.

In [4]:
Solution().kthSmallest(matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8)

13