In [1]:
import numpy as np
from collections import defaultdict

In [2]:
def load_data(datafile):
    polymer_map = {}
    with open(datafile, 'r') as fp:
        data = fp.readlines()
        init_str = data[0].strip()
        for l in data[2:]:
            pair, element = l.strip().split(' -> ')
            polymer_map[pair] = element
    return init_str, polymer_map

In [3]:
load_data("test.txt")

('NNCB',
 {'CH': 'B',
  'HH': 'N',
  'CB': 'H',
  'NH': 'C',
  'HB': 'C',
  'HC': 'B',
  'HN': 'C',
  'NN': 'C',
  'BH': 'H',
  'NC': 'B',
  'NB': 'B',
  'BN': 'B',
  'BB': 'N',
  'BC': 'B',
  'CC': 'N',
  'CN': 'C'})

In [52]:
def poly_grow(datafile, nstep=4):
    poly_str, poly_map = load_data(datafile)
    poly_arr = [c for c in poly_str]
    count = 0
    while count < nstep:
        n = 0
        new_arr = []
        while n < len(poly_arr)-1:
            new_arr.append(poly_arr[n])
            new_arr.append(poly_map[poly_arr[n] + poly_arr[n+1]])
            n += 1
        new_arr.append(poly_arr[-1])
        poly_arr = new_arr
        count += 1
    return ''.join(poly_arr)

def poly_count(datafile, nstep=4):
    pairs_counter = defaultdict(int)
    poly_str, poly_map = load_data(datafile)
    for p in [i + j for i, j in zip(poly_str, poly_str[1:])]:
        pairs_counter[p] += 1
    
    for _ in range(nstep):
        next_pairs = defaultdict(int)
        for pair, sum in pairs_counter.items():
            a, b = pair[0], pair[1]
            i = poly_map[pair]
            next_pairs[a+i] += sum
            next_pairs[i+b] += sum
        pairs_counter = next_pairs

    totals = defaultdict(int)
    for pair, val in pairs_counter.items():
        a, b = pair[0], pair[1]
        totals[a] += val
    totals[poly_str[-1]] += 1
    diff = max(totals.values()) - min(totals.values())
    return totals, diff
    

In [88]:
poly_grow("test.txt") == "NBBNBNBBCCNBCNCCNBBNBBNBBBNBBNBBCBHCBHHNHCBBCBHCB"

True

In [48]:
full_str = poly_grow("input.txt", nstep=10)
cnt = defaultdict(int)
for c in full_str:
    cnt[c] += 1
cnt, max(cnt.values()) - min(cnt.values())

(defaultdict(int,
             {'B': 965,
              'N': 1985,
              'F': 2659,
              'V': 2381,
              'C': 2150,
              'S': 2323,
              'P': 1240,
              'K': 3147,
              'O': 1805,
              'H': 802}),
 2345)

In [53]:
poly_count("input.txt", nstep=10)

(defaultdict(int,
             {'B': 965,
              'N': 1985,
              'F': 2659,
              'V': 2381,
              'C': 2150,
              'S': 2323,
              'P': 1240,
              'K': 3147,
              'O': 1805,
              'H': 802}),
 2345)

In [57]:
%timeit poly_count("input.txt", nstep=40)

1.45 ms ± 7.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
