# Generators

Programming in Python

School of Computer Science, University of St Andrews

## Generators
* Iterators are *classes* that produce a sequence of values
* Generators are *functions* that produce a sequence of values
* It can be expensive to generate an entire list at once (e.g. using list comprehensions)
* Generator does so called "lazy generation", yielding one element at a time

## A simple generator function
* Note the `yield` statement, which makes it special

In [1]:
def count_from(n): 
    while True:
        yield n 
        n += 1

In [2]:
for i in count_from(1):
    if i < 11:
        print(i, end=' ')
    else:
        break

1 2 3 4 5 6 7 8 9 10 

## A generator function returns a generator object

In [3]:
count_from(-5)

<generator object count_from at 0x7fbab0bc8270>

In [4]:
for i in count_from(-5):
    if i < 6:
        print(i, end=' ')
    else:
        break

-5 -4 -3 -2 -1 0 1 2 3 4 5 

## How does it work?
* generator function is only run when you use it as an iterator
* first time function is run, runs until `yield`
* `yield` acts like `return` but allows to continue another execution cycle
* each subsequent call runs the execution cycle one more time
* this goes on and on, until `yield` no longer returns

In [5]:
def count_from_info(n):
    while True:
       print("--- before yield, n =", n)
       yield n
       n += 1
       print("--- after yield, n =", n)

In [6]:
for i in count_from_info(1):
    print('*** new iteration')
    if i <= 3: 
        print('*** i =', i)  
    else:
        break

--- before yield, n = 1
*** new iteration
*** i = 1
--- after yield, n = 2
--- before yield, n = 2
*** new iteration
*** i = 2
--- after yield, n = 3
--- before yield, n = 3
*** new iteration
*** i = 3
--- after yield, n = 4
--- before yield, n = 4
*** new iteration


## Generator comprehension vs list comprehension

In [7]:
squares = [n*n for n in range(0,11)]
squares

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100]

Generator comprehension uses same notation as list comprehension but with round brackets

In [8]:
squares = (n*n for n in count_from(1))

Instead of returning a list, generator comprehension returned another generator!

In [9]:
squares

<generator object <genexpr> at 0x7fbab0bc8900>

In [10]:
for j in squares:
    if j <= 1000:
        print(j, end=' ')
    else:
        break

1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361 400 441 484 529 576 625 676 729 784 841 900 961 

## It can even be used with `enumerate`

In [11]:
enum = enumerate(squares)

In [12]:
for x in enum:
    print(x, end=' ')
    if x[0] > 10:
        break

(0, 1089) (1, 1156) (2, 1225) (3, 1296) (4, 1369) (5, 1444) (6, 1521) (7, 1600) (8, 1681) (9, 1764) (10, 1849) (11, 1936) 

- But why does it start from 1089?
- Need to reset `squares` to fix this

In [13]:
squares = (n*n for n in count_from(1))
enum = enumerate(squares)

In [14]:
for x in enum:
    print(x, end=' ')
    if x[0] > 10:
        break

(0, 1) (1, 4) (2, 9) (3, 16) (4, 25) (5, 36) (6, 49) (7, 64) (8, 81) (9, 100) (10, 121) (11, 144) 

# Generating prime numbers

In [15]:
# a helper function to check if `n` is a multiple of some number from the list `numbers`
def is_multiple(numbers, n):
    for i in numbers:
        if n % i == 0:
            return True # n is a multiple of i
    return False

In [16]:
def prime_generator():
    primes = []
    for i in count_from(2):
        if not is_multiple(primes, i):
            primes.append(i)
            yield i

In [17]:
    for i in prime_generator():
        if i >= 500:
            break
        print(i, end=" ")

2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53 59 61 67 71 73 79 83 89 97 101 103 107 109 113 127 131 137 139 149 151 157 163 167 173 179 181 191 193 197 199 211 223 227 229 233 239 241 251 257 263 269 271 277 281 283 293 307 311 313 317 331 337 347 349 353 359 367 373 379 383 389 397 401 409 419 421 431 433 439 443 449 457 461 463 467 479 487 491 499 

## Exercise
* Rewrite the code of `prime_generator` above to avoid using `count_from`
* Test it on the same example to generate all primes less than 500


In [18]:
def prime_generator2():
    primes = []
    i = 1
    while True:
        i += 1
        if not is_multiple(primes, i):
            primes.append(i)
            yield i

In [19]:
    for i in prime_generator2():
        if i >= 500:
            break
        print(i, end=" ")


2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53 59 61 67 71 73 79 83 89 97 101 103 107 109 113 127 131 137 139 149 151 157 163 167 173 179 181 191 193 197 199 211 223 227 229 233 239 241 251 257 263 269 271 277 281 283 293 307 311 313 317 331 337 347 349 353 359 367 373 379 383 389 397 401 409 419 421 431 433 439 443 449 457 461 463 467 479 487 491 499 