In [23]:
import numpy as np

def readInput(input):
    start = ""
    rules = []
    lines = input.split("\n")
    readRules = False
    for line in lines:
        if line == "":
            readRules = True
            continue
        if readRules:
            vals = line.split(" -> ")
            rules.append((vals[0], vals[1]))
        else:
            start = line
    return start, rules

def checkRepeatRules(rules):
    pairs = {}
    for r in rules:
        if r[0] in pairs:
            print("Error: Pair already found, please change code to account for repeated rule!!")
            return False
        pairs[r[0]] = True
    return True

#Get pair index is slow and should only be run for initializations
def getPairIndex(pair, rules):
    for i in range(0, len(rules)):
        r = rules[i]
        if r[0] == pair:
            return i
    print("Error: Unable to find pair in rule")
    return -1
        

#Rule CH -> B means replace the string CH with CBH which has pairs CB and BH. 
#We can do matrix multiplication to track the pairs
def getPairTransformMatrix(rules):
    N = len(rules)
    M = np.zeros((N,N))
    for i in range(0,len(rules)):
        r = rules[i]
        indexA = getPairIndex(r[0][0]+r[1], rules)
        indexB = getPairIndex(r[1]+r[0][1], rules)
        M[indexA][i] = 1
        M[indexB][i] = 1
    return M

def getStartVector(start, rules):
    v = np.zeros(len(rules))
    for i in range(1,len(start)):
        loc = getPairIndex(start[i-1]+start[i], rules)
        v[loc] += 1
    return v

def getLetterCounts(v, rules, start):
    letters = {}
    for r in rules:
        letters[r[0][0]] = 0
        letters[r[0][1]] = 0
        letters[r[1]] = 0
    #Count letters in a way to avoid double counting
    for i in range(0,len(rules)):
        pair = rules[i][0]
        pairCount = v[i]
        letters[pair[0]] += pairCount/2.0
        letters[pair[1]] += pairCount/2.0
    #Add starting and ending characters
    letters[start[0]] += 0.5
    letters[start[-1]] += 0.5
    return letters

def getLetterCountsFromStr(s):
    letters = {}
    for c in s:
        if c in letters:
            letters[c] +=1
        else:
            letters[c] = 1
    return letters

numSteps = 10
start, rules = readInput(input)
checkRepeatRules(rules)
M = getPairTransformMatrix(rules)
v = getStartVector(start, rules)
for i in range(0,numSteps):
    v = np.matmul(M,v)
letters = getLetterCounts(v,rules,start)
print(letters)
print(max(letters.values()) - min(letters.values()))
        

{'C': 298.0, 'H': 161.0, 'B': 1749.0, 'N': 865.0}
1588.0


In [24]:
numSteps = 40
start, rules = readInput(input)
checkRepeatRules(rules)
M = getPairTransformMatrix(rules)
v = getStartVector(start, rules)
for i in range(0,numSteps):
    v = np.matmul(M,v)
letters = getLetterCounts(v,rules,start)
print(letters)
print(max(letters.values()) - min(letters.values()))


{'C': 6597635301.0, 'H': 3849876073.0, 'B': 2192039569602.0, 'N': 1096047802353.0}
2188189693529.0


In [22]:
input = """NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C"""