In [1]:
input_location = "inputs/input_20211209.txt"

with open(input_location) as f:
    data = f.read().splitlines()

In [2]:
def clean_data(data):
    cleaned_data = []
    for row in data:
        cleaned_row = [int(x) for x in row]
        cleaned_data.append(cleaned_row)
    return cleaned_data

cleaned_data = clean_data(data)

In [3]:
def pad_heatmap(data, i=9):
    """
    Pads the heatmap with 9's at the top and bottom and sides with 9's
    """
    padded_heatmap = []
    for row in data:
        padded_row = [i] + row + [i]
        padded_heatmap.append(padded_row)
    top_bottom_rows = [i] * len(padded_heatmap[0])
    padded_heatmap.insert(0, top_bottom_rows)
    padded_heatmap.append(top_bottom_rows)
    return padded_heatmap

def get_low_points(heatmap):
    """
    Returns a dictionary where 
    - key : coordinate
    - value : height
    """
    low_points = {}
    for x in range(1, len(heatmap)-1):
        for y in range(1, len(heatmap[0])-1):
            middle = heatmap[x][y]
            up, down = heatmap[x][y+1], heatmap[x][y-1]
            left, right = heatmap[x-1][y], heatmap[x+1][y]
            
            if middle < min(up, down, left, right):
                low_points[(x,y)] = middle
    return low_points



In [4]:
from collections import deque

def get_basin_count(low_point_coordinate, heatmap):
    checked = []
    to_check = deque([low_point_coordinate])
    
    while to_check:
        x, y = to_check.popleft()
        up, down = (x, y+1), (x, y-1)
        left, right = (x-1, y), (x+1 ,y)

        for coordinate in [up, down, left, right]:
            height = heatmap[coordinate[0]][coordinate[1]]
            if height < 9 and coordinate not in checked:
                to_check.append(coordinate)

        checked.append((x,y))
    
    return len(set(checked))

In [5]:
def solution_1(data):
    padded_heatmap = pad_heatmap(data)
    low_points = get_low_points(padded_heatmap)
    
    return sum(low_points.values())+len(low_points)

def solution_2(data):
    padded_heatmap = pad_heatmap(data)
    low_points = get_low_points(padded_heatmap)
    
    low_points_basin_count = []
    
    for low_point in low_points:
        basin_count = get_basin_count(low_point, padded_heatmap)
        low_points_basin_count.append(basin_count)
    
    # get the multiplication of the 3 highest size
    output = 1
    for basin_count in sorted(low_points_basin_count)[-3:]:
        output *= basin_count
    return output

In [6]:
print(solution_1(cleaned_data))   # 541
print(solution_2(cleaned_data))   # 847504

541
847504
