# Task

- [x] 表达式求导
- [x] 堆排序
- [x] 归并排序
- [x] 基数排序
- [x] 几种排序对比

## 表达式求导

这是第三次交表达式求导的作业了，第一天就已经跟着课件实现完，第二天添加了`simplify`，
但当时认为`__neg__`的`simplify`没用，今天上课想到还是有用的，就把负数运算符实现一下，作为最完整的提交

In [5]:
class Expr(object):
    def eval(self, **values):     # evaluate the expression
        pass
    def deriv(self, x):           # get the derivative of x
        pass
    
    def __add__(self, other):     # overloading + operator
        return Add(self, other).simplify()
    def __sub__(self, other):     # overloading - operator
        return Sub(self, other).simplify()
    def __mul__(self, other):     # overloading * operator
        return Mul(self, other).simplify()
    def __neg__(self):            # overloading - operator(单目)
        return Neg(self).simplify()
    def __truediv__(self, other): # overloading / operator
        return TrueDiv(self, other).simplify()
    
class Const(Expr):
    def __init__(self, value):
        self.value = value
    def eval(self, **values):
        return self.value
    def deriv(self, x):
        return Const(0)
    def __repr__(self):
        return str(self.value)
    
def get_var_name(x):
    return x.name if isinstance(x, Variable) else x
    
class Variable(Expr):
    def __init__(self, name):
        self.name = name
    def eval(self, **kwargs):
        '''
        x.eval(x=3, y=4) ==> 3
        x.eval(y=4) ==> error
        '''
        if self.name in kwargs:
            return kwargs[self.name]
        else:
            raise Exception(f"Variable {self.name} is not found")
    def deriv(self, x):
        '''
        derivative of var itself is 1, 
        otherwise is 0
        '''
        var_name = get_var_name(x)
        return Const(1 if var_name == self.name else 0)
    def __repr__(self):
        return self.name
    
x1 = Variable("x")
print(x1)
print(x1.eval(x=3))
print(x1.deriv("x"))
print(x1.deriv("x").eval(x=4))
print(x1.deriv("y").eval(x=4))

x
3
1
1
0


In [6]:
class Add(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def eval(self, **kwargs):
        return self.left.eval(**kwargs) + self.right.eval(**kwargs)
    def deriv(self, x):
        # get a NEW Add object
        return self.left.deriv(x) + self.right.deriv(x)
    def __repr__(self):
        return f"({self.left} + {self.right})"
    
    def simplify(self):
        left, right = self.left, self.right
        left_const = isinstance(left, Const) 
        right_const = isinstance(right, Const)
        if left_const and right_const:
            return Const(left.value + right.value)
        return self

class Sub(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def eval(self, **kwargs):
        return self.left.eval(**kwargs) - self.right.eval(**kwargs)
    def deriv(self, x):
        return self.left.deriv(x) - self.right.deriv(x)
    def __repr__(self):
        return f"({self.left} - {self.right})"
    def simplify(self):
        left, right = self.left, self.right
        left_const = isinstance(left, Const) 
        right_const = isinstance(right, Const)
        if left_const and right_const:
            return Const(left.value - right.value)
        return self
    
class Neg(Expr):
    def __init__(self, value):
        self.value = value
    def eval(self, **kwargs):
        return -self.value.eval(**kwargs)
    def deriv(self, x):
        return -self.value.deriv(x)
    def simplify(self):
        if isinstance(self.value, Const):
            return Const(-self.value.value)
        return self
    def __repr__(self):
        return '(-%s)' % self.value
        
class Mul(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def eval(self, **kwargs):
        return self.left.eval(**kwargs) * self.right.eval(**kwargs)
    def deriv(self, x):
        '''
        (uv)' = u'v + v'u  (Chain Rule)
        '''
        u, v = self.left, self.right
        return u.deriv(x) * v + v.deriv(x) * u
    def __repr__(self):
        return f"({self.left} * {self.right})"
    
    def simplify(self):
        left, right = self.left, self.right
        left_const = isinstance(left, Const) 
        right_const = isinstance(right, Const)
        if left_const and right_const:
            return Const(left.value * right.value)
        if left_const:
            if left.value == 0: return Const(0)
            if left.value == 1: return right
        if right_const:
            if right.value == 0: return Const(0)
            if right.value == 1: return left
        return self
    
class TrueDiv(Expr):
    def __init__(self, left, right):
        self.left = left
        self.right = right
    def eval(self, **kwargs):
        return self.left.eval(**kwargs) / self.right.eval(**kwargs)
    def deriv(self, x):
        '''
        (u/v)' = (u'v - v'u) / v^2 (Quotient Rule)
        '''
        u, v = self.left, self.right
        return (u.deriv(x)*v - u*v.deriv(x))/(v*v)
    def __repr__(self):
        return f"({self.left} / {self.right})"    
    
    def simplify(self):
        left, right = self.left, self.right
        left_const = isinstance(left, Const) 
        right_const = isinstance(right, Const)
        if left_const and right_const:
            if right_const.value == 0:
                raise Error("divide by zero error")
            return Const(left.value / right.value)
        if left_const and left.value == 0: return Const(0)
        if right_const and right.value == 1: return left
        return self
    
c1 = Const(13)
c2 = Variable("a")
c = c1 + c2
d = c1 - c2
e = c1 * c2
f = c1 / c2
g = -c1
h = -c2
print(f'{c} = {c1.eval()} + 5 = {c.eval(a=5)}')
print(f'{d} = {c1.eval()} - 5 = {d.eval(a=5)}')
print(f'{e} = {c1.eval()} * 5 = {e.eval(a=5)}')
print(f'{f} = {c1.eval()} / 5 = {f.eval(a=5)}')
print(f'{g} = {g.eval(a=2)}')
print(f'{h} = {h.eval(a=2)}')
print('-'*24)
print(f'{c}\' = {c.deriv("a").eval(a=3)}')
print(f'{d}\' = {d.deriv("a").eval(a=3)}')
print(f'{e}\' = {e.deriv("a").eval(a=3)}')
print(f'{f}\' = {f.deriv("a").eval(a=3)}')


(13 + a) = 13 + 5 = 18
(13 - a) = 13 - 5 = 8
(13 * a) = 13 * 5 = 65
(13 / a) = 13 / 5 = 2.6
-13 = -13
(-a) = -2
------------------------
(13 + a)' = 1
(13 - a)' = -1
(13 * a)' = 13
(13 / a)' = -1.4444444444444444


求$(2*x-6y)/(3*x+4*y)$在x=2, y=3处的的导数值

In [7]:
x = Variable("x")
y = Variable("y")
exp1 = Const(2) * x - Const(6) * y
exp2 = Const(3) * x + Const(4) * y
exp = exp1 / exp2
d1 = exp.deriv(x)
d2 = exp.deriv(y)
values = {"x":2, "y":3}
print(exp)
print(f'∂f/∂x = {d1}')
print(f'∂f/∂y = {d2}')
print(values,'----'*10)
print(f'∂f/∂x = {d1.eval(**values)}')
print(f'∂f/∂y = {d2.eval(**values)}')

(((2 * x) - (6 * y)) / ((3 * x) + (4 * y)))
∂f/∂x = (((2 * ((3 * x) + (4 * y))) - (((2 * x) - (6 * y)) * 3)) / (((3 * x) + (4 * y)) * ((3 * x) + (4 * y))))
∂f/∂y = (((-6 * ((3 * x) + (4 * y))) - (((2 * x) - (6 * y)) * 4)) / (((3 * x) + (4 * y)) * ((3 * x) + (4 * y))))
{'x': 2, 'y': 3} ----------------------------------------
∂f/∂x = 0.24074074074074073
∂f/∂y = -0.16049382716049382


## 堆排序

以下代码是对如下示意图的实现： 

![heap_sort](img/heap_sort.gif)

在实现每一轮的遍历数字较大的那个子节点并交换数字的过程中，我之前用的是递归，在小数据量顺利通过，但上万条数据时碰到了`RecursionError: maximum recursion depth exceeded in comparison`, 查询本机迭代大小设置为1000，但设到几十万就不起作用了（虽然不报错），于是改成了`while`循环，代码几乎没变，但是秒过了。

可见递归并不是个好东西（在python世界？）

In [7]:
# helper
get_parent_index = lambda i : max((i - 1) // 2, 0)
get_child_index  = lambda i : 2 * i + 1

def heap_sort(arr):
    heapify(arr)                    # 初排
    siftDown(arr, 0, len(arr)-1)    # 整理
    return arr

def heapify(arr):
    index = 1
    while index < len(arr):
        p_index = get_parent_index(index)
        parent  = arr[p_index]
        child   = arr[index]
        if child > parent:
            arr[p_index], arr[index] = arr[index], arr[p_index]
            siftUp(arr, p_index)
        index += 1
    return arr

def siftUp(arr, c_index):
    p_index = get_parent_index(c_index)
    parent  = arr[p_index]
    leaf    = arr[c_index]
    if parent < leaf:
        arr[p_index], arr[c_index] = arr[c_index], arr[p_index]
    if p_index > 0:
        siftUp(arr, p_index)

def siftDown(arr, start, end):
    '''
    1. 交换首尾两个数，这样尾数就变成了最大
    2. 跟两个子节点中较大的比较，并迭代，递归下去
    '''
    while start < end:
        if start == 0:
            arr[0], arr[end] = arr[end], arr[0]
        left_i  = get_child_index(start)
        if left_i >= end: 
            # 子结点是end，就不要比了，把当前节点设为新end
            start = 0
            end -= 1
            continue
        else:
            right_i = left_i + 1
            index = left_i
            if right_i < end:
                # 右边没有到end的话，取出值比大小
                # 并且把下一轮的start设为选中的子节点
                index = left_i if arr[left_i] > arr[right_i] else right_i
            parent  = arr[start]
            if parent < arr[index]:
                arr[start], arr[index] = arr[index], arr[start]
        # 如果左叶子已经被标记为end  (已提前return)
        # 如果右边叶子被标记为end
        # 如果下一个索引被标记为end
        # 都表示本轮遍历已经到底, end往前移一位即可
        if right_i >= end or (end - right_i) == 1:
            start = 0 # 用start=0表示需要进行一次首尾替换再从头到尾移动一次
            end -= 1
        else:
            # 否则进入下一个循环
            # 起点就是用来跟父级做比较的索引
            # 终点不变
            start = index

if __name__ == "__main__":
    import numpy as np
    import time
#     np.random.seed(7)
#     length = 20000
#     arr = list(np.random.randint(0, length*5, size=(length,)))
    arr = list("65318724")
    start = time.time()
    s = heap_sort(arr)
    print(time.time()-start, '\n', s[:100])

4.696846008300781e-05 
 ['1', '2', '3', '4', '5', '6', '7', '8']


## 归并排序

以下代码是对如下示意图的实现： 

![merge_sort](img/merge_sort.gif)

In [8]:
import math
def merge_sort(arr):
    '''
    每一轮比较的时候是把选中的元素填到另一个数组里
    为了减少内存消耗，就循环用两个数组
    我们用交替设置i和j为0和1来实现这个逻辑
    '''
    start    = 0
    step     = 1
    length   = len(arr) - 1
    lists    = [arr, []]
    i, j     = 0, 1
    while step < length:
        compare(lists[i], start, step, length, lists[j])
        step *= 2
        i, j  = j, i
    return lists[i]

def gen_indexs(start, step, length):
    '''
    根据左边界和步长确定本轮拿来比较的两个数组的边界
    '''
    left_end    = start + step - 1
    right_start = min(start + step, length)
    right_end   = min(right_start + step - 1, length)
    return start, left_end, right_start, right_end


def compare(arr, start, step, length, result):
    result.clear()
    left_start, left_end, right_start, right_end \
                = gen_indexs(start, step, length)
    left_index  = 0  # 组内索引(0, step-1)
    right_index = 0
    while left_start <= length:
        left    = left_start + left_index
        right   = min(right_start + right_index, length)
        l_done  = False
        r_done  = False
        if arr[left] < arr[right]:
            result.append(arr[left])
            left_index += 1
            left   = left_start + left_index
            l_done = left == right_start
        else:
            result.append(arr[right])
            right_index += 1
            r_done = (right_start + right_index) > right_end
        if l_done or r_done:
            if l_done:
                # 左边没数了，右边的数全塞到result里去
                result += arr[right:right_end]
                result.append(arr[right_end])
            else:
                # 右边没数了，左边剩下的数全塞到result里去
                result += arr[left:right_start]
            left_start, left_end, right_start, right_end \
                        = gen_indexs(right_end+1, step, length)
            left_index  = 0
            right_index = 0

if __name__ == "__main__":
    import numpy as np
    import time
#     np.random.seed(7)
#     length = 20000
#     arr = list(np.random.randint(0, length*5, size=(length,)))
    arr = [1,6,17,5,2,9,3,1,22,9,8,0,7,65]#,2,13,4,6,17,33,8,0,4,17,22]
    start = time.time()
    s = merge_sort(arr)
    print(time.time()-start, s[:100])

5.626678466796875e-05 [0, 1, 1, 2, 3, 5, 6, 7, 8, 9, 9, 17, 22, 65]


## 基数排序

![radix_sort](img/radix_sort.png)

In [15]:
def get_number(num, index):
    '''
    个位：527 % 10^1 // 10^0 = 7
    十位：527 % 10^2 // 10^1 = 2
    百位：527 % 10^3 // 10^2 = 5
    千位：527 % 10^4 // 10^3 = 0
    '''
    return num % 10**(index+1) // 10**index

def digit_length(number):
    import math
    return 1 if abs(number) < 10 else int(math.log10(abs(number)) + 1)

def digit_sort(arr, index):
    '''
    对第index个数字进行排序
    统计每个数字出现的次数，按次数从低到高排列
    '''
    results = []
    for i in range(11):
        results.append([]) # [[]] * 10 会造成引用传递
    for num in arr:
        i = get_number(num, index)
        results[i].append(num)
    return [digit for sublist in results for digit in sublist]  # flatten the 2-d array

def radix_sort(arr):
#     length = len(repr(max(arr)))
    length = digit_length(max(arr)) # 演示如何从数学上取得数字的长度（几十万次迭代效率只有毫米级的差别）
    for i in range(length):
        arr = digit_sort(arr, i)
        # print(i, arr)
    return arr

if __name__ == "__main__":
    import numpy as np
    import time
#     np.random.seed(7)
#     length = 20000
#     arr = list(np.random.randint(0, length*50, size=(length,)))
    arr = [954,354,309,411]
    start = time.time()
    s = radix_sort(arr)
    print(time.time()-start, s[:100])

4.076957702636719e-05 [309, 354, 411, 954]


In [2]:
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr.pop()
    left  = [item for item in arr if item <= pivot]
    right = [item for item in arr if item > pivot]
    return quick_sort(left) + [pivot] + quick_sort(right)

quick_sort([1,34,8,6,76])

[1, 6, 8, 34, 76]

In [4]:
def insert_sort(arr):
    rst = arr[0:1]
    for i in arr[1:]:
        found = False
        for idx, j in enumerate(rst):
            if j > i:
                rst.insert(idx, i) # 排到第一个比它大的前面
                found = True
                break;
        if not found:
            rst.append(i)
    return rst

insert_sort([1,34,5,6,9,0])

[0, 1, 5, 6, 9, 34]

In [5]:
def shell_sort(arr):
    group = len(arr) // 2
    
    while group > 0:
        for i in range(group, len(arr)):
            right   = arr[i]
            current = i
            while current >= group and arr[current - group] > right:
                arr[current] = arr[current - group]
                current -= group
            arr[current] = right
        group //= 2
    return arr

shell_sort([34,24,538,536,1])

[1, 24, 34, 536, 538]

In [10]:
if __name__ == "__main__":
    import numpy as np
    import time
    np.random.seed(7)
    length = 20000
    arr = list(np.random.randint(0, length*50, size=(length,)))
    print(f'{length} random integers sort comparation:')
    for i in range(5):
        print(f'-------------round {i+1}------------')
        # insert is too slow
        # or my implementation is not so good
#         start = time.time()
#         s1 = insert_sort(arr)
#         print(f"insert_sort\t {time.time()-start:.5f} seconds")
        start = time.time()
        s2 = quick_sort(arr.copy())
        print(f"quick_sort\t {time.time()-start:.5f} seconds")
        start = time.time()
        s3 = shell_sort(arr.copy())
        print(f"shell_sort\t {time.time()-start:.5f} seconds")
        start = time.time()
        s4 = heap_sort(arr.copy())
        print(f"heap_sort\t {time.time()-start:.5f} seconds")
        start = time.time()
        s5 = merge_sort(arr.copy())
        print(f"merge_sort\t {time.time()-start:.5f} seconds")
        start = time.time()
        s6 = radix_sort(arr.copy())
        print(f"radix_sort\t {time.time()-start:.5f} seconds")
    print(f"first 10 numbers:\n{s2[:10]}\n{s3[:10]}\n{s4[:10]}\n{s5[:10]}\n{s6[:10]}")

20000 random integers sort comparation:
-------------round 1------------
quick_sort	 0.07970 seconds
shell_sort	 0.17623 seconds
heap_sort	 0.32919 seconds
merge_sort	 0.20177 seconds
radix_sort	 0.18000 seconds
-------------round 2------------
quick_sort	 0.05894 seconds
shell_sort	 0.15423 seconds
heap_sort	 0.28844 seconds
merge_sort	 0.20043 seconds
radix_sort	 0.19310 seconds
-------------round 3------------
quick_sort	 0.06169 seconds
shell_sort	 0.18299 seconds
heap_sort	 0.33159 seconds
merge_sort	 0.20836 seconds
radix_sort	 0.20003 seconds
-------------round 4------------
quick_sort	 0.05780 seconds
shell_sort	 0.15414 seconds
heap_sort	 0.26780 seconds
merge_sort	 0.18810 seconds
radix_sort	 0.17084 seconds
-------------round 5------------
quick_sort	 0.05508 seconds
shell_sort	 0.15128 seconds
heap_sort	 0.27173 seconds
merge_sort	 0.19646 seconds
radix_sort	 0.17663 seconds
first 10 numbers:
[74, 87, 144, 148, 236, 254, 291, 326, 336, 363]
[74, 87, 144, 148, 236, 254, 291,