# December 16, 2021

https://adventofcode.com/2021/day/16

In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict, deque
from queue import PriorityQueue

In [2]:
with open("../data/2021/16.txt", "r") as f:
    data = f.read()

In [3]:
tests = [
    "8A004A801A8002F478"
    , "620080001611562C8802118E34"
    , "C0015000016115A2E0802F182340"
    , "A0016C880162017C3686B18A3D4780"
]

In [4]:
def hex_to_binary( text ):

    bit_char = {
        "0": "0000"
        , "1": "0001"
        , "2": "0010"
        , "3": "0011"
        , "4": "0100"
        , "5": "0101"
        , "6": "0110"
        , "7": "0111"
        , "8": "1000"
        , "9": "1001"
        , "A": "1010"
        , "B": "1011"
        , "C": "1100"
        , "D": "1101"
        , "E": "1110"
        , "F": "1111"
    }

    return "".join( [bit_char[hex_char] for hex_char in text] )

def str_to_int(s):
    n = len(s) - 1
    return sum( [int(d) * 2**(n-i) for i,d in enumerate(s)] )

In [5]:
assert hex_to_binary(tests[0]) == "110100101111111000101000"

AssertionError: 

In [12]:
class BitStream():
    def __init__(self, stream):
        self.stream = stream
        self.pos = 0
        self.version_sum = 0

    def read_binary_number(self, nbits):
        s = self.stream[self.pos:self.pos+nbits]
        self.pos += nbits
        return self.str_to_int(s)
    
    def read_version_id(self):
        return self.read_binary_number(3)
    
    def read_type_id(self):
        return self.read_binary_number(3)
    
    def read_length_type_id(self):
        return self.read_binary_number(1)
    
    def read_length_type0(self):
        '''read packets with a given total bit-length'''
        nbits = self.read_binary_number(15)
        values = []
        start_pos = self.pos
        while self.pos - start_pos < nbits:
            values += [self.read_packet()]
        return values
    
    def read_length_type1( self ):
        '''read a specified number of packets'''
        npackets = self.read_binary_number(11)
        values = [0] * npackets
        for i in range(npackets):
            values[i] = self.read_packet()
        return values
    
    def read_literal(self):
        s = ""
        while True:
            end_flag = self.stream[self.pos]
            s += self.stream[self.pos+1:self.pos+5]
            self.pos += 5
            if end_flag == "0":
                break
        return self.str_to_int(s)
    
    def read_packet(self):
        '''return the value of the packet starting at pos'''
        vers_id = self.read_version_id()
        self.version_sum += vers_id
        type_id = self.read_type_id()

        if type_id == 4:
            return self.read_literal()
        
        length_type_id = self.read_length_type_id()
        if length_type_id == 0:
            values = self.read_length_type0()
        else:
            values = self.read_length_type1()

        # perform an operation
        if type_id == 0: # sum
            value = 0
            for v in values:
                value += v
        elif type_id == 1: # prod
            value = 1
            for v in values:
                value *= v
        elif type_id == 2: #min
            value = values[0]
            for v in values[1:]: # works even if values is len 1. Nice job, Python
                if v < value: value = v
        elif type_id == 3: # max
            value = values[0]
            for v in values[1:]: # works even if values is len 1. Nice job, Python
                if v > value: value = v
        elif type_id == 5: # gt
            if values[0] > values[1]:
                value = 1
            else:
                value = 0
        elif type_id == 6: # lt
            if values[0] < values[1]:
                value = 1
            else:
                value = 0
        elif type_id == 7: # eq
            if values[0] == values[1]:
                value = 1
            else:
                value = 0

        return value
    
    def get_version_sum(self):
        vsum = []
        while len(self.stream) - self.pos > 3:
            vsum += [self.read_packet()]
        return vsum
    
    def reset(self):
        self.pos = 0

    @staticmethod
    def str_to_int(s):
        n = len(s) - 1
        return sum( [int(d) * 2**(n-i) for i,d in enumerate(s)] )

# Part 1

In [13]:
for t in tests:
    bit_string = hex_to_binary(t)
    bs = BitStream(bit_string)
    bs.read_packet()
    print(bs.version_sum)

16
12
23
31


In [14]:
bit_string = hex_to_binary(data)
bs = BitStream(bit_string)
bs.read_packet()
print(bs.version_sum)

940


# Part 2

evaluate packets

In [None]:
tests = [
    "C200B40A82"
    , "04005AC33890"
    , "880086C3E88112"
    , "CE00C43D881120"
    , "D8005AC2A8F0"
    , "F600BC2D8F"
    , "9C005AC2F8F0"
    , "9C0141080250320F1802104A08"
]

In [17]:
for t in tests:
    bit_string = hex_to_binary(t)
    bs = BitStream(bit_string)
    print(t, "--->", bs.read_packet())

print("How'd we do?")

C200B40A82 ---> 3
04005AC33890 ---> 54
880086C3E88112 ---> 7
CE00C43D881120 ---> 9
D8005AC2A8F0 ---> 1
F600BC2D8F ---> 0
9C005AC2F8F0 ---> 0
9C0141080250320F1802104A08 ---> 1
How'd we do?


In [24]:
bit_string = hex_to_binary(data)
bs = BitStream(bit_string)
bs.read_packet(), bs.version_sum

(13476220616073, 940)