In [1]:
import re
from itertools import combinations
from typing    import Union
from helpers   import data

In [2]:
def get_mask(line) -> str:
    """
    >>> get_mask('mask = 010011101XX1111X100000000000XX10110X')
    '010011101XX1111X100000000000XX10110X'
    """
    m = re.match("mask = (.*)", line)
    if not m:
        raise ValueError(f'Expected line of the form "mask = ...", but instead got {line}')
    return m.group(1)

def get_memory_store(line) -> tuple[int]:
    """
    >>> get_memory_store('mem[6391] = 812')
    (6391, 812)
    >>> get_memory_store('mem[55972] = 5779')
    (55972, 5779)
    """
    m = re.match("mem\[(\d*)\] = (\d*)", line)
    if not m: 
        raise ValueError(f'Expected line of the form "mem[address] = value", but instead got {line}')
    return int(m.group(1)), int(m.group(2))

instructions = [get_mask(line) if line[:4] == "mask"
                else get_memory_store(line)
                for line in data(14)]

instructions[:5]

['111101X010011110100100110100101X0X0X',
 (37049, 1010356),
 (5632, 28913),
 (9384, 7522),
 '00X1011X11X0000010100X011X10X10X10XX']

In [3]:
def apply_mask(mask: str, n: int) -> int:
    """
    Mask is a string of 0, 1, or X: 
    - 0: force bit to be 0 
    - 1: force bit to be 1 
    - X: don't change bit
    
    >>> apply_mask('XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X', 11)
    73
    >>> apply_mask('XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X', 101)
    101
    >>> apply_mask('XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X', 0)
    64
    """
    # AND with this to set to 0 
    and_mask = 2**len(mask) - 1 # All 1s, then subtract if need to 0 out
    # OR with this to set to 1 
    or_mask = 0 
    
    for i, bit in enumerate(reversed(mask)):
        if bit == '0':
            and_mask -= 2**i 
        elif bit == '1':
            or_mask += 2**i 
            
    return and_mask & n | or_mask 

**Part 1:** Sum of all values in memory after applying bit mask to each one and storing it. (There might be overrides.)

In [4]:
mask = None 
memory = {}
for inst in instructions: 
    if isinstance(inst, str):
        mask = inst 
    else: 
        addr, value = inst 
        memory[addr] = apply_mask(mask, value)
        
sum(memory.values())

8471403462063

**Part 2:** Now the mask applies to the *memory address* with different rules: 0 is unchanged, 1 overwrites with 1, X is a "floating" bit. We write to all memory addresses where the floating bits are any possible bit. For 3 floating bits, that's 2^3 = 8 addresses. Get sum of all values in memory. 

In [5]:
def apply_floating_mask(mask: str, addr: int): 
    """
    Mask is a string of 0, 1, or X: 
    - 0: leave bit unchanged 
    - 1: force bit to be 1 
    - X: floating bit 
    
    The floating bits should be all possible values. For n floating bits, 
    we have 2**n results of applying the mask. 
    
    >>> list(apply_floating_mask('000000000000000000000000000000X1001X', 42))
    [26, 27, 58, 59]
    >>> list(apply_floating_mask('00000000000000000000000000000000X0XX', 26))
    [16, 17, 18, 24, 19, 25, 26, 27]
    """
    # Floating bits (OR with one of these to turn on)
    floating = []
    # AND with this to set to 0 (set all floating bits to 0 initially, then turn on)
    and_mask = 2**len(mask) - 1

    # OR with this to set to 1 
    or_mask = 0 
    
    for i, bit in enumerate(reversed(mask)):
        if bit == '1':
            or_mask += 2**i 
        elif bit == 'X':
            floating.append(2**i)
            and_mask -= 2**i 
            
    addr_without_floating = and_mask & addr | or_mask 
    
    for n in range(len(floating) + 1):
        for floating_comb in combinations(floating, n):
            addr_with_floating = addr_without_floating
            for bit in floating_comb: 
                addr_with_floating |= bit
            yield addr_with_floating

In [6]:
mask = None 
memory = {} 
for inst in instructions: 
    if isinstance(inst, str):
        mask = inst 
    else: 
        initial_addr, value = inst 
        for addr in apply_floating_mask(mask, initial_addr):
            memory[addr] = value 
            
sum(memory.values())

2667858637669

In [7]:
from doctest import testmod
testmod()

TestResults(failed=0, attempted=8)