In [14]:
use std::convert::TryInto;
use std::collections::HashMap;

fn char_to_byte(c: char) -> u8 {
    c as u8
}

fn parse_input(input: &str) -> (Vec<u8>, HashMap<(u8, u8), u8>) {
    let mut lines = input.split('\n');
    let template: Vec<u8> = lines.next().unwrap().chars().map(char_to_byte).collect();

    // skip line
    lines.next();
    
    let mut rules = HashMap::new();
    for line in lines {
        let mut parts = line.split(" -> ");
        let mut left = parts.next().unwrap().chars().map(char_to_byte);
        let right = char_to_byte(parts.next().unwrap().chars().next().unwrap());
        rules.insert((left.next().unwrap(), left.next().unwrap()), right);
    }

    (template, rules)
}

In [15]:
let inp = parse_input("FNFPPNKPPHSOKFFHOFOC

VS -> B
SV -> C
PP -> N
NS -> N
BC -> N
PB -> F
BK -> P
NV -> V
KF -> C
KS -> C
PV -> N
NF -> S
PK -> F
SC -> F
KN -> K
PN -> K
OH -> F
PS -> P
FN -> O
OP -> B
FO -> C
HS -> F
VO -> C
OS -> B
PF -> V
SB -> V
KO -> O
SK -> N
KB -> F
KH -> C
CC -> B
CS -> C
OF -> C
FS -> B
FP -> H
VN -> O
NB -> N
BS -> H
PC -> H
OO -> F
BF -> O
HC -> P
BH -> S
NP -> P
FB -> C
CB -> H
BO -> C
NN -> V
SF -> N
FC -> F
KK -> C
CN -> N
BV -> F
FK -> C
CF -> F
VV -> B
VF -> S
CK -> C
OV -> P
NC -> N
SS -> F
NK -> V
HN -> O
ON -> P
FH -> O
OB -> H
SH -> H
NH -> V
FF -> B
HP -> B
PO -> P
HB -> H
CH -> N
SN -> P
HK -> P
FV -> H
SO -> O
VH -> V
BP -> V
CV -> P
KP -> K
VB -> N
HV -> K
SP -> N
HO -> P
CP -> H
VC -> N
CO -> S
BN -> H
NO -> B
HF -> O
VP -> K
KV -> H
KC -> F
HH -> C
BB -> K
VK -> P
OK -> C
OC -> C
PH -> H");
inp

([70, 78, 70, 80, 80, 78, 75, 80, 80, 72, 83, 79, 75, 70, 70, 72, 79, 70, 79, 67], {(72, 78): 79, (80, 79): 80, (67, 72): 78, (66, 78): 72, (78, 86): 86, (86, 75): 80, (75, 78): 75, (79, 72): 70, (83, 72): 72, (80, 86): 78, (83, 67): 70, (80, 80): 78, (67, 70): 70, (83, 75): 78, (78, 75): 86, (83, 70): 78, (79, 78): 80, (79, 86): 80, (70, 75): 67, (79, 66): 72, (70, 70): 66, (80, 70): 86, (66, 86): 70, (75, 66): 70, (83, 66): 86, (83, 78): 80, (86, 80): 75, (66, 66): 75, (79, 67): 67, (75, 70): 67, (72, 66): 72, (79, 75): 67, (80, 78): 75, (80, 72): 72, (70, 66): 67, (86, 78): 79, (72, 72): 67, (72, 83): 70, (80, 83): 80, (78, 83): 78, (70, 67): 70, (86, 86): 66, (72, 80): 66, (67, 86): 80, (86, 66): 78, (72, 75): 80, (75, 72): 67, (79, 83): 66, (86, 79): 67, (67, 75): 67, (67, 79): 83, (78, 67): 78, (66, 83): 72, (70, 79): 67, (78, 66): 78, (79, 70): 67, (75, 75): 67, (67, 83): 67, (67, 80): 72, (72, 67): 80, (72, 79): 80, (83, 86): 67, (67, 78): 78, (86, 83): 66, (66, 75): 80, (70, 7

In [16]:
let sample = parse_input("NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C");
sample

([78, 78, 67, 66], {(67, 67): 78, (72, 72): 78, (72, 66): 67, (78, 78): 67, (66, 78): 66, (66, 72): 72, (66, 67): 66, (72, 78): 67, (72, 67): 66, (67, 66): 72, (78, 66): 66, (66, 66): 78, (78, 67): 66, (67, 78): 67, (78, 72): 67, (67, 72): 66})

In [17]:
fn polymerize(pattern: &[u8], rules: &HashMap<(u8, u8), u8>) -> Vec<u8> {
    let mut result = Vec::with_capacity(pattern.len() * 2 - 1);
    for x in pattern.windows(2) {
        result.push(x[0]);
        result.push(*rules.get(&(x[0], x[1])).unwrap());
    }
    result.push(*pattern.last().unwrap());
    result
}

In [18]:
String::from_utf8(polymerize(&sample.0, &sample.1)).unwrap()

"NCNBCHB"

In [19]:
fn solve(pattern: &[u8], rules: &HashMap<(u8, u8), u8>, step: usize) -> usize {
    let mut pattern: Vec<u8> = pattern.into();
    for _ in 0..step {
        pattern = polymerize(&pattern, rules);
    }
    let mut counts: HashMap<u8, usize> = HashMap::new();
    for i in pattern {
        *counts.entry(i).or_default() += 1;
    }
    counts.iter().map(|(_, i)| *i).max().unwrap() - counts.iter().map(|(_, i)| *i).min().unwrap()
}

In [20]:
solve(&sample.0, &sample.1, 10)

1588

In [21]:
solve(&inp.0, &inp.1, 10)

2975

In [22]:
use std::ops::{Add, AddAssign};
use std::convert::From;
use std::fmt::{Debug, Formatter, Error};

type Uint = usize;

#[derive(Clone, Debug, Default)]
struct ElemCount(HashMap<u8, usize>);

impl Add for ElemCount {
    type Output = ElemCount;
    fn add(self, rhs: ElemCount) -> Self::Output {
        let mut result = self.clone();
        result += rhs;
        result
    }
}

impl AddAssign for ElemCount {
    fn add_assign(&mut self, rhs: ElemCount) {
        for (k, v) in rhs.0.iter() {
            *self.0.entry(*k).or_default() += v;
        }
    }
}

impl From<[u8; 2]> for ElemCount {
    fn from(pattern: [u8; 2]) -> Self {
        let mut result = Self::default();
        result.incr(pattern[0]);
        result.incr(pattern[1]);
        result
    }
}

impl ElemCount {
    fn solve(&self) -> Uint {
        self.0.iter().map(|(_, i)| *i).max().unwrap() - self.0.iter().map(|(_, i)| *i).min().unwrap()
    }

    fn incr(&mut self, elem: u8) {
        *self.0.entry(elem).or_default() += 1;
    }

    fn decr(&mut self, elem: u8) {
        if self.0.get(&elem) == Some(&1) {
            self.0.remove(&elem);
        } else {
            *self.0.entry(elem).or_default() -= 1;
        }
    }
}

assert_eq!(ElemCount::from([1, 2]).solve(), 0);
assert_eq!((ElemCount::from([1, 2]) + ElemCount::from([1, 0])).solve(), 1);

#[derive(Clone, Default)]
struct Counter {
    mem: HashMap<([u8; 2], usize), ElemCount>,
    rules: HashMap<(u8, u8), u8>,
}

impl Debug for Counter {
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
        f.debug_map().entries(self.mem.iter().map(|(k, v)|
            ((std::str::from_utf8(&k.0).unwrap(), k.1), v)
        )).finish()
    }
}

impl Counter {
    fn new(rules: &HashMap<(u8, u8), u8>) -> Self {
        Self {
            mem: Default::default(),
            rules: rules.clone(),
        }
    }

    fn count(&mut self, pattern: [u8; 2], step: usize) -> ElemCount {
        if step == 0 {
            return pattern.into();
        }
        if let Some(ec) = self.mem.get(&(pattern, step)) {
            return ec.clone();
        }

        let mut result = ElemCount::default();
        let mid = *self.rules.get(&(pattern[0], pattern[1])).unwrap();
        result += self.count([pattern[0], mid], step - 1);
        result += self.count([mid, pattern[1]], step - 1);
        result.decr(mid);

        self.mem.insert((pattern, step), result.clone());
        result
    }

    fn solve(&mut self, pattern: &[u8], step: usize) -> Uint {
        let mut result = ElemCount::default();
        for pair in pattern.windows(2) {
            result += self.count([pair[0], pair[1]], step);
            result.decr(pair[1]);
        }
        result.incr(*pattern.last().unwrap());
        result.solve()
    }
}

In [23]:
// assert_eq!(solve(&sample.0, &sample.1, 10), Counter::new(&sample.1).solve(&sample.0, 10));

In [24]:
Counter::new(&sample.1).solve(&sample.0, 40)

2188189693529

In [25]:
let mut i = 0;
while i <= 10 && solve(&inp.0, &inp.1, i) == Counter::new(&inp.1).solve(&inp.0, i) {
    i += 1;
}
dbg!(i);
dbg!(solve(&inp.0, &inp.1, i));
dbg!(Counter::new(&inp.1).solve(&inp.0, i));

[src/lib.rs:263] i = 11
[src/lib.rs:264] solve(&inp.0, &inp.1, i) = 5733
[src/lib.rs:265] Counter::new(&inp.1).solve(&inp.0, i) = 5733


In [26]:
let mut counter = Counter::new(&inp.1);
counter.solve(&inp.0, 0);
dbg!(counter);

[src/lib.rs:261] counter = {}


In [27]:
assert_eq!(solve(&inp.0, &inp.1, 10), Counter::new(&inp.1).solve(&inp.0, 10));

In [28]:
Counter::new(&inp.1).solve(&inp.0, 40)

3015383850689