In [1]:
%matplotlib inline

import itertools
import re

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np


def map_array(map_txt):
    rows = map_txt.splitlines()
    map_ = np.array([list(r) for r in rows])
    return map_


def print_map(map_):
    for row in map_:
        print(''.join(row))

        
def out_of_bounds(pos, n, m):
    return (pos[0] < 0) | (pos[0] >= n) | (pos[1] < 0) | (pos[1] >= m)


directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]

In [4]:
txt = """
OOOOO
OXOXO
OOOOO
OXOXO
OOOOO
""".strip()

In [52]:
txt = """
RRRRIICCFF
RRRRIICCCF
VVRRRCCFFF
VVRCCCJFFF
VVVVCJJCFE
VVIVCCJJEE
VVIIICJJEE
MIIIIIJJEE
MIIISIJEEE
MMMISSJEEE
""".strip()

In [105]:
with open('input.txt', 'r') as f:
    txt = f.read().strip()

# Part 1

In [106]:
map_ = map_array(txt)
plants = np.unique(map_)
n, m = map_.shape

In [107]:
not_visited = set(itertools.product(range(n), range(m)))

perimeters = []
areas = []
while len(not_visited) > 0:
    pos = not_visited.pop()

    plant = map_[*pos]
    perimeter = 0
    area = 0
    visited = set()

    consider_queue = set()
    while True:
        visited.add(pos)
        area += 1
        for d in directions:
            next_pos = pos[0] + d[0], pos[1] + d[1]
            if out_of_bounds(next_pos, n, m):
                perimeter += 1
            elif map_[*next_pos] == plant:
                if not next_pos in visited:
                    consider_queue.add(next_pos)
            else:
                perimeter += 1
        if len(consider_queue) == 0:
            break
        pos = consider_queue.pop()
        not_visited.remove(pos)
        
    perimeters.append(perimeter)
    areas.append(area)

In [108]:
areas, perimeters

([83,
  149,
  71,
  164,
  94,
  48,
  114,
  121,
  86,
  70,
  191,
  98,
  109,
  11,
  12,
  61,
  73,
  213,
  48,
  265,
  64,
  28,
  265,
  136,
  99,
  45,
  49,
  52,
  77,
  106,
  128,
  91,
  62,
  15,
  140,
  47,
  189,
  68,
  76,
  105,
  32,
  72,
  200,
  148,
  39,
  60,
  6,
  39,
  85,
  262,
  59,
  66,
  91,
  48,
  1,
  1,
  145,
  104,
  210,
  126,
  37,
  67,
  108,
  44,
  107,
  250,
  57,
  13,
  151,
  1,
  56,
  20,
  24,
  125,
  119,
  207,
  114,
  85,
  46,
  75,
  18,
  76,
  56,
  31,
  128,
  38,
  51,
  21,
  39,
  89,
  35,
  141,
  49,
  106,
  123,
  61,
  88,
  22,
  5,
  114,
  195,
  119,
  305,
  128,
  47,
  5,
  25,
  59,
  78,
  71,
  130,
  11,
  41,
  6,
  154,
  71,
  126,
  148,
  95,
  1,
  132,
  93,
  27,
  72,
  18,
  31,
  166,
  147,
  79,
  106,
  2,
  118,
  33,
  39,
  53,
  55,
  102,
  146,
  130,
  253,
  104,
  59,
  108,
  77,
  103,
  84,
  22,
  41,
  50,
  33,
  138,
  89,
  15,
  2,
  169,
  156,
  69,
  63,
  96

In [109]:
sum(areas[idx] * perimeters[idx] for idx in range(len(areas)))

1485656

# Part 2

In [74]:
txt = """
AAAA
BBCD
BBCC
EEEC
""".strip()

In [94]:
txt = """
AAAAAA
AAABBA
AAABBA
ABBAAA
ABBAAA
AAAAAA
""".strip()

In [113]:
with open('input.txt', 'r') as f:
    txt = f.read().strip()

In [116]:
map_ = map_array(txt)
plants = np.unique(map_)
n, m = map_.shape

In [117]:
not_visited = set(itertools.product(range(n), range(m)))

edges = []
areas = []
while len(not_visited) > 0:
    pos = not_visited.pop()

    plant = map_[*pos]
    area = 0
    visited = set()
    outside = []

    consider_queue = set()
    while True:
        visited.add(pos)
        area += 1
        for d in directions:
            next_pos = pos[0] + d[0], pos[1] + d[1]
            if out_of_bounds(next_pos, n, m):
                outside.append((d[0], d[1], next_pos[0], next_pos[1]))
            elif map_[*next_pos] == plant:
                if not next_pos in visited:
                    consider_queue.add(next_pos)
            else:
                outside.append((d[0], d[1], next_pos[0], next_pos[1]))
        if len(consider_queue) == 0:
            break
        pos = consider_queue.pop()
        not_visited.remove(pos)
        
        
    g = nx.Graph()
    g.add_nodes_from(outside)
    for node1, node2 in itertools.combinations(outside, 2):
        dist = sum([abs(node1[i] - node2[i]) for i in range(len(node1))])
        if dist == 1:
            g.add_edge(node1, node2)
            
    nedges = nx.number_connected_components(g)
    edges.append(nedges)
    areas.append(area)

In [118]:
areas, edges

([83,
  149,
  71,
  164,
  94,
  48,
  114,
  121,
  86,
  70,
  191,
  98,
  109,
  11,
  12,
  61,
  73,
  213,
  48,
  265,
  64,
  28,
  265,
  136,
  99,
  45,
  49,
  52,
  77,
  106,
  128,
  91,
  62,
  15,
  140,
  47,
  189,
  68,
  76,
  105,
  32,
  72,
  200,
  148,
  39,
  60,
  6,
  39,
  85,
  262,
  59,
  66,
  91,
  48,
  1,
  1,
  145,
  104,
  210,
  126,
  37,
  67,
  108,
  44,
  107,
  250,
  57,
  13,
  151,
  1,
  56,
  20,
  24,
  125,
  119,
  207,
  114,
  85,
  46,
  75,
  18,
  76,
  56,
  31,
  128,
  38,
  51,
  21,
  39,
  89,
  35,
  141,
  49,
  106,
  123,
  61,
  88,
  22,
  5,
  114,
  195,
  119,
  305,
  128,
  47,
  5,
  25,
  59,
  78,
  71,
  130,
  11,
  41,
  6,
  154,
  71,
  126,
  148,
  95,
  1,
  132,
  93,
  27,
  72,
  18,
  31,
  166,
  147,
  79,
  106,
  2,
  118,
  33,
  39,
  53,
  55,
  102,
  146,
  130,
  253,
  104,
  59,
  108,
  77,
  103,
  84,
  22,
  41,
  50,
  33,
  138,
  89,
  15,
  2,
  169,
  156,
  69,
  63,
  96

In [119]:
sum(areas[idx] * edges[idx] for idx in range(len(areas)))

899196