In [75]:
with open('day_16_input') as f:
    cmds = [bytes.fromhex(line.strip()) for line in f.readlines()]
cmds

[b'"\rK\x80I\x1f\xe6\xfb\xdc\xdaa\xf2?\x1d\x9bv0\x04\xa7\xc1(\x01/\x9d\xa8\x8c\xe2{\x00\x0b0\xf4\x80MI\xcdQS\x805!\x00v=\xc5\xe8\xec\x00\x08D3\x8b\x10\xb6g\xa1\xe6\x00\x94\xb7\xbe\x8d`\n\xcewM\xf3\x9d\xd3d\x97\x9fg\xa9\xac\r\x18\x02\xb2\xa4\x14\x015Ok\xf1\xdc\x06\'\xb1^\xc5\xcc\xc0\x16\x94\xf5\xba\xbf\xc0\td\xe9<\x95\xcf\x08\x02c\xf0\x04gA\xa7@\xa7kpC\x00\x82I&i2t\xbe|\xc8\x80&}\x00FHRHJ_tR\x00\x05\xd6Z\x1e\xad#4\xa7\x00\xbaN\xa4\x12V\xe4\xbb\xbd\x8d\xc0\x99\x9f\xc3\xa9r\x86\xc2\x01d\xb4\xff\x14\xa9?\xd2\x94t\x94\xe6\x83\xe7R\xe4\x9b\'7\xdf|@\x80\x18\x19sIe\t\xa5\xb9\xa8\xd3{|0\x044\x01i \xd9\xea\xef\x16\xae\xc0\xa4\xab}\xf5\xb1\xc0\x1c\x93;\x9a\xaf\x19\xe1\x81\x80\'\xa0\n\x80\x02\x1f\x1f\xa0\xe44\x00\x04>\x17F8W+\x98K\x06d\x01\xd3\xe8\x02sZJ\x9e\xce7\x17\x89hZ\xb3\xe0\xe8\x00rS3\xef\xfb\xb4\xb8\xd11\xa9\xf3\x9e\xd4\x13\xa1r\x00X\xf39\xee2\x05-H\xecN^\xc3\xa6\x00l\xc2\xb4\xbeo\xf3\xf4\x00\x17\xa0\xe4\xd5""`\t\xcagjv\x00\x98\x00!\xf1\x92\x14Fp\x00B\xa2<6\x8bq<\xc0\x15\xe0\x072J8\xdf0\xb

In [76]:
class StreamReader(object):
    
    def __init__(self, bytearr, initial_pos=0):
        self.bytearr = bytearr
        self.position = initial_pos
                
    def extract_value(self, byte_num, bit_off, num_bits):
        byte = self.bytearr[byte_num]
        mask = (0xFF & (0xFF << (8 - num_bits))) >> bit_off
        # print(f'byte={byte} mask={mask}')
        return (byte & mask) >> (8 - num_bits - bit_off)
            
    def merge_integer_values(self, values, bits_in_last_byte):
        result = 0
        for index, value in enumerate(reversed(values)):
            if index == 0:
                result += value
            else:
                result += value << ((8 * (index - 1)) + bits_in_last_byte)
        return result
            
        
    def extract_integer(self, num_bits):
        start = self.position
        remaining = num_bits
        bits_in_last_byte = remaining
        vals = []
        while remaining > 0:
            bits_in_last_byte = remaining
            byte = self.position // 8
            off = self.position % 8
            bits_in_this_byte = min(8 - off, remaining)
            val = self.extract_value(byte, off, bits_in_this_byte)
            # print(f'Extracting value {byte}:{off}@{bits_in_this_byte} rem: {remaining} value: {val}')
            vals.append(val)
            remaining -= bits_in_this_byte
            self.position += bits_in_this_byte
        # print(f'Merging {vals}, {bits_in_last_byte}')
        result = self.merge_integer_values(vals, bits_in_last_byte)
        # print(f'Read in {num_bits} from {start}: {result}')
        return result
    
    def extract_version(self):
        return self.extract_integer(3)
    
    def extract_type_id(self):
        return self.extract_integer(3)
    
    def extract_length_type_id(self):
        return self.extract_integer(1)
    
    def extract_total_length(self):
        return self.extract_integer(15)
    
    def extract_num_subpackets(self):
        return self.extract_integer(11)
    
    def extract_literal_value(self):
        vals = []
        while True:
            next_chunk = self.extract_integer(5)
            vals.append(next_chunk & 0xF)
            if next_chunk < 16:
                break
        result = 0
        for i, val in enumerate(reversed(vals)):
            result += val << (4 * i)
        return result
    
    def extract_packet(self):
        version = self.extract_version()
        type_id = self.extract_type_id()
        # print(f'Version={version} type_id={type_id}')
        if type_id == 4:
            literal_value = self.extract_literal_value()
            # print(f'Literal Value={literal_value}')
            return (version, type_id, [], literal_value)
        else:
            children = []
            length_type_id = self.extract_length_type_id()
            if length_type_id == 0:
                total_length = self.extract_total_length()
                # print(f'Must read {total_length} bits of children')
                end = self.position + total_length
                while self.position < end:
                    sub_reader = StreamReader(self.bytearr, self.position)
                    children.append(sub_reader.extract_packet())
                    self.position = sub_reader.position
            else:
                num_subs = self.extract_num_subpackets()
                # print(f'Must read {num_subs} children')
                for _ in range(num_subs):
                    sub_reader = StreamReader(self.bytearr, self.position)
                    children.append(sub_reader.extract_packet())
                    self.position = sub_reader.position
        return (version, type_id, children, None)

In [77]:
sr = StreamReader([255, 255])
print(sr.extract_integer(6))
print(sr.extract_integer(4))
print(sr.extract_integer(6))


63
15
63


In [78]:
def version_sum(pkt):
    result = 0
    result += pkt[0]
    for child in pkt[2]:
        result += version_sum(child)
    return result

for cmd in cmds:
    print(version_sum(StreamReader(cmd).extract_packet()))

977


In [79]:
def calculate_lit(pkt):
    return pkt[3]

def calculate_sum(pkt):
    result = 0
    for child in pkt[2]:
        result += calculate(child)
    return result

def calculate_prod(pkt):
    result = calculate(pkt[2][0])
    for child in pkt[2][1:]:
        result *= calculate(child)
    return result
        
def calculate_min(pkt):
    return min([calculate(child) for child in pkt[2]])

def calculate_max(pkt):
    return max([calculate(child) for child in pkt[2]])

def gt(pkt):
    left = calculate(pkt[2][0])
    right = calculate(pkt[2][1])
    if left > right:
        return 1
    return 0

def lt(pkt):
    left = calculate(pkt[2][0])
    right = calculate(pkt[2][1])
    if left < right:
        return 1
    return 0

def eq(pkt):
    left = calculate(pkt[2][0])
    right = calculate(pkt[2][1])
    if left == right:
        return 1
    return 0

def calculate(pkt):
    print(pkt)
    operators = [calculate_sum, calculate_prod, calculate_min, calculate_max, calculate_lit, gt, lt, eq]
    return operators[pkt[1]](pkt)

for cmd in cmds:
    packet = StreamReader(cmd).extract_packet()
    print(calculate(packet))

(1, 0, [(1, 3, [(4, 4, [], 15), (7, 4, [], 199071281), (7, 4, [], 499), (5, 4, [], 190)], None), (6, 1, [(2, 4, [], 240), (4, 5, [(7, 4, [], 3410), (1, 4, [], 201)], None)], None), (7, 3, [(1, 4, [], 7)], None), (5, 1, [(4, 4, [], 51786), (4, 7, [(2, 0, [(5, 4, [], 15), (3, 4, [], 11), (6, 4, [], 14)], None), (6, 0, [(0, 4, [], 6), (3, 4, [], 11), (0, 4, [], 5)], None)], None)], None), (5, 4, [], 2513), (7, 1, [(2, 4, [], 28408), (6, 5, [(5, 4, [], 52653), (7, 4, [], 52653)], None)], None), (4, 4, [], 751), (5, 4, [], 936579), (2, 1, [(5, 4, [], 84), (0, 5, [(2, 4, [], 246356739), (0, 4, [], 15228190)], None)], None), (3, 1, [(6, 4, [], 321243), (7, 7, [(4, 4, [], 212), (7, 4, [], 631)], None)], None), (4, 1, [(3, 7, [(1, 4, [], 208), (6, 4, [], 208)], None), (2, 4, [], 54640)], None), (4, 1, [(0, 4, [], 41), (1, 4, [], 169), (4, 4, [], 213), (7, 4, [], 246)], None), (2, 1, [(7, 6, [(1, 4, [], 34), (4, 4, [], 34)], None), (2, 4, [], 31562)], None), (2, 0, [(5, 4, [], 1625637177)], None