|
| 1 | +""" |
| 2 | +Find the kth smallest element in linear time using divide and conquer. |
| 3 | +Recall we can do this trivially in O(nlogn) time. Sort the list and |
| 4 | +access kth element in constant time. |
| 5 | +
|
| 6 | +This is a divide and conquer algorithm that can find a solution in O(n) time. |
| 7 | +
|
| 8 | +For more information of this algorithm: |
| 9 | +https://web.stanford.edu/class/archive/cs/cs161/cs161.1138/lectures/08/Small08.pdf |
| 10 | +""" |
| 11 | +from random import choice |
| 12 | +from typing import List |
| 13 | + |
| 14 | + |
| 15 | +def random_pivot(lst): |
| 16 | + """ |
| 17 | + Choose a random pivot for the list. |
| 18 | + We can use a more sophisticated algorithm here, such as the median-of-medians |
| 19 | + algorithm. |
| 20 | + """ |
| 21 | + return choice(lst) |
| 22 | + |
| 23 | + |
| 24 | +def kth_number(lst: List[int], k: int) -> int: |
| 25 | + """ |
| 26 | + Return the kth smallest number in lst. |
| 27 | + >>> kth_number([2, 1, 3, 4, 5], 3) |
| 28 | + 3 |
| 29 | + >>> kth_number([2, 1, 3, 4, 5], 1) |
| 30 | + 1 |
| 31 | + >>> kth_number([2, 1, 3, 4, 5], 5) |
| 32 | + 5 |
| 33 | + >>> kth_number([3, 2, 5, 6, 7, 8], 2) |
| 34 | + 3 |
| 35 | + >>> kth_number([25, 21, 98, 100, 76, 22, 43, 60, 89, 87], 4) |
| 36 | + 43 |
| 37 | + """ |
| 38 | + # pick a pivot and separate into list based on pivot. |
| 39 | + pivot = random_pivot(lst) |
| 40 | + |
| 41 | + # partition based on pivot |
| 42 | + # linear time |
| 43 | + small = [e for e in lst if e < pivot] |
| 44 | + big = [e for e in lst if e > pivot] |
| 45 | + |
| 46 | + # if we get lucky, pivot might be the element we want. |
| 47 | + # we can easily see this: |
| 48 | + # small (elements smaller than k) |
| 49 | + # + pivot (kth element) |
| 50 | + # + big (elements larger than k) |
| 51 | + if len(small) == k - 1: |
| 52 | + return pivot |
| 53 | + # pivot is in elements bigger than k |
| 54 | + elif len(small) < k - 1: |
| 55 | + return kth_number(big, k - len(small) - 1) |
| 56 | + # pivot is in elements smaller than k |
| 57 | + else: |
| 58 | + return kth_number(small, k) |
| 59 | + |
| 60 | + |
| 61 | +if __name__ == "__main__": |
| 62 | + import doctest |
| 63 | + |
| 64 | + doctest.testmod() |
0 commit comments