In [36]:
from utils import read_lines

class Packet:
    def __init__(self, version, tp, value=0, sub_packets=[]):
        self.version = version
        self.tp = tp
        self.value = value
        self.sub_packets = sub_packets

    def sum_versions(self):
        return self.version + sum(p.sum_versions() for p in self.sub_packets)
    
    def get_value(self):
        if self.tp == 4:
            return self.value
        elif self.tp == 0:
            return sum(x.get_value() for x in self.sub_packets)
        elif self.tp == 1:
            ans = 1
            for sub in self.sub_packets:
                ans *= sub.get_value()
            return ans
        elif self.tp == 2:
            return min(x.get_value() for x in self.sub_packets)
        elif self.tp == 3:
            return max(x.get_value() for x in self.sub_packets)
        elif self.tp == 5:
            return 1 if self.sub_packets[0].get_value() > self.sub_packets[1].get_value() else 0
        elif self.tp == 6:
            return 1 if self.sub_packets[0].get_value() < self.sub_packets[1].get_value() else 0
        elif self.tp == 7:
            return 1 if self.sub_packets[0].get_value() == self.sub_packets[1].get_value() else 0

    
    def __repr__(self) -> str:
        return str(self.__dict__)
    
def parse_packet(bin_str, start):
    version = int(bin_str[start:start+3], 2)
    tp = int(bin_str[start + 3:start+6], 2)
    if tp == 4:
        i = start + 6
        value = ''
        while True:
            value += bin_str[i+1:i+5]
            if bin_str[i] == '0':
                i += 5
                break
            else:
                i += 5
        value = int(value, 2)
        literal = Packet(version, tp, value)
        return literal, i
    else:
        sub_packets = []
        tid = bin_str[start+6]
        if tid == '0':
            p_len = int(bin_str[start+7:start+22], 2)
            end_pos = start + 22 + p_len
            i = start + 22
            
            while i < end_pos:
                sub_p, i = parse_packet(bin_str, i)
                sub_packets.append(sub_p)
            op_packet = Packet(version, tp, sub_packets=sub_packets)
            return op_packet, end_pos
        else:
            sub_len = int(bin_str[start+7:start+18], 2)
            i = start + 18
            while len(sub_packets) < sub_len:
                sub_p, i = parse_packet(bin_str, i)
                sub_packets.append(sub_p)
            op_packet = Packet(version, tp, sub_packets=sub_packets)
            return op_packet, i

def parse_packet_from_hex(hex_str):
    bin_str = hex_to_bin(hex_str)
    return parse_packet(bin_str, 0)

def hex_to_bin(hex_str):
    bin_str = bin(int(hex_str, base=16))[2:]
    if len(bin_str) % 4:
        return '0' * (4 - len(bin_str) % 4) + bin_str
    else:
        return bin_str
    
def part1(input_file):
    line = read_lines(input_file)[0]
    bin_str = hex_to_bin(line)
    packet, _ = parse_packet(bin_str, 0)
    return packet.sum_versions()

def part2(input_file):
    line = read_lines(input_file)[0]
    bin_str = hex_to_bin(line)
    packet, _ = parse_packet(bin_str, 0)
    return packet.get_value()


In [33]:
part1('inputs/day16.txt')

984

In [37]:
part2('inputs/day16.txt')

1015320896946

In [23]:
p, end = parse_packet_from_hex('D2FE28')
assert p.value == 2021
assert end == 21

In [25]:
p, end = parse_packet_from_hex('38006F45291200')
assert p.sub_packets[0].value == 10
assert p.sub_packets[1].value == 20
assert end == 49

In [38]:
p, end = parse_packet_from_hex('EE00D40C823060')
assert p.sub_packets[0].value == 1
assert p.sub_packets[1].value == 2
assert p.sub_packets[2].value == 3
assert end == 51