In [54]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from utils.util import file2list
import numpy as np

In [17]:
def to_bits(s):
    i = int(s, 16)
    h_size = len(s) * 4
    return bin(i)[2:].zfill(h_size)

In [50]:
class op_packet:
    def __init__(self, version, p_type):
        self.version = version
        self.p_type = p_type
        self.packets = []
    def __str__(self):
        s = "op|ver: {.version}, [".format(self)
        for p in self.packets:
            s = s + str(p) + ", "
        s += "]|"
        return s

class literal_packet:
    def __init__(self, version, value):
        self.version = version
        self.value = value
    def __str__(self):
        return "l|ver: {.version}, val: {.value}|".format(self, self)

In [48]:
def parse_literal(p, version):
    val = 0
    i = 0
    while p[i] != '0':
        literal = int(p[i+1:i+5], 2)
        val += literal
        val*=16
        i += 5
    literal = int(p[i+1:i+5], 2)
    val += literal
    return literal_packet(version, val), i+5

def parse(p):
    v = int(p[0:3], 2)
    t = int(p[3:6], 2)
    i = 6
    if t == 4: # literal
        lit, end_p =  parse_literal(p[i:], v)
        return (lit, end_p + i)
    else: # operator
        lt = p[i]
        i+=1
        if lt == '0': # 15 bit subpacket length
            length = int(p[i:i+15], 2)
            i += 15
            root = op_packet(v, t)
            parsed_length = 0
            while parsed_length < length:
                new_packet, end_p = parse(p[i:])
                root.packets.append(new_packet)
                parsed_length += end_p
                i+= end_p
            return root, i
        else: # 11 bit number of packets
            num = int(p[i:i+11], 2)
            i += 11
            root = op_packet(v, t)
            current_packet = 0
            while current_packet != num:
                new_packet, end_p = parse(p[i:])
                root.packets.append(new_packet)
                i += end_p
                current_packet += 1
            return root, i

In [45]:
def sum_version(rp):
    if type(rp) is literal_packet:
        return rp.version
    else:
        ret = 0
        for p in rp.packets:
            ret += sum_version(p)
        return ret + rp.version

In [55]:
def calc_value(rp):
    if type(rp) is literal_packet:
        return rp.value
    else:
        t = rp.p_type
        if t == 5: # gt
            return calc_value(rp.packets[0]) > calc_value(rp.packets[1])
        elif t == 6: # lt
            return calc_value(rp.packets[0]) < calc_value(rp.packets[1])
        elif t == 7: # eq
            return calc_value(rp.packets[0]) == calc_value(rp.packets[1])
        
        values = []
        for p in rp.packets:
            values.append(calc_value(p))
        if t == 0: # sum
            return sum(values)
        elif t == 1: # product
            return np.prod(values)
        elif t == 2: # min
            return min(values)
        elif t == 3: # max
            return max(values)

In [56]:
l = file2list("./input.txt")[0]
bits = to_bits(l)
root, end = parse(bits)
#print(str(root), end)
print("version sum :{}".format(sum_version(root)))
print("value :{}".format(calc_value(root)))


version sum :999
version sum :3408662834145
