https://adventofcode.com/2021/day/16

In [2]:
#!/usr/bin/env python

from bitstring import BitArray

with open('data/16.txt') as fh:
    data = fh.read()

    
def iter_literal(ba, i):
    version = ba[i : i + 3].uint
    i += 3
    ptype = ba[i : i + 3].uint
    i += 3
    while ba[i]:
        i += 5
    i += 5
    yield (version, ptype, i)

    
def iter_op_0(ba, i):
    version = ba[i : i + 3].uint
    i += 3
    ptype = ba[i : i + 3].uint
    i += 3
    if ba[i]:
        raise ValueError("Wrong op type")
    i += 1
    sublen = ba[i : i + 15].uint
    i += 15
    yield (version, ptype, i)
    
    j = i + sublen
    while i < j:
        for (version, ptype, i) in dispatch_iter(ba, i):
            yield (version, ptype, i)
    if i > j:
        raise ValueError("Too far")

        
def iter_op_1(ba, i):
    version = ba[i : i + 3].uint
    i += 3
    ptype = ba[i : i + 3].uint
    i += 3
    if not ba[i]:
        raise ValueError("Wrong op type")
    i += 1
    subcount = ba[i : i + 11].uint
    i += 11
    yield (version, ptype, i)
   
    ctr = 0
    while ctr < subcount:
        ctr += 1
        for (version, ptype, i) in dispatch_iter(ba, i):
            yield (version, ptype, i)
    if ctr > subcount:
        raise ValueError("Too many")

        
def dispatch_iter(ba, i):
    version = ba[i : i + 3].uint
    ptype = ba[i + 3 : i + 6].uint
    if ptype == 4:
        yield from iter_literal(ba, i)
    else:
        # length type
        if ba[i + 6]:
            yield from iter_op_1(ba, i)
        else:
            yield from iter_op_0(ba, i)
            

part_1 = sum(v for (v, _, _) in dispatch_iter(BitArray(hex=data), 0))
print('part_1 =', part_1)

# Part 2

def collect_operands_0(ba, i):  
    i += 7
    sublen = ba[i : i + 15].uint
    i += 15
    j = i + sublen
    vals = []
    while i < j:
        val, i = dispatch_eval(ba, i)
        vals.append(val)
    return (vals, i)

def collect_operands_1(ba, i):  
    i += 7
    subcount = ba[i : i + 11].uint
    i += 11
    ctr = 0
    vals = []
    while ctr < subcount:
        ctr += 1
        val, i = dispatch_eval(ba, i)
        vals.append(val)
    return (vals, i)

operand_collectors = [collect_operands_0, collect_operands_1]

#0
def op_sum(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    return sum(operands), i

#1
def op_prod(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    prod = 1
    for val in operands:
        prod *= val
    return prod, i

#2
def op_min(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    return min(operands), i

#3
def op_max(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    return max(operands), i

#4
def literal(ba, i):
    i += 6
    val = BitArray()
    while ba[i]:
        s = i+1
        i += 5
        val += ba[s:i]
    s = i + 1
    i += 5
    val += ba[s:i]
    return val.uint, i

#5
def op_gt(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    a, b = operands
    return 1 if a > b else 0, i

#6
def op_lt(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    a, b = operands
    return 1 if a < b else 0, i

#7
def op_eq(ba, i):
    operands, i = operand_collectors[ba[i + 6]](ba, i)
    a, b = operands
    return 1 if a == b else 0, i

evaluators = [
    op_sum,
    op_prod,
    op_min,
    op_max,
    literal,
    op_gt,
    op_lt,
    op_eq
]

def dispatch_eval(ba, i):
    ptype = ba[i + 3 : i + 6].uint
    return evaluators[ptype](ba, i)


ba = BitArray(hex=data)
part_2, _ = dispatch_eval(ba, 0)
print('part_2 =', part_2)

part_1 = 951
part_2 = 902198718880
