# Day 16

In [1]:
from itertools import islice, takewhile
from enum import Enum
from functools import reduce
from operator import mul

PacketType = Enum(
    'PacketType',
    'SUM PRODUCT MINIMUM MAXIMUM LITERAL GREATER_THAN LESS_THAN EQUAL_TO',
    start = 0
)

def to_bit_sequence(s):
    zeros = len([*takewhile(lambda x: x == '0', s)])
    zeros = '0' * 4 * zeros if zeros else ''
    n = int(s, base=16)
    s = f'{zeros}{n:b}'
    n = len(s) % 4
    if n:
        for _ in range(4 - n):
            yield '0'
    yield from s    
    
def parse_bits(n, g):
    return int(''.join(islice(g, n)), base=2)

def literal_value_parser(g):    
    while True:
        a, *n = islice(g, 5)
        yield from n
        if a == '0':
            break
    
def op_0_parser(g):
    length_in_bits = parse_bits(15, g)    
    yield from parse_packets(islice(g, length_in_bits))

def op_1_parser(g):
    length_in_subpackets = parse_bits(11, g)    
    for _ in range(length_in_subpackets):
        yield parse_packet(g)

def op_parser(g):
    length_type_id = parse_bits(1, g)    
    if length_type_id:
        yield from op_1_parser(g)        
    else:
        yield from op_0_parser(g)    

def parse_packet(g):    
    packet = dict(version=parse_bits(3, g))    
    packet_type = PacketType(parse_bits(3, g))
    packet['type'] = packet_type
    if packet_type is PacketType.LITERAL:
        packet['value'] = int(''.join(literal_value_parser(g)), base=2)
    else:
        packet['subpackets'] = [*op_parser(g)]
    return packet

def parse_packets(g):
    while True:
        try:
            yield parse_packet(g)
        except:
            break

def parse(s):
    g = to_bit_sequence(s)
    return parse_packet(g)

def versions_extractor(packet):
    yield packet['version']
    if packet['type'] is not PacketType.LITERAL:
        for subpacket in packet['subpackets']:
            yield from versions_extractor(subpacket)

def calc_packet(packet):
    pt = packet['type']
    g = (calc_packet(p) for p in packet.get('subpackets', []))
    if pt is PacketType.SUM:
        return sum(g)
    elif pt is PacketType.PRODUCT:
        return reduce(mul, g)
    elif pt is PacketType.MINIMUM:
        return min(g)
    elif pt is PacketType.MAXIMUM:
        return max(g)
    elif pt is PacketType.GREATER_THAN:
        return int(next(g) > next(g))
    elif pt is PacketType.LESS_THAN:
        return int(next(g) < next(g))
    elif pt is PacketType.EQUAL_TO:
        return int(next(g) == next(g))
    else:
        return packet['value']   

In [2]:
# Testing for Part 1

test_data = [
    ('8A004A801A8002F478', 16),
    ('620080001611562C8802118E34', 12),
    ('C0015000016115A2E0802F182340', 23),
    ('A0016C880162017C3686B18A3D4780', 31)
]
for s, sum_of_versions in test_data:
    packet = parse(s)
    assert sum(versions_extractor(packet)) == sum_of_versions
    
# Testing for Part 2

test_data = [
    ('C200B40A82', 3),
    ('04005AC33890', 54),
    ('880086C3E88112', 7),
    ('CE00C43D881120', 9),
    ('D8005AC2A8F0', 1),
    ('F600BC2D8F', 0),
    ('9C005AC2F8F0', 0),
    ('9C0141080250320F1802104A08', 1)
]
for s, value in test_data:
    packet = parse(s)    
    assert calc_packet(packet) == value

In [3]:
with open('day16.txt', 'r') as f:
    s = f.read().strip()
packet = parse(s)
sum_of_versions = sum(versions_extractor(packet))
value = calc_packet(packet)
print(sum_of_versions)
print(value)

953
246225449979
