<a href="https://colab.research.google.com/github/travisormsby/python-tips-tricks/blob/main/docs/PerformanceMemory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Optimize Performance and Memory Use

When you begin to use Python regularly in your work, you'll start noticing bottlenecks in your code. Some workflows may run at lightning speed, while others take hours of processing time to complete, or even crash.

Avoiding bloat is invaluable as you move toward using code for automation, bigger data, and working with APIs. Code efficiency means:
- Less chance of a slowdown or crash: the dreaded MemoryError.
- Quicker response time and fewer bottlenecks for the larger workflow.
- Better scaling.
- Efficient code is often (but not always!) cleaner and more readable.

Let's look at some ways you can reduce bloat in your code.

tl;dr
<br>Access and store only what you need, no more.
- __Storage__: avoid a list where you could use a tuple
- __Membership look-up__: avoid a list (or tuple) where you could use a set (or dictionary)
- __Iteration__: avoid a function (or list comprehension) where you could use a generator (or generator expression)
- __Profile__: make time for performance checks by profiling your code for bottlenecks

## Storage: lists vs. tuples

If you have a collection of values, your first thought may be to store them in a list.

In [None]:
data_list = [17999712, 2015, 'Hawkins Road', 'Linden ', 'NC', 28356]

Lists are nice because they are very flexible. You can change the values in the list, including appending and removing values. But that flexibility comes at a cost. Lists are less efficient than tuples. For example, they use more memory.

In [None]:
import sys

data_tuple = (17999712, 2015, 'Hawkins Road', 'Linden ', 'NC', 28356)

print(sys.getsizeof(data_list))
print(sys.getsizeof(data_tuple))

104
88


If you aren't going to be changing the values in a collection, use a tuple instead of a list.

### Membership look-up: sequential vs. hashable

However, when you want to see if an element _already exists_ in a collection of elements, use a set or dictionary to store that collection if possible.

- List and tuple look-up is **sequential**, going at the speed of *O(n): linear time*.
    - With lists, Python scans the entire list until it finds the match (or reaches the end).
    - Worst case: it has to look at every element.
- Set and dictionary look-up are **hashable**: mapping keys to values. These go at the speed of *O(1): constant time*.
    - No matter how big the collection is, the set only ever has to check 1 value.
    - Sets are built on hash tables. Python computes the hash of the element and jumps straight to where it should be stored.

The example below shows that a set is over 1000x faster than a list in calculating the first 10,000 values of [Recaman's sequence](https://oeis.org/search?q=recaman&language=english&go=Search).

In [None]:
def recaman_check(cur, i, visited):
    return (cur - i) < 0 or (cur - i) in visited

def recaman_list(n: int) -> list[int]:
    """
    return a list of the first n numbers of the Recaman series
    """

    visited_list = [0]
    current = 0
    for i in range(1, n):
        if recaman_check(current, i, visited_list):
            current += i
        else:
            current -= i
        visited_list.append(current)
    return visited_list

In [None]:
%%timeit
recaman_list(10000)

In [None]:
def recaman_set(n: int) -> list[int]:
    """
    return a set of the first n numbers of the Recaman series
    """
    visited_set = {0}
    current = 0
    for i in range(1, 100_000):
        if recaman_check(current, i, visited_set):
            current += i
        else:
            current -= i
        visited_set.add(current)
    return visited_set

In [None]:
%%timeit
recaman_set(10000)

When you add an element to a set...
1. Python calls the element’s __hash__() method to get a hash value (an integer);
1. That hash value determines where the element will be stored in the set's internal structure; and
1. When checking if an element is in the set, Python uses the hash to quickly find it.

## Iteration: functions vs. generators

We often use functions to operate on data, but generators can be more memory-efficient and faster for certain tasks.

**Regular functions and comprehensions** typically store outputs into containers, like lists or dictionaries. This can take up unnecessary memory, especially when we're creating multi-step workflows with many intermediate outputs.

In contrast, **generators** only hold one data item in memory at a time. A generator is a type of iterator that produces results on-demand (lazily), maintaining its state between iterations.

Under the hood, a generator's syntax is similar to a function. Generally, you:
- define a process(),
- provide the logic, and
- ask for the result, either with a return statement (for functions) or a yield statement (for generators).

Imagine you have a large dataset containing millions of employee records. You want to calculate the combined hourly rates of all employees on an annual salary.

In [None]:
# For the sake of simplicity, we'll represent the dataset with a small sample.
employeeDatabase = [
  {'lastName': 'Knope', 'rate': 72000, 'pay_class': 'annual'},
  {'lastName': 'Gergich', 'rate': 17, 'pay_class': 'hourly'},
  {'lastName': 'Ludgate', 'rate': 60000, 'pay_class': 'annual'},
  {'lastName': 'Swanson', 'rate': 'redacted', 'pay_class': 'redacted'},
  {'lastName': 'Haverford', 'rate': 52000, 'pay_class': 'annual'}
]

You can use a function for this, but it means the entire input dataset will be held in memory.

In [None]:
def hourly_rate(payments):
  """Function that returns each salaried workers' hourly rate."""
  hourlyRates = []
  for worker in payments:
    if worker.get('pay_class') == 'annual':
      hourly = worker['rate'] / 2080
      hourlyRates.append(hourly)
  return hourlyRates

# Sum hourly rates for those receiving an annual salary.
salariesPerHour = sum(hourly_rate(employeeDatabase))

print(f"Total dispersments per hour for salaried employees: ${salariesPerHour:.2f}")

Total dispersments per hour for salaried employees: $88.46


If the input dataset is huge, this eats up a ton of space. Instead, what if we process data lazily, storing one row in memory at a time?

In [None]:
def hourly_rate_gen(payments):
  """Generator that yields each salaried worker's hourly rate."""
  for worker in payments:
    if worker.get('pay_class') == 'annual':
      hourly = worker['rate'] / 2080
      yield hourly

# Sum hourly rates for those receiving an annual salary.
salariesPerHour = sum(hourly_rate_gen(employeeDatabase))

print(f"Total dispersments per hour for salaried employees: ${salariesPerHour:.2f}")

Total dispersments per hour for salaried employees: $88.46


In a function, the return statement signals that the function can execute from start to finish.
- Every output that a function produces is *held in memory at the same time*.

In a generator, the yield statement signals that execution can proceed *one at a time*; when yield is executed, the generator pauses, retaining the generator's state until the next time it is called.
- Each output that a generator produces can be yielded, *then discarded* before the next output is yielded.
- A generator also makes it easier to stream input data: Each input (such as a row in a CSV) can be yielded, *then discarded* before the next input is yielded.

Note: **Generator pipelines** are a powerful tool for GIS and remote sensing. Use multiple generators to string tasks together lazily. These are hugely helpful for complex spatial analysis workflows, such as raster processing.

### Iteration, continued: List comprehension vs. generator expression

Generator expressions (aka generator comprehensions) are concise, one-line generators. Generator expressions can be a handy replacement for list comprehensions.

Let's look at how the analysis above would appear in list comprehension format.

In [None]:
hourly = [worker['rate'] / 2080 for worker in employeeDatabase if worker.get('pay_class') == 'annual']
salariesPerHour = sum(hourly)

print(f"${salariesPerHour:.2f}")

$88.46


As with the function, the list comprehension constructs a list of n values. Then, we use sum() to add all values in the list together.

A generator expression looks almost identical to a list comprehension: simply swap out square brackets with parentheses.

*Here's a fun tip: When a generator expression is the only argument in a function (in this case, sum()), you can drop the inner parentheses.*

In [None]:
hourly = (worker['rate'] / 2080 for worker in employeeDatabase if worker.get('pay_class') == 'annual')
salariesPerHour = sum(hourly)

print(f"${salariesPerHour:.2f}")

$88.46


## Profiling: finding bottlenecks

Profiling is any technique used to measure the performance of your code, in particular its speed. There are dozens of tools available for profiling. We'll use a few to:
1. **Check memory use:** Use `sys.getsizeof()` to check the memory size of variables.
1. **Spot-profile your code:** Use the `timeit` notebook magic to perform some basic profiling by cell or by line.
1. **Profile your script comprehensively:** The `cProfile` module has the ability to break down call by call to determine the number of calls and the total time spent on each.

### Check memory use with `getsizeof()`

Use this tool to quickly check how much memory a variable is taking up on your system.

In [None]:
import sys

tract1 = {
    "area": 100,
    "area_water": 20,
    "population": 1000
}

print(f"Bytes: {sys.getsizeof(tract1)}")

Bytes: 184


In [None]:
print(f"Bytes: {sys.getsizeof(recaman_list(1000))}")
print(f"Bytes: {sys.getsizeof(recaman_set(1000))}")

Bytes: 8856
Bytes: 2097368


"You said sets were better than lists!"

Remember, sets are preferred over lists for membership lookup because they are faster, not slimmer.
- If you care more about *output size*, make a list; it takes up less memory.
- If you care more about *task speed*, make a set.

### Spot-check speed with `%%timeit`

The `timeit` module measures the execution time of a selection of code. Among the many ways you'll see it written are "magic" commands:
- `%timeit` is a form of _line magic_. Line magic arguments only extend to the end of the current line.
- `%%timeit` is a form of _cell magic_. It measures the execution time of the entire notebook cell.

With both of these commands, the notebook will test your code multiple times and print the average speed of those calls.

In [None]:
%%timeit
# Cell magic example
from typing import NamedTuple

class Tract(NamedTuple):
    population: int
    households: int

tract1 = Tract(1000, 500)
tract2 = Tract(2000, 800)
tract3 = Tract(5000, 3000)

tract1.households

117 µs ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [None]:
# Line magic example
%timeit sum(hourly_rate(employeeDatabase))
%timeit sum(hourly_rate_gen(employeeDatabase))

3.56 µs ± 118 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
5.34 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


"Wait! You said the generator would be faster!"

With small datasets (like this sample), a function might be slightly, negligibly faster than a generator. Increase the input size by 10x, 100x, or 100000x to see the generator outpace its competitors.

`timeit` tip: Optionally, you can limit the number of calls and repetitions with:
- -n (number of times to execute the main statement) and
- -r (number of times to repeat the timer).

In [None]:
%timeit -n 1 -r 5 sum(hourly_rate_gen(employeeDatabase))

The slowest run took 8.05 times longer than the fastest. This could mean that an intermediate result is being cached.
3.08 µs ± 2.65 µs per loop (mean ± std. dev. of 5 runs, 1 loop each)


### Profile with `cProfile`

Whereas `timeit` is a quick way to test speed, `cProfile` is useful as a comprehensive and holistic code profiler. Some perks of `cProfile`:
 - Compare which lines take longest to execute
 - See how often a function is executed
 - Sort profiling results by time
 - See the respective data the function interacts with
 - Print detailed reports with multiple statistics

Let's take a look:

In [None]:
import cProfile

cProfile.run('recaman_list(10000)')

         20002 function calls in 0.383 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     9999    0.371    0.000    0.371    0.000 <ipython-input-61-1cfc8d8a116c>:1(recaman_check)
        1    0.011    0.011    0.383    0.383 <ipython-input-61-1cfc8d8a116c>:4(recaman_list)
        1    0.000    0.000    0.383    0.383 <string>:1(<module>)
        1    0.000    0.000    0.383    0.383 {built-in method builtins.exec}
     9999    0.001    0.000    0.001    0.000 {method 'append' of 'list' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}




Some of the most useful outputs from a `cProfile` statement:
- Above the table, you are given the number of function calls and how long the code took overall.
- cumtime: The cumulative time it took to call a given function, including all of its subfunctions.
- filename: The data that the function interacted with

In [None]:
cProfile.run('recaman_set(10000)')

         200002 function calls in 0.102 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    99999    0.028    0.000    0.028    0.000 <ipython-input-61-1cfc8d8a116c>:1(recaman_check)
        1    0.059    0.059    0.101    0.101 <ipython-input-63-f612dc588057>:1(recaman_set)
        1    0.002    0.002    0.102    0.102 <string>:1(<module>)
        1    0.000    0.000    0.102    0.102 {built-in method builtins.exec}
    99999    0.014    0.000    0.014    0.000 {method 'add' of 'set' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}




These results show that the set-based function executed 10x more calls (200002), but ran 4x faster (0.086 seconds).

`cProfile` tip: Use cProfile.Profile() as a context manager!

In [7]:
import cProfile

with cProfile.Profile() as pr:
  def recaman_check(cur, i, visited):
    return (cur - i) < 0 or (cur - i) in visited

  def recaman_set(n: int) -> list[int]:
      """
      return a set of the first n numbers of the Recaman series
      """
      visited_set = {0}
      current = 0
      for i in range(1, 100_000):
          if recaman_check(current, i, visited_set):
              current += i
          else:
              current -= i
          visited_set.add(current)
      return visited_set

  recaman_set(1000)

  pr.print_stats('line') # Order by line number.

         200008 function calls in 0.090 seconds

   Ordered by: line number

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    99999    0.012    0.000    0.012    0.000 {method 'add' of 'set' objects}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    99999    0.025    0.000    0.025    0.000 <ipython-input-7-b276f05b123d>:4(recaman_check)
        1    0.053    0.053    0.090    0.090 <ipython-input-7-b276f05b123d>:7(recaman_set)
        1    0.000    0.000    0.000    0.000 cProfile.py:41(print_stats)
        1    0.000    0.000    0.000    0.000 cProfile.py:51(create_stats)
        1    0.000    0.000    0.000    0.000 pstats.py:108(__init__)
        1    0.000    0.000    

---

# Exercises

__Exercises summary__
1. Replace lists with efficient alternatives
    1. Storage: List to tuple
    1. Look-up: List to set
1. Replace iterables with efficient alternatives
    1. Iteration: List comprehension to generator expression
    1. Iteration: Function to generator
1. Check for speed bottlenecks
    1. Compare differences in speed with `timeit`
    1. Check for speed bottlenecks in detail with `cProfile`

## 1) Replace lists with efficient alternatives

### 1a) Tuple-based storage

The code below creates a list containing all years in a research study timeframe, from 1900 to 2030.

The values in this collection will not need to be changed because the study will always use this timeframe.

In [None]:
import sys

def listFromRange(r1, r2):
  """Create a list from a range of values"""
  return [item for item in range(r1, r2+1)]

start = 1900
end = 2030

studyYears = listFromRange(start, end)

print(studyYears)
print("Bytes used: ", sys.getsizeof(studyYears))

[range(1900, 2031)]
Bytes used:  64


**Your turn:** For the same timeframe, write a different implementation using a storage option that takes up less memory.

In [None]:
# # # Exercise solution # # #

def tupleFromRange(r1, r2):
  """Create a tuple from a range of values"""
  return tuple(range(r1, r2+1))

start = 1900
end = 2030

studyYears = tupleFromRange(start, end)

print(studyYears)
print("Bytes used: ", sys.getsizeof(studyYears))

(1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030)
Bytes used:  1088


### 1b) Set-based look-up

The code below assigns a collection of placenames to a list. Then, it checks whether a placename is in the list. If not, the placename is reported missing.

If you have 1 million placenames to look up and 6 names in the list, that’s up to 6 million checks.

In [None]:
placeNames_list = ["Kinshasa", "Duluth", "Uruguay", "Doherty Residence", "Dinkytown", "Khazad-dum"]

# List look-up
if "Dinkytown" not in placeNames_list:
    print("Missing.")  # O(n) look-up

**Your turn:** Write a different implementation using a storage option that allows quicker checks for membership at scale.

In [None]:
# # # Exercise solution # # #

placeNames_set = set(placeNames_list)

# Set look-up
if "Dinkytown" not in placeNames_set:
    print("Missing.")  # O(1) look-up

## 2) Replace iterables with efficient alternatives

### 2a) Generator expression

You have a list of random strings which contain a combination of upper and lowercase letters. You have written a list comprehension, lowerCase, to rewrite all of these strings into lowercase.

In [16]:
import random
import string

# Input dataset: A list of random strings. Each string is 8 letters long.
randomStrings = [''.join(random.choices(string.ascii_letters, k=8)) for i in range(50)]
print(randomStrings)

# Convert all strings to lowercase
lowerCase = [x.lower() for x in randomStrings]
print(lowerCase)

['DOhEEWyq', 'LwTpaYti', 'hISzIpVu', 'zVEOxAzd', 'kmYdVWrR', 'ZQUqNUen', 'uheqBKMt', 'HhXVDwcZ', 'NTzjXLTf', 'FweGNqLW', 'duhrAAyk', 'VWItTMOR', 'ysfectKt', 'hWtPjfsv', 'KaSGdbdN', 'xOcbxIcH', 'TBXlgafL', 'rXSpEuqM', 'guouEKoF', 'qgGIqSub', 'vuqfcXMB', 'ZTRNGwJs', 'AbNtBCHp', 'lgpjstts', 'xeTTpjpt', 'xIPhPuhZ', 'ZAilJHHh', 'EuguGHpD', 'dKvgFHQM', 'osrezGNR', 'MmSfpuWx', 'HkMDCSSL', 'IEyiLxZM', 'WrYweAcu', 'lvcJGTNl', 'efMKzQEY', 'qevCMkoo', 'ZsLBcASU', 'HmiiGMQb', 'ZceAavfS', 'pkHTOnwK', 'WnjLYTuA', 'xJLvJlue', 'YyEmyaKm', 'xjVOacdI', 'ulivhxvV', 'lLWNKRgh', 'mbDIxYuD', 'BALBErLW', 'xvZZKxOX']
['doheewyq', 'lwtpayti', 'hiszipvu', 'zveoxazd', 'kmydvwrr', 'zquqnuen', 'uheqbkmt', 'hhxvdwcz', 'ntzjxltf', 'fwegnqlw', 'duhraayk', 'vwittmor', 'ysfectkt', 'hwtpjfsv', 'kasgdbdn', 'xocbxich', 'tbxlgafl', 'rxspeuqm', 'guouekof', 'qggiqsub', 'vuqfcxmb', 'ztrngwjs', 'abntbchp', 'lgpjstts', 'xettpjpt', 'xiphpuhz', 'zailjhhh', 'eugughpd', 'dkvgfhqm', 'osrezgnr', 'mmsfpuwx', 'hkmdcssl', 'ieyilxzm', 'w

**Your turn**: Write a different implementation that still prints all the lowercase results, but operates faster than a list comprehension (when used with a large dataset).

In [19]:
# # # Exercise solution # # #

lowerCase_gen = (x.lower() for x in randomStrings)
for x in lowerCase_gen:
  print(x)

doheewyq
lwtpayti
hiszipvu
zveoxazd
kmydvwrr
zquqnuen
uheqbkmt
hhxvdwcz
ntzjxltf
fwegnqlw
duhraayk
vwittmor
ysfectkt
hwtpjfsv
kasgdbdn
xocbxich
tbxlgafl
rxspeuqm
guouekof
qggiqsub
vuqfcxmb
ztrngwjs
abntbchp
lgpjstts
xettpjpt
xiphpuhz
zailjhhh
eugughpd
dkvgfhqm
osrezgnr
mmsfpuwx
hkmdcssl
ieyilxzm
wryweacu
lvcjgtnl
efmkzqey
qevcmkoo
zslbcasu
hmiigmqb
zceaavfs
pkhtonwk
wnjlytua
xjlvjlue
yyemyakm
xjvoacdi
ulivhxvv
llwnkrgh
mbdixyud
balberlw
xvzzkxox


### 2b) Generator

In [50]:
# The data that a collection must match in order to proceed:
primList = [4, 7, 140]

# The collections that we are comparing to primList:
inputs = (
 [0, 3, 0],
 [5, 4, 3],
 [7, 150, 0.5, 1]
 )

In [51]:
def matchingStructure(collections, primList):
  """
  This function compares the length of each input collection to the primary list
  (primList). An input that matches in length gets multiplied by the primary
  list and returned.
  """
  results = []
  for item in collections:
    if len(item) == len(primList):
      multiplied = [a * b for a, b in zip(item, primList)]
      results.append(multiplied)
  return results

print(matchingStructure(inputs, primList))

[[0, 21, 0], [20, 28, 420]]


Your turn: Write a different implementation that uses a generator instead of a function to compare each list's shape.

In [52]:
# # # Exercise solution # # #

def matchingStructure_gen(collections, primList):
  """
  This generator compares the shape of each input list to the primary list
  (primList). An input that matches all conditions is multiplied by the
  primary list and yielded. An input that fails any condition is skipped.
  """
  for item in collections:
    if len(item) == len(primList):
      multiplied = [a * b for a, b in zip(item, primList)]
      yield multiplied

for item in matchingStructure_gen(inputs, primList):
  print(item)

[0, 21, 0]
[20, 28, 420]


## 3) Check for speed bottlenecks

### 3a) Compare differences in speed using `timeit`

Using `%%timeit`, compare the time it took to create myDataPaths as a list (original code) versus as a tuple (exercise solution).

In [None]:
%%timeit


In [None]:
%%timeit
## Your solution here ##

Use `%%timeit` again to compare list-based lookup to set intersection.

In [None]:
%%timeit


In [None]:
%%timeit
## Your solution here ##

Finally, compare the second list vs. set change that you made.

In [None]:
%%timeit


In [None]:
%%timeit
## Your solution here ##

### 3b) Check for speed bottlenecks in detail using `cProfile`

Use cProfile to locate the slowest calls in your improved script.

Hint: Sort by tottime instead of name to find hotspots more easily.



---



## Stretch Goal

### Raster Generator

Let's say you have a raster depicting 500 square meter population density (people per 500m²) across a country. That's a huge dataset! You want to resample the raster down to 1 square kilometer (people per 1km²) to make it easier to work with.

To do this, you have written a function that creates a new raster of 1km² grid cells. Each 1km² cell contains the total population of all 500m² cells within it.

In [None]:
import numpy as np

# Starting dataset: 80x80 grid of people per 500m².
highResPop = np.ones((80, 80)) * 5

*Note: The example here uses arrays to represent the rasters for simplicity, and each 500m² cell contains exactly 5 people.*

In [None]:
def densityKM(popArray):
    """
    Function that returns population density per km² cell from a
    500 m² resolution population source.

    Input:  500x500m 2D array
    Output: 1x1km 2D array, covering the same area of interest.
    """
    group_size = 20 # Every 20x20 group of 500m² cells equals 1km².
    rows, cols = popArray.shape

    # Aggregate
    kmArray = popArray.reshape(
        rows // group_size, group_size,
        cols // group_size, group_size
    )

    # Sum over each group
    kmDensity = kmArray.sum(axis=(1, 3))

    # Output
    return kmDensity

In [None]:
densityKM(highResPop)

array([[2000., 2000., 2000., 2000.],
       [2000., 2000., 2000., 2000.],
       [2000., 2000., 2000., 2000.],
       [2000., 2000., 2000., 2000.]])

Your turn: Convert this function into a generator.

In [None]:
# # # Exercise solution, version 1 # # #
def densityKM_gen(popArray):
    """
    Generator that yields rows of population density per km² cells from a
    500 m² resolution population source.

    Input:  500x500m 2D array
    Output: Each yield output is a 1D array representing one row of densities.
    """
    group_size = 20
    rows, cols = popArray.shape

    # Aggregate
    kmArray = popArray.reshape(
        rows // group_size, group_size,
        cols // group_size, group_size
    )

    # Sum over each group
    kmDensity = kmArray.sum(axis=(1, 3))

    for row in kmDensity:
        yield row  # Now yields an array

In [None]:
for row in densityKM_gen(highResPop):
    print(row)

[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]


In [None]:
# # # Exercise solution, version 2 (more memory efficient) # # #
def densityKM_gen2(popArray):
    """
    Generator that yields rows of population density per km² cells from a
    500 m² resolution population source.
    Unlike Solution Version 1, this generator does not create the entire km²
    array in memory. It saves memory by processing one group of 20x20
    cells at a time.

    Input:  500x500m 2D array
    Output: Each yield is a 1D NumPy array representing one row of km²
    densities, processed group by group.
    """
    import numpy as np

    group_size = 20
    rows, cols = popArray.shape

    num_row_blocks = rows // group_size
    num_col_blocks = cols // group_size

    for i in range(num_row_blocks):
        row_densities = []
        row_start = i * group_size

        for j in range(num_col_blocks):
            col_start = j * group_size
            block = popArray[row_start:row_start + group_size,
                             col_start:col_start + group_size]
            density = block.sum()
            row_densities.append(density)

        yield np.array(row_densities)

In [None]:
for row in densityKM_gen2(highResPop):
    print(row)

[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]
[2000. 2000. 2000. 2000.]
