In [1]:
from __future__  import annotations
from collections import Counter, defaultdict, namedtuple, deque
from itertools   import permutations, combinations, cycle, product, islice, chain
from functools   import lru_cache
from typing      import Dict, Tuple, Set, List, Iterator, Optional
from sys         import maxsize

import re
import ast
import operator

import numpy as np

In [2]:
def read_data(input: str, parser=str, sep='\n', testing=False) -> list:
    if testing:
        sections = input.split(sep)
    else:
        sections = open(input).read().split(sep)
    return [parser(section) for section in sections]

In [3]:
string = """mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0"""

In [4]:
def parse_ins(ins: List[str]):
    if 'mask' in ins:
        return 'mask', ins.split("=")[-1].strip()
        # mask instruction
    return list(int(x) for x in re.findall(r'\d+', ins))
test_ins = read_data(string, parser=parse_ins, sep="\n", testing=True)

print(test_ins)



[('mask', 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X'), [8, 11], [7, 101], [8, 0]]


In [5]:
type('mask') == str

True

Part I  

Execute the initialization program. What is the sum of all values left in memory after it completes? (Do not truncate the sum to 36 bits.)

In [6]:
def int_to_bin36(val: int) -> str:
    return "{:036b}".format(val)

def process_mask(mask: str, val: int) -> int:
    output = ''
    for m, v in zip(mask, int_to_bin36(val)):
        if m != 'X':
            output += m
        else:
            output += v
    return int(output, base=2)

def run_part1(ins: Tuple[obj, obj]) -> int:
    mask = ''
    mem = defaultdict(int)
    for header, val in ins:
        if type(header) == str:
            mask = val
        else:
            mem[header] = process_mask(mask, val)
    return sum(mem.values())

In [7]:
run_part1(test_ins)

165

In [8]:
real_ins = read_data("input.txt", parser=parse_ins)
run_part1(real_ins)

4886706177792

Part II

Execute the initialization program using an emulator for a version 2 decoder chip. What is the sum of all values left in memory after it completes?

In [9]:
def process_mask2(mask: str, header: int) -> int:
    template = ''
    for m, v in zip(mask, int_to_bin36(header)):
        if m == 'X':
            template += '{}'
        elif m == '0':
            template += v
        else:
            template += '1'
    return [int(template.format(*perm), base=2) for perm in product('01', repeat=template.count('{}'))]

def run_part2(ins: Tuple[obj, obj]) -> int:
    mask = ''
    mem = defaultdict(int)
    for header, val in ins:
        if type(header) == str:
            mask = val
        else:
            addrs = process_mask2(mask, header)
            for addr in addrs:
                mem[addr] = val
    return sum(mem.values())


In [10]:
string2 = '''mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1'''

test_ins2 = read_data(string2, parser=parse_ins, sep="\n", testing=True)

print(test_ins2)

[('mask', '000000000000000000000000000000X1001X'), [42, 100], ('mask', '00000000000000000000000000000000X0XX'), [26, 1]]


In [11]:
run_part2(test_ins2)

208

In [12]:
run_part2(real_ins)

3348493585827