In [42]:
from pathlib import Path
from utils import read_input
from collections import Counter
import numpy as np
from bresenham import bresenham
from statistics import mean, median, mode
from itertools import combinations, product
from tqdm import tqdm

In [2]:
test = Path("input/day9/test.txt")
test1 = Path("input/day9/test1.txt")
data = Path("input/day9/data.txt")

In [3]:
def get_heightmap(input_file):
    ht_map = read_input(input_file)
    return [[int(pos) for pos in line] for line in ht_map]

In [4]:
def is_lower_than_surrounds(row_position, col_position, position_height, heightmap):
    if position_height == 9:
        return False
    # make mask
    row_count = len(heightmap)
    column_count = len(heightmap[0])
    mask = np.ones((row_count, column_count))
    
    for row in range(row_position - 1, row_position + 2):
        if 0 <= row < row_count:
            if (row, col_position) != (row_position, col_position):
                mask[row][col_position] = 0
    for col in range(col_position - 1, col_position + 2):
        if 0 <= col < column_count:
            if (row_position, col) != (row_position, col_position):
                mask[row_position][col] = 0
    
    masked_array = np.ma.array(heightmap, mask=mask)
    surrounds = masked_array[~masked_array.mask]
    return position_height < min(surrounds)

In [5]:
def find_risk_levels(heightmap):
    row_count = len(heightmap)
    column_count = len(heightmap[0])
    all_positions = product(range(row_count), range(column_count))
    
    for position in all_positions:
        pos_r, pos_c = position
        position_height = heightmap[pos_r][pos_c]
        if is_lower_than_surrounds(pos_r, pos_c, position_height, heightmap):
            yield 1 + position_height

In [6]:
heightmap = get_heightmap(test)

In [7]:
sum(find_risk_levels(heightmap))

15

In [8]:
heightmap = get_heightmap(data)

In [9]:
#sum(find_risk_levels(heightmap))

In [43]:
def find_basin(heightmap, to_check=None, checked=None):
    checked = checked or set()
    to_check = to_check or []
    
    if not to_check:
        return checked
    
    new = set()
    
    while to_check:
        row_position, col_position = to_check.pop()
        checked.add((row_position, col_position))
        
        lower_row = row_position - 1
        while lower_row >= 0:
            new_point = (lower_row, col_position)
            if heightmap[lower_row][col_position] == 9 or new_point in checked:
                break
            new.add(new_point)
            lower_row -= 1

        higher_row = row_position + 1
        while higher_row < len(heightmap):
            new_point = (higher_row, col_position)
            if heightmap[higher_row][col_position] == 9 or new_point in checked:
                break
            new.add(new_point)
            higher_row += 1

        lower_col = col_position - 1
        while lower_col >= 0:
            new_point = (row_position, lower_col)
            if heightmap[row_position][lower_col] == 9 or new_point in checked:
                break
            new.add(new_point)
            lower_col -= 1

        higher_col = col_position + 1
        while higher_col < len(heightmap[0]):
            new_point = (row_position, higher_col)
            if heightmap[row_position][higher_col] == 9 or new_point in checked:
                break
            new.add(new_point)
            higher_col += 1    
    
    checked.add((row_position, col_position))
    return find_basin(heightmap, new, checked)

In [44]:
heightmap = get_heightmap(test)

In [60]:
def find_3_largest_basins(heightmap):
    low_points = []
    row_count = len(heightmap)
    column_count = len(heightmap[0])
    all_positions = product(range(row_count), range(column_count))
    
    print("Finding low points")
    for position in tqdm(all_positions):
        pos_r, pos_c = position
        position_height = heightmap[pos_r][pos_c]
        if is_lower_than_surrounds(pos_r, pos_c, position_height, heightmap):
            low_points.append(position)
            
    print("Finding basins")
    basins = []
    
    for low_point in tqdm(low_points):
        basin_size = len(find_basin(heightmap, [low_point]))
        if len(basins) < 3:
            basins.append(basin_size)
        elif basin_size > min(basins):
            basins.append(basin_size)
            basins.sort(reverse=True)
            basins.pop()
    
    return basins
            

In [61]:
heightmap = get_heightmap(test)
find_3_largest_basins(heightmap)

Finding low points


50it [00:00, 6812.70it/s]


Finding basins


100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 12576.62it/s]


[14, 9, 9]

In [62]:
heightmap = get_heightmap(data)
find_3_largest_basins(heightmap)

Finding low points


10000it [00:03, 2579.86it/s]


Finding basins


100%|█████████████████████████████████████████| 209/209 [00:00<00:00, 11560.04it/s]


[121, 116, 114]

In [63]:
121*116*114

1600104