In [135]:
from dataclasses import dataclass
from math import prod

In [2]:
#!pip install bitstruct
import bitstruct

In [3]:
# filename = 'example.txt'
filename = 'input.txt'

In [4]:
with open(filename) as f:
    intxt = f.readline().rstrip()

In [5]:
def decode_hex(hextxt):
    return bytes.fromhex(hextxt)

In [140]:
@dataclass
class Packet:
    vx: int
    tx: int
        
    def vsum(self) -> int:
        raise NotImplementedError

    def eval(self) -> int:
        raise NotImplementedError
        
@dataclass
class Literal(Packet):
    v: int

    def vsum(self) -> int:
        return self.vx
        
    def eval(self) -> int:
        return self.v
        
@dataclass
class Operator(Packet):
    sp: list[int]

    def vsum(self) -> int:
        return self.vx + sum(p.vsum() for p in self.sp)
 
    def eval(self) -> int:
        args = [p.eval() for p in self.sp]
        
        if self.tx == 0:
            return sum(args)
        elif self.tx == 1:
            return prod(args)
        elif self.tx == 2:
            return min(args)
        elif self.tx == 3:
            return max(args)
        elif self.tx == 5:
            return int(args[0] > args[1])
        elif self.tx == 6:
            return int(args[0] < args[1])
        elif self.tx == 7:
            return int(args[0] == args[1])
        else:
            raise Exception()

In [125]:
class PacketSlice:
    def __init__(self, packet):
        self.packet = packet
        self.ofs = 0
    
    @classmethod
    def fromhex(cls, hextxt):
        return cls(bytes.fromhex(hextxt))
    
    def unpack(self, fmt):
        fields = bitstruct.unpack_from(fmt, self.packet, self.ofs)
        self.ofs += bitstruct.calcsize(fmt)
        return fields
    
    def parse_header(self):
        return self.unpack('u3u3')

    def parse_literal(self):
        has_more = True
        payload = 0
        while has_more:
            has_more, nibble = self.unpack('b1u4')
            payload <<= 4
            payload |= nibble
        return payload

    def parse_operator(self):
        (ltid,) = self.unpack('b1')
        subpackets = []
        if not ltid:
            (bitlen,) = self.unpack('u15')
            start_ofs = self.ofs
            while self.ofs < start_ofs + bitlen:
                subpackets.append(self.parse())
        else:
            (num_subpackets,) = self.unpack('u11')
            for _ in range(num_subpackets):
                subpackets.append(self.parse())

        return subpackets
    
    def parse(self):
        version, type_id = self.parse_header()

        if type_id == 4:
            return Literal(version, type_id, self.parse_literal())
        else:
            return Operator(version, type_id, self.parse_operator())

In [94]:
packet = PacketSlice.fromhex('D2FE28')
print(packet.parse_header())
print(packet.parse_literal())

(6, 4)
2021


In [95]:
packet = PacketSlice.fromhex('D2FE28')
packet.parse()

Literal(vx=6, v=2021)

In [96]:
packet = PacketSlice.fromhex('38006F45291200')
packet.parse()

Operator(vx=1, sp=[Literal(vx=6, v=10), Literal(vx=2, v=20)])

In [97]:
packet = PacketSlice.fromhex('EE00D40C823060')
packet.parse()

Operator(vx=7, sp=[Literal(vx=2, v=1), Literal(vx=4, v=2), Literal(vx=1, v=3)])

In [98]:
PacketSlice.fromhex('8A004A801A8002F478').parse()

Operator(vx=4, sp=[Operator(vx=1, sp=[Operator(vx=5, sp=[Literal(vx=6, v=15)])])])

In [99]:
PacketSlice.fromhex('620080001611562C8802118E34').parse()

Operator(vx=3, sp=[Operator(vx=0, sp=[Literal(vx=0, v=10), Literal(vx=5, v=11)]), Operator(vx=1, sp=[Literal(vx=0, v=12), Literal(vx=3, v=13)])])

In [100]:
PacketSlice.fromhex('C0015000016115A2E0802F182340').parse()

Operator(vx=6, sp=[Operator(vx=0, sp=[Literal(vx=0, v=10), Literal(vx=6, v=11)]), Operator(vx=4, sp=[Literal(vx=7, v=12), Literal(vx=0, v=13)])])

In [101]:
PacketSlice.fromhex('A0016C880162017C3686B18A3D4780').parse()

Operator(vx=5, sp=[Operator(vx=1, sp=[Operator(vx=3, sp=[Literal(vx=7, v=6), Literal(vx=6, v=6), Literal(vx=5, v=12), Literal(vx=2, v=15), Literal(vx=2, v=15)])])])

In [112]:
def packetsum(hexp):
    def helpsum(packet):
        if isinstance(packet, Literal):
            return packet.vx
        else:
            return packet.vx + sum(helpsum(p) for p in packet.sp)
        
    packet = PacketSlice.fromhex(hexp).parse()
    return helpsum(packet)

In [113]:
packetsum('8A004A801A8002F478')

16

In [114]:
packetsum('620080001611562C8802118E34')

12

In [115]:
packetsum('C0015000016115A2E0802F182340')

23

In [116]:
packetsum('A0016C880162017C3686B18A3D4780')

31

In [117]:
packetsum(intxt)

883

In [123]:
[PacketSlice.fromhex(hext).parse().vsum() for hext in [
    'D2FE28',
    '38006F45291200',
    'EE00D40C823060',
    '8A004A801A8002F478',
    '620080001611562C8802118E34',
    'C0015000016115A2E0802F182340',
    'A0016C880162017C3686B18A3D4780'
]]

[6, 9, 14, 16, 12, 23, 31]

In [144]:
[PacketSlice.fromhex(hext).parse().eval() for hext in [
    'C200B40A82',
    '04005AC33890',
    '880086C3E88112',
    'CE00C43D881120',
    'D8005AC2A8F0',
    'F600BC2D8F',
    '9C005AC2F8F0',
    '9C0141080250320F1802104A08'
]]

[3, 54, 7, 9, 1, 0, 0, 1]

In [145]:
PacketSlice.fromhex(intxt).parse().eval()

1675198555015