# Smoke Basin

The analysis that follows pertains to the ninth day of the [Python Problem-Solving Bootcamp](https://mathspp.com/pythonbootcamp).

In the analysis that follows you may be confronted with code that you do not understand, especially as you reach the end of the explanation of each part.

If you find functions that you didn't know before, remember to [check the docs](https://docs.python.org/3/) for those functions and play around with them in the REPL.
This is written to be increasing in difficulty (within each part of the problem), so it is understandable if it gets harder as you keep reading.
That's perfectly fine, you don't have to understand everything _right now_, especially because I can't know for sure what _your level_ is.

## Part 1 problem statement

(From [Advent of Code 2021, day 9](https://adventofcode.com/2021/day/9))

These caves seem to be [lava tubes](https://en.wikipedia.org/wiki/Lava_tube). Parts are even still volcanically active; small hydrothermal vents release smoke into the caves that slowly settles like rain.

If you can model how the smoke flows through the caves, you might be able to avoid it and be that much safer. The submarine generates a heightmap of the floor of the nearby caves for you (your puzzle input).

Smoke flows to the lowest point of the area it's in. For example, consider the following heightmap:

```
2199943210
3987894921
9856789892
8767896789
9899965678
```

Each number corresponds to the height of a particular location, where `9` is the highest and `0` is the lowest a location can be.

Your first goal is to find the _low points_ - the locations that are lower than any of its adjacent locations. Most locations have four adjacent locations (up, down, left, and right); locations on the edge or corner of the map have three or two adjacent locations, respectively. (Diagonal locations do not count as adjacent.)

In the above example, there are _four_ low points: two are in the first row (a `1` and a `0`), one is in the third row (a `5`), and one is in the bottom row (also a `5`), all of them marked with a `.` in the grid below. All other locations on the heightmap have some lower adjacent location, and so are not low points.

```
2.9994321.
3987894921
98.6789892
8767896789
989996.678
```

The _risk level_ of a low point is _1 plus its height_. In the above example, the risk levels of the low points are `2`, `1`, `6`, and `6`. The sum of the risk levels of all low points in the heightmap is therefore `15`.

Find all of the low points on your heightmap. _What is the sum of the risk levels of all low points on your heightmap?_

_Using the input file `input.txt`, the result should be `502`._

In [1]:
# IMPORTANT: Set this to the correct path for you!
INPUT_FILE = "data/input.txt"

## Baseline solution

For the baseline solution to this problem, we will do exactly what it says in the problem statement:
we will go through all of the positions of the grid, and find out which positions are low points:

In [9]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

risk = 0
for r, row in enumerate(grid[:2]):
    for c, val in enumerate(row):
        print(row)
        print(val)
        print()
        up = grid[r - 1][c] if r > 0 else float("inf")
        print(f"up: {up}")
        down = grid[r + 1][c] if r < len(grid) - 1 else float("inf")
        print(f"down: {down}")
        left = grid[r][c - 1] if c > 0 else float("inf")
        print(f"left: {left}")
        right = grid[r][c + 1] if c < len(grid) - 1 else float("inf")
        print(f"right: {right}")
        if val < up and val < down and val < left and val < right:
            risk += val + 1
print(risk)

[2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 4, 2, 1, 0, 1, 2, 3, 4, 7, 8, 9, 9, 8, 7, 8, 5, 1, 0, 9, 8, 6, 5, 6, 6, 7, 9, 9, 9, 8, 7, 6, 5, 4, 5, 6, 7, 9, 8, 6, 4, 5, 6, 9, 9, 5, 4, 3, 9, 9, 7, 6, 7, 8, 9, 7, 6, 4, 3, 4, 7, 9, 9, 6, 5, 6, 7, 8, 9, 9, 9, 8, 9, 9, 4, 3, 2, 4, 5, 6, 7, 9, 9, 8, 8, 7, 6, 7]
2

up: inf
down: 3
left: inf
right: 1
[2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 4, 2, 1, 0, 1, 2, 3, 4, 7, 8, 9, 9, 8, 7, 8, 5, 1, 0, 9, 8, 6, 5, 6, 6, 7, 9, 9, 9, 8, 7, 6, 5, 4, 5, 6, 7, 9, 8, 6, 4, 5, 6, 9, 9, 5, 4, 3, 9, 9, 7, 6, 7, 8, 9, 7, 6, 4, 3, 4, 7, 9, 9, 6, 5, 6, 7, 8, 9, 9, 9, 8, 9, 9, 4, 3, 2, 4, 5, 6, 7, 9, 9, 8, 8, 7, 6, 7]
1

up: inf
down: 2
left: 2
right: 0
[2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 5, 4, 2, 1, 0, 1, 2, 3, 4, 7, 8, 9, 9, 8, 7, 8, 5, 1, 0, 9, 8, 6, 5, 6, 6, 7, 9, 9, 9, 8, 7, 6, 5, 4, 5, 6, 7, 9, 8, 6, 4, 5, 6, 9, 9, 5, 4, 3, 9, 9, 7, 6, 7, 8, 9, 7, 6, 4, 3, 4, 7, 9, 9, 6, 5, 6, 7, 8, 9, 9, 9, 8, 9, 9, 4, 3, 2, 4, 5, 6, 7, 9, 9, 8, 8, 7, 6, 7]
0

up: inf
down: 1
left

down: 9
left: 9
right: 5
[3, 2, 1, 2, 3, 4, 9, 6, 7, 8, 9, 8, 9, 5, 3, 3, 2, 3, 4, 5, 5, 6, 9, 9, 8, 9, 6, 5, 3, 2, 9, 8, 7, 7, 3, 4, 5, 6, 8, 9, 9, 7, 6, 5, 6, 3, 4, 5, 9, 8, 7, 4, 3, 4, 8, 7, 8, 9, 5, 9, 8, 8, 6, 5, 6, 7, 8, 9, 5, 3, 2, 5, 6, 7, 8, 9, 4, 5, 6, 7, 9, 8, 8, 7, 8, 8, 9, 4, 5, 5, 6, 7, 9, 8, 7, 6, 5, 6, 5, 7]
5

up: 2
down: 9
left: 4
right: 5
[3, 2, 1, 2, 3, 4, 9, 6, 7, 8, 9, 8, 9, 5, 3, 3, 2, 3, 4, 5, 5, 6, 9, 9, 8, 9, 6, 5, 3, 2, 9, 8, 7, 7, 3, 4, 5, 6, 8, 9, 9, 7, 6, 5, 6, 3, 4, 5, 9, 8, 7, 4, 3, 4, 8, 7, 8, 9, 5, 9, 8, 8, 6, 5, 6, 7, 8, 9, 5, 3, 2, 5, 6, 7, 8, 9, 4, 5, 6, 7, 9, 8, 8, 7, 8, 8, 9, 4, 5, 5, 6, 7, 9, 8, 7, 6, 5, 6, 5, 7]
5

up: 4
down: 8
left: 5
right: 6
[3, 2, 1, 2, 3, 4, 9, 6, 7, 8, 9, 8, 9, 5, 3, 3, 2, 3, 4, 5, 5, 6, 9, 9, 8, 9, 6, 5, 3, 2, 9, 8, 7, 7, 3, 4, 5, 6, 8, 9, 9, 7, 6, 5, 6, 3, 4, 5, 9, 8, 7, 4, 3, 4, 8, 7, 8, 9, 5, 9, 8, 8, 6, 5, 6, 7, 8, 9, 5, 3, 2, 5, 6, 7, 8, 9, 4, 5, 6, 7, 9, 8, 8, 7, 8, 8, 9, 4, 5, 5, 6, 7, 9, 8, 7, 6, 5, 6, 5, 7]
6

u

This solution is pretty straightforward.
Just note that we make use of `float("inf")` to make sure that, in case we are on the edge of the grid, we assign something that is always greater than `val` to that non-existing position.

In fact, there are many ways to handle the edges, so we will go over those now.

## Handling the edges of the grid

Another alternative that is sometimes more useful is to use the value we are going to compare against (in this case, that's `val`) and use it to build something that will fail the test.
In our case, we want to check if `val` is the smaller item, so we can use `val` to build a value that is greater than `val`:

In [3]:
risk = 0
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        up = grid[r - 1][c] if r > 0 else val + 1
        down = grid[r + 1][c] if r < len(grid) - 1 else val + 1
        left = grid[r][c - 1] if c > 0 else val + 1
        right = grid[r][c + 1] if c < len(grid) - 1 else val + 1
        if val < up and val < down and val < left and val < right:
            risk += val + 1
print(risk)

502


For integers and a `<` comparison, using `val + 1` might look weird; but imagine that, instead, we were comparing lengths of strings: what would you use as the default value?
You couldn't create a string with infinite length, and it probably wouldn't be a good idea to create a really long string, that could take up too much memory.

And even if you did create a _really_ long string, how could you know that that string would always be longer than whatever string showed up in `val`?
Instead, what you could do is write `... else val + " "` and that would create an even longer string:

In [4]:
matrix = [
    ["short", "longer", "short"],
    ["longer", "really long", "longer"],
    ["short", "longer", "short"],
]
for r, row in enumerate(matrix):
    for c, val in enumerate(row):
        up = matrix[r - 1][c] if r > 0 else val + " "
        down = matrix[r + 1][c] if r < len(matrix) - 1 else val + " "
        left = matrix[r][c - 1] if c > 0 else val + " "
        right = matrix[r][c + 1] if c < len(matrix) - 1 else val + " "
        L = len(val)
        if L < len(up) and L < len(down) and L < len(left) and L < len(right):
            print(val)

short
short
short
short


(Although, to be fair, this specific example could also be rewritten by using `float("inf")`; can you see how?)

Either way you choose to define your `else` part of the conditional expression, we should probably put all the values inside a list and then use the built-in `min` to figure out what number we really need to compare `val` against:

In [5]:
risk = 0
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [
            grid[r - 1][c] if r > 0 else float("inf"),
            grid[r + 1][c] if r < len(grid) - 1 else float("inf"),
            grid[r][c - 1] if c > 0 else float("inf"),
            grid[r][c + 1] if c < len(grid) - 1 else float("inf"),
        ]
        if min(values) > val:
            risk += val + 1
print(risk)

502


This is significantly better than what we had previously (four different variables) because this scales much better.
If the specification of the problem were to be changed and say that diagonals now also matter, we would just need to add four lines to the list.
With the previous approach, we would have to create four new variables and then fill in the `if` statement, which would also probably become really long.

You can still do things differently.
Instead of hardcoding the list of values to consider, we can have a short function that tries to compute the neighbouring positions to a given point.
Then, that function is responsible by checking that the coordinates are legal and only returns the legal positions:

In [6]:
def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

risk = 0
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [grid[r_][c_] for r_, c_ in neighbouring_positions(grid, r, c)]
        if min(values) > val:
            risk += val + 1
print(risk)

502


This a decent way to work around the issue of the edges on the grid, but there is yet another way of thinking about this that you might enjoy considering.

## Unconditionally accessing neighbours

Instead of having to use `if` statements to check that we are on the boundary, we could try to do something that allows us to index into the neighbours without worries.

One such alternative is to pad the whole grid with a series of `9`s, because those don't affect the way we parse low points.
Then, we just work on the interior of the grid regularly:

In [7]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

WIDTH, HEIGHT = len(grid[0]), len(grid)

grid = (  # Enclose the grid in an artificial boundary of 9s.
    [[9] * (WIDTH + 2)] +                # Create a row at the top containing only 9s.
    [[9] + row + [9] for row in grid] +  # Put a 9 on the left and right of each row.
    [[9] * (WIDTH + 2)]                  # Create a row at the bottom containing only 9s.
)

risk = 0
for r in range(1, HEIGHT + 1):
    for c in range(1, WIDTH + 1):
        values = [grid[r - 1][c], grid[r + 1][c], grid[r][c - 1], grid[r][c + 1]]
        if min(values) > grid[r][c]:
            risk += grid[r][c] + 1
print(risk)

502


By creating this artificial padding around the grid we are able to access the neighbours of all the positions that we care about, and we can do it without having to waste any time checking for boundary conditions.

Instead of padding, we can trade even more memory in another representation of the neighbours that is even more convenient, and a bit more array-oriented in nature:
we will pre-compute the four neighbours of each position beforehand, and just use that throughout the algorithm:

In [8]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

NEIGHBOURS = (
    [[9] * len(grid[0])] + grid[:-1],  # The “up” neighbours (r - 1).
    grid[1:] + [[9] * len(grid[0])],   # The “down” neighbours (r + 1).
    [[9] + row[:-1] for row in grid],  # The “left” neighbours (c - 1).
    [row[1:] + [9] for row in grid],   # The “right” neighbours (c + 1).
)

risk = 0
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        if min(neighbour[r][c] for neighbour in NEIGHBOURS) > val:
            risk += val + 1
print(risk)

502


In this alternative, we also did an implicit padding of the edges of the grid.
The tuple `NEIGHBOURS` contains four items, each one corresponding to a single shift of the grid.

So far, we have ignored a very basic optimisation that could be done, which is to skip the processing of the `9`s that we find along the grid:

In [9]:
risk = 0
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        if val == 9:
            continue
        if min(neighbour[r][c] for neighbour in NEIGHBOURS) > val:
            risk += val + 1
print(risk)

502


## Avoid repeated work

The solution(s) we have so far are straightforward, and they also do a lot of work.
Or rather, perhaps we could do something to try and save us from doing some extra work?

For example, even when we locate a low point, we still check if the points immediately next to it are low points as well...
But we already know they can't be!
So, we could have skipped those.

In order to implement this, we will start off with a set of all the positions that we want to investigate, and as we find positions that cannot be low points, we remove them from consideration.
For this particular idea, it is useful to know which neighbours we are looking at, and thus we'll bring back the auxiliary function `neighbouring_positions`:

In [10]:
def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

to_consider = {(r, c) for r in range(len(grid[0])) for c in range(len(grid))}
risk = 0
while to_consider:
    r, c = to_consider.pop()
    val = grid[r][c]
    if val == 9:
        continue
    is_low_point = True
    for r_, c_ in neighbouring_positions(grid, r, c):
        if val < grid[r_][c_]:
            to_consider.discard((r_, c_))  # Position r_, c_ can't be low point.
        else:
            is_low_point = False
            break
    if is_low_point:
        risk += val + 1
print(risk)

502


This solution isn't great.
Even though we had good intentions, the code we wrote is a bit more convoluted and it ended up not being faster!
However, this is a good piece of code to show you how the `else` statement of a `for` loop works.

If you look at the code above, we have a Boolean flag `is_low_point` that has the value `True` by default, and that lets us know if the current position is a low point, or not.
When we find a value below the tentative low point, we set that flag to `False` and leave the loop.
Then, we have an `if` statement to check if the flag was changed (and the loop broken) or not.

Instead of doing all this bookkeeping, we can use the `else` statement of the `for` loop.
The `else` statement of a `for` loop runs only when the loop was **not** stopped by a `break`.
Thus, the `else` statement in our code runs whenever the `for` loop finds a low point:

In [14]:
def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

to_consider = {(r, c) for r in range(len(grid[0])) for c in range(len(grid))}
risk = 0
while to_consider:
    r, c = to_consider.pop()
    val = grid[r][c]
    if val == 9:
        continue
    for r_, c_ in neighbouring_positions(grid, r, c):
        if val < grid[r_][c_]:
            to_consider.discard((r_, c_))  # Position r_, c_ can't be low point.
        else:
            break
    else:
        risk += val + 1
print(risk)

502


There you have it: premature optimisation is the root of all evil, but it did give us a chance to learn something else about `for` loops!

## Part 2 problem statement

(From [Advent of Code 2021, day 9](https://adventofcode.com/2021/day/9))

Next, you need to find the largest basins so you know what areas are most important to avoid.

A _basin_ is all locations that eventually flow downward to a single low point. Therefore, every low point has a basin, although some basins are very small. Locations of height `9` do not count as being in any basin, and all other locations will always be part of exactly one basin.

The _size_ of a basin is the number of locations within the basin, including the low point. The example above has four basins.
They are shown below with `.`s in place of the `9`s that surround them and with `o` inside the basin:

The top-left basin, size `3`:

```
oo.9943210
o.87894921
.856789892
8767896789
9899965678
```

The top-right basin, size `9`:

```
2199.ooooo
39878.o.oo
985678.8.o
876789678.
9899965678
```

The middle basin, size `14`:

```
21...43210
3.ooo.4.21
.ooooo.892
ooooo.6789
.o...65678
```

The bottom-right basin, size `9`:

```
2199943210
3987894.21
985678.o.2
87678.ooo.
9899.ooooo
```

Find the three largest basins and multiply their sizes together. In the above example, this is `9 * 14 * 9 = 1134`.

_What do you get if you multiply together the sizes of the three largest basins?_

_Using the input file `input.txt`, the result should be `1330560`._

## Flood fill

A very common approach to a problem of this type is to use a [flood fill](https://en.wikipedia.org/wiki/Flood_fill) algorithm, or something similar.
In essense, we visualise the grid as an implicit graph, and use any sort of search algorithm we like to traverse all the nodes in a given basin, until we hit the walls (the `9`s).

The previous problem had us find the lowest point of each basin, so we could use that information to seed our search algorithm:

In [22]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

def find_basin(matrix, r, c, visited=None):
    if visited is None:
        visited = set()
    visited.add((r, c))

    for r_, c_ in neighbouring_positions(matrix, r, c):
        if matrix[r_][c_] == 9 or (r_, c_) in visited:
            continue
        visited = find_basin(matrix, r_, c_, visited)  # Recursive call
    return visited

low_points = []
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [grid[r][c] for r, c in neighbouring_positions(grid, r, c)]
        if min(values) > val:
            low_points.append((r, c))

lengths = []
for r, c in low_points:
    basin = find_basin(grid, r, c)
    lengths.append(len(basin))

prod = 1
for n in sorted(lengths)[-3:]:
    prod *= n
print(prod)

1330560


The algorithm we just implemented has a nature that is essentially what is represented in the GIF that follows:

![](floodfill.gif)

We pick a position (the lowest position in each basin) and recursively visit the neighbours of the part of the basin we have explored so far, to keep expanding the basin.
However, we are currently doing it recursively, and we pass down the information about the visited nodes in a fourth argument that has a default value of `None`.
Why couldn't we have set the default value to `set()`?

Well, we can try:

In [23]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

def find_basin(matrix, r, c, visited=set()):
    visited.add((r, c))
    for r_, c_ in neighbouring_positions(matrix, r, c):
        if matrix[r_][c_] == 9 or (r_, c_) in visited:
            continue
        visited = find_basin(matrix, r_, c_, visited)  # Recursive call

    return visited

low_points = []
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [grid[r][c] for r, c in neighbouring_positions(grid, r, c)]
        if min(values) > val:
            low_points.append((r, c))

lengths = []
for r, c in low_points:
    basin = find_basin(grid, r, c)
    lengths.append(len(basin))

prod = 1
for n in sorted(lengths)[-3:]:
    prod *= n
print(prod)

398200898865


Wow!
The returned value is _huge_!
What's the issue..?
The issue is that a set is a mutable data type, and thus the default value `visited=set()` isn't behave like you expect.
You can read a more thorough explanation of how this works [in this article](https://mathspp.com/blog/pydonts/pass-by-value-reference-and-assignment), but the gist of it is that the set that is being used as a default value is actually being _reused_ for every call and it just keeps on accumulating the positions from all the previous basins.
Thus, we _do_ need a default value of `visited=None` and the accompanying `if` statement.

## Flood fill in imperative style

Now, as much as I find recursion a very elegant way of reasoning about problems (because I do!), I believe it is important to share a solution that makes use of a `while` loop.
(This is so that our solution works on larger grid sizes, because recursion has a limit to the depth it can reach.)

In [35]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

def find_basin(matrix, r, c):
    to_visit, seen, basin = {(r, c)}, {(r, c)}, set()
    while to_visit:
        r_, c_ = to_visit.pop()
        if matrix[r_][c_] == 9:
            continue
        basin.add((r_, c_))
        neighbs = neighbouring_positions(matrix, r_, c_)
        to_visit.update(neighbs - seen)
        seen.update(neighbs)
    return basin

low_points = []
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [grid[r][c] for r, c in neighbouring_positions(grid, r, c)]
        if min(values) > val:
            low_points.append((r, c))

lengths = []
for r, c in low_points:
    basin = find_basin(grid, r, c)
    lengths.append(len(basin))

prod = 1
for n in sorted(lengths)[-3:]:
    prod *= n
print(prod)

1330560


What might be interesting in the imperative implementation of the function `find_basin` above is the choice of data structure to hold all of the information we care about, as we have three sets:

 - one keeps track of the points that still have to be visited;
 - one keeps track of the points that have already been seen (“touched”) by the algorithm; and
 - one keeps track of all the points that actually belong to the basin.

Because we want to be quickly able to tell if a position has been seen before, it makes sense for it to be from the built-in `set` type.
For the data structure containing the positions to visit next and for the data structure containing the positions in the basin, other options would also make sense.
For example, we could use lists for both those variables:

In [36]:
def find_basin(matrix, r, c):
    to_visit, seen, basin = [(r, c)], {(r, c)}, []
    while to_visit:
        r_, c_ = to_visit.pop()
        if matrix[r_][c_] == 9:
            continue
        basin.append((r_, c_))
        neighbs = neighbouring_positions(matrix, r_, c_)
        to_visit.extend(neighbs - seen)
        seen.update(neighbs)
    return basin

low_points = []
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        values = [grid[r][c] for r, c in neighbouring_positions(grid, r, c)]
        if min(values) > val:
            low_points.append((r, c))

lengths = []
for r, c in low_points:
    basin = find_basin(grid, r, c)
    lengths.append(len(basin))

prod = 1
for n in sorted(lengths)[-3:]:
    prod *= n
print(prod)

1330560


In fact, we see that the changes aren't even that many.

For this particular algorithm, there isn't a need for something like `collections.deque` because we won't be using any of its methods that a list or a set doesn't provide.

## Avoiding unnecessary `for` loops

In our quest for writing clean and elegant code, we have fought many battles against unnecessary `for` loops, and now it is time for one more.

First off, we don't need to write a `for` loop to iterate over the list `low_points`, we can just do it with a list comprehension.
Secondly, since Python 3.8 that the module `math` contains a function `prod` that computes the product of an iterable of numbers, much like the built-in `sum` computes the sum of an iterable of numbers.

Therefore, the end of our algorithm can be improved to look like this:

In [39]:
from math import prod

lengths = [len(find_basin(grid, r, c)) for r, c in low_points]
print(prod(sorted(lengths)[-3:]))

1330560


If you don't have access to the function `math.prod` because you are in an earlier version of Python, you can always make your own:

In [40]:
from functools import reduce
from operator import mul

def prod(iterable):
    return reduce(mul, iterable, 1)

lengths = [len(find_basin(grid, r, c)) for r, c in low_points]
print(prod(sorted(lengths)[-3:]))

1330560


Once more, `functools.reduce` comes to the rescue; as we've come to notice, many useful algorithms are reductions.

## No pre-computation of lowest points

The first part of the challenge segued nicely into the second part and we were able to reuse information about the lowest points to make it slightly easier for us to find the basins.
However, this extra work is definitely not needed.

What we can do, instead, is to go over the grid and try to find a basin whenever we find a position that hasn't been put inside a basin yet:

In [41]:
from math import prod

with open(INPUT_FILE, "r") as f:
    grid = [
        [int(num) for num in line.strip()]
        for line in f
    ]

def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

def find_basin(matrix, r, c):
    to_visit, seen, basin = [(r, c)], {(r, c)}, []
    while to_visit:
        r_, c_ = to_visit.pop()
        if matrix[r_][c_] == 9:
            continue
        basin.append((r_, c_))
        neighbs = neighbouring_positions(matrix, r_, c_)
        to_visit.extend(neighbs - seen)
        seen.update(neighbs)
    return basin

lengths = []
# We mark with -1 points already in a basin.
for r, row in enumerate(grid):
    for c, val in enumerate(row):
        if val == -1 or val == 9:
            continue
        basin = find_basin(grid, r, c)
        for r_, c_ in basin:
            grid[r_][c_] = -1
        lengths.append(len(basin))

print(prod(sorted(lengths)[-3:]))

1330560


Now we had to go back to using loops, but at least we don't go over the whole grid more than what is needed!

Before we conclude this analysis, I would like to take a look at a fixed point application of the flood fill algorithm, which will allow us to apply the flood fill algorithm to the whole grid at once.

## Fixed point flood fill

The flood fill algorithm we implemented is standard, and we did it for one basin at a time.
What if we could do it for all basins at the same time?!

We will start by going through the grid and assigning a unique ID to each position that belongs to a basin:

In [79]:
next_id = 0
with open(INPUT_FILE, "r") as f:
    grid = [
        [(next_id := next_id + 1) if num != "9" else 0 for num in line.strip()]
        for line in f
    ]

It should be noted that this code will only run on Python 3.8 or later, because of the use of the [walrus operator](mathspp.com/blog/pydonts/assignment-expressions-and-the-walrus-operator) when parsing the grid.
The walrus operator is being used to provide a unique index to each position that is _not_ a wall.

Using the walrus operator is a cool way of achieving this effect, but the most appropriate way is probably using the correct tool for the job, `itertools.count`:

In [80]:
from itertools import count

id_counter = count(1)
with open(INPUT_FILE, "r") as f:
    grid = [
        [next(id_counter) if num != "9" else 0 for num in line.strip()]
        for line in f
    ]

`itertools.count` is a lazy generator that counts from `0` to infinity.
By giving it one argument, we tell it to start at a different value (`1`, in our case) and a second argument would tell it to take steps larger than `1`.
The built-in `next` is what tells `count` to provide its next item.

After the grid is populated with ids, we go over the grid and update each position to be the maximum between its ID and the largest ID it is a neighbour of.
We do this until an iteration of this loop doesn't do any changes to the grid.
As we go through the grid repeatedly, this makes it so that the largest IDs spread inside their basins, essentially filling the basins up.

After we fill the basins up, we use the object `collections.Counter` to figure out which IDs were used the most:

In [74]:
from collections import Counter
from itertools import chain
from math import prod

next_id = 0
with open(INPUT_FILE, "r") as f:
    grid = [
        [(next_id := next_id + 1) if num != "9" else 0 for num in line.strip()]
        for line in f
    ]

def neighbouring_positions(matrix, r, c):
    return {
        (r_, c_) for r_, c_ in [(r - 1, c), (r, c - 1), (r + 1, c), (r, c + 1)]
        if 0 <= r_ < len(matrix) and 0 <= c_ < len(matrix[0])
    }

grid, new_grid = [], grid
while grid != new_grid:
    grid = new_grid
    new_grid = []
    for r in range(len(grid)):
        new_row = []
        for c in range(len(grid[0])):
            if grid[r][c] == 0:
                new_row.append(0)
                continue
            M = max(grid[r_][c_] for r_, c_ in neighbouring_positions(grid, r, c))
            new_row.append(max(grid[r][c], M))
        new_grid.append(new_row)

counter = Counter(chain.from_iterable(grid))
counter[0] = 0  # Remove counts from the basin walls (9s that were turned into 0s).
print(prod(sorted(counter.values())[-3:]))

1330560


From a computational standpoint, this generally won't be the best method we could've come up with, and that's because we have to go through the _whole_ grid to make updates to the filling of the basins, even if there is just one basin that needs filling up.

On the other hand, this is an interesting approach because it showcases another idea of functional programming: the idea of a fixed point.
What we did was write a piece of code that keeps running until it has no effect on the object we are working with!

Let's take that piece of code and write it as an auxiliary function, so that we can express the fixed point in a more idiomatic way:

In [78]:
with open(INPUT_FILE, "r") as f:
    grid = [
        [(next_id := next_id + 1) if num != "9" else 0 for num in line.strip()]
        for line in f
    ]

def update_flood_fill(grid):
    new_grid = []
    for r in range(len(grid)):
        new_row = []
        for c in range(len(grid[0])):
            if grid[r][c] == 0:
                new_row.append(0)
                continue
            M = max(grid[r_][c_] for r_, c_ in neighbouring_positions(grid, r, c))
            new_row.append(max(grid[r][c], M))
        new_grid.append(new_row)
    return new_grid

while grid != (grid := update_flood_fill(grid)):  # Fixed point!
    pass

counter = Counter(chain.from_iterable(grid))
counter[0] = 0  # Remove counts from the basin walls (9s that were turned into 0s).
print(prod(sorted(counter.values())[-3:]))

1330560


Notice how we rewrote the outer `while` loop to be an empty loop.
Now, the main logic lies in the actual handling of the fixed point search.
The code

```py
while grid != (grid := update_flood_fill(grid)):
    pass
```

reads as “while the `grid` isn't the fixed point of `update_flood_fill`, keep going”.
Again, the “fixed point” of a function `f` is just the point `x` for which `f(x) == x`.

It is also worth mentioning that in the above, we use the [walrus operator](mathspp.com/blog/pydonts/assignment-expressions-and-the-walrus-operator) to update the value of the variable `grid` and, _at the same time_, allow its value to be compared to what it was just before the assignment.
This is what gets rid of the need for the auxiliary variable `new_grid`.
The walrus operator is available from Python 3.8, so you _do_ need the auxiliary variable if you are using an earlier version of Python.

## Conclusion

We saw a couple of ways to handle boundary conditions in a grid, including the conditional generation of the coordinates of the neighbours, padding the grid, and implicit padding with the pre-computation of all neighbours.
Depending on the context, all these can be useful.

We also glanced at an interesting concept from mathematics/functional programming: the concept of a fixed point.
It may look esoteric, but the pattern of wanting to look for some `x` such that `f(x) == x` is less esoteric than it might first seem.

If you have any questions, suggestions, remarks, recommendations, corrections, or anything else, you can reach out to me [on Twitter](https://twitter.com/mathsppblog) or via email to rodrigo at mathspp dot com.

**GIF credits**: By André Karwath aka Aka - Own work, CC BY-SA 2.5, https://commons.wikimedia.org/w/index.php?curid=481651