In [1]:
import itertools

# 最基础的iterator/generator

In [2]:
def integers(start=1):
    while True:
        yield start
        start += 1

In [3]:
a = integers(start=4)
count = 0
while True:
    print(next(a), end=', ')
    count += 1
    if count == 10:
        break

4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 

## 上面的generator在`itertools`中已经自带

In [4]:
a = itertools.count(start=4, step=1)
count = 0
while True:
    print(next(a), end=', ')
    count += 1
    if count == 10:
        break

4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 

# 想要增加的小功能

## 使用`send()`, 重置起点

In [5]:
def integers_v2(start=1):
    while True:
        i = yield start
        if i is not None:
            start = i - 1
        start += 1

In [6]:
a = integers_v2(start=4)
print(next(a))
print(next(a))
print(a.send(100))
print(next(a))
print(a.send(10))
print(next(a))

4
5
100
101
10
11


## 剔除掉某些值

In [7]:
def is_divided_by(n):
    return lambda x: x % n == 0


def filter_integers(rules, basic_generator):
    # 例如剔除所有3的倍数和5的倍数
    while True:
        x = next(basic_generator)
        if any((rule(x) for rule in rules)):
            continue
        else:
            yield x

In [8]:
a = filter_integers([is_divided_by(3), is_divided_by(5)], integers(start=0))

In [9]:
count = 0
while True:
    print(next(a), end=', ')
    count += 1
    if count == 10:
        break

1, 2, 4, 7, 8, 11, 13, 14, 16, 17, 

### 这个功能在`itertools`中也有自带: `filterfalse`

In [10]:
a = itertools.filterfalse(is_divided_by(5), integers(start=0))
a = itertools.filterfalse(is_divided_by(3), a)
count = 0
while True:
    print(next(a), end=', ')
    count += 1
    if count == 10:
        break

1, 2, 4, 7, 8, 11, 13, 14, 16, 17, 

### 基于上述的工具, 我们可以写一个无穷质数生成器

In [11]:
from math import sqrt

def prime_number_v1():
    basic_generator = integers(start=3)
    prime_numbers = {2}
    yield 2
    while True:
        x = next(basic_generator)
        if any((is_divided_by(n)(x) for n in prime_numbers)):
            pass
        else:
            yield x
            prime_numbers.add(x)

def prime_number_v2():
    """v2比v1更高效,因为筛子范围更小..."""
    basic_generator = integers(start=3)
    prime_numbers = {2}
    yield 2
    while True:
        x = next(basic_generator)
        if any((is_divided_by(n)(x) for n in prime_numbers if n <= sqrt(x))):
            pass
        else:
            yield x
            prime_numbers.add(x)


In [12]:
%%timeit
a = prime_number_v1()
for i in range(10000):
    next(a)

In [None]:
%%timeit
a = prime_number_v2()
for i in range(10000):
    next(a)

10.8 s ± 736 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### 对上述的无穷质数生成器, 增加一个新功能: 跳过一些数字, 只生成比某数大的质数

In [None]:
# @no_repeat_generator_decorator
def prime_number_v3(start=2):
    while True:
        if start == 2:
            x = yield start
        if (not any((is_divided_by(n)(start) for n in range(2, int(sqrt(start)) + 2)))) and start > 1:
            x = yield start
        else:
            start += 1
            continue
        
        if x is not None:
            start = x - 1
        start += 1

In [None]:
%%timeit
a = prime_number_v3()
for i in range(10000):
    next(a)

1.79 s ± 155 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
a = prime_number_v1()
b = prime_number_v2()
c = prime_number_v3()
for i in range(10000):
    assert next(a) == next(b) == next(c)

In [None]:
a = prime_number_v3(start=2)

In [None]:
print(next(a))
print(next(a))
print(a.send(99))
print(next(a))
print(a.send(99))
print(next(a))
print(a.send(121))
print(next(a))
print(a.send(100))


2
3
5
7
11
13
17
19
23


## 使用yield用二分法来解方程

In [None]:
def func(x):
    return x ** 2 - 6 * x + 8

In [None]:
func(-1.1)

-2.220446049250313e-16

In [None]:
from math import isclose
def is_close(a, b):
    return isclose(a, b, rel_tol=1e-3, abs_tol=1e-3)

In [None]:
def left_flow(start=3.1):
    while True:
        i = yield start
        if i is not None:
            start = i
def right_flow(start=6):
    while True:
        i = yield start
        if i is not None:
            start = i
left = left_flow()
right = right_flow()

# 假设初始区间是正确的
while True:
    left_value = next(left)
    right_value = next(right)
    mid_value = (left_value + right_value) * 0.5
    func_mid_value = func(mid_value)
    func_left_value = func(left_value)
    func_right_value = func(right_value)

    if is_close(func_left_value, 0):
        print(left_value)
        break
    elif is_close(func_right_value, 0):
        print(right_value)
        break
    elif is_close(func_mid_value, 0):
        print(mid_value)
        break
    else:
        if func_left_value * func_mid_value >= 0:
            left.send(mid_value)
            # right.send(right_value)
        elif func_right_value * func_mid_value >= 0:
            right.send(mid_value)
            # left.send(left_value)

4.000000000023283


### 写得更像流一点, 并且对初始区间的正确性不限制太死

In [None]:
def value_flow(start):
    while True:
        i = yield start
        if i is not None:
            start = i

def no_repeat_generator_decorator(gen, check_last=5, tol=1e-2):
    def new_gen(*args, **kwargs):
        cache = []
        tmp = gen(*args, **kwargs)
        while True:
            x = next(tmp)
            if cache and any((abs(x - i) / ((abs(x) + abs(i)) * 0.5) < tol) for i in cache):
                break
            yield x
            cache.append(x)
            if len(cache) > check_last:
                cache.pop(0)
    return new_gen
            

@no_repeat_generator_decorator
def solve_equation(left_start=1.1, right_start=6):
    # print(f'running {left_start = }, {right_start = }')
    left_flow = value_flow(left_start)
    right_flow = value_flow(right_start)
    while True:
        left = next(left_flow)
        right = next(right_flow)
        mid = (left + right) / 2
        # 情况1：中点恰好满足， yield it！
        for _ in (left, mid):
            if abs(func(_) - 0) <= 1e-5:
                yield _
                break
        else:
            # 情况2：中点函数值不满足，并且区间长度太小
            if abs(left - right) <= 1e-5:
                # raise Exception(f'{mid = }')
                break
            # 情况3：中点函数值不满足，并且给定的区间不满足初始区间的要求
            # todo: 还有bug
            if func(left) * func(right) > 0:
                yield from solve_equation(left, mid)
                yield from solve_equation(mid+1e-5, right)
                break
        # 情况4：二分法，更新边界
        # print(f'hello, {mid = }')
        if func(left) * func(mid) >= 0:
            left_flow.send(mid)
        elif func(right) * func(mid) >= 0:
            right_flow.send(mid) 

In [None]:
a = solve_equation(-1,100)
print(next(a))
print(next(a))
print(next(a))
print(next(a))
print(next(a))

1.9999991059303284
3.9999966482543936


RuntimeError: generator raised StopIteration