In [3]:
with open("input16.txt", "r") as f:
    data = f.read().strip()
data = data.split("\n")

In [4]:
for d in data[:12]:
    print(d)
for d in data[-12:]:
    print(d)

Before: [1, 1, 2, 0]
8 1 0 3
After:  [1, 1, 2, 1]

Before: [1, 1, 1, 2]
8 1 0 3
After:  [1, 1, 1, 1]

Before: [2, 2, 0, 3]
5 1 3 1
After:  [2, 0, 0, 3]

12 0 2 2
4 2 1 2
2 2 3 3
12 3 3 1
4 1 0 3
14 3 0 3
15 2 0 2
12 0 2 3
4 3 2 3
4 3 2 3
2 3 1 1
12 1 0 0


In [5]:
import copy
import re
from collections import defaultdict

In [6]:
def addr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]+reg[b]
    return reg

def addi(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]+b
    return reg

def mulr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]*reg[b]
    return reg

def muli(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]*b
    return reg

def banr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]&reg[b]
    return reg

def bani(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]&b
    return reg
    
def borr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]|reg[b]
    return reg

def bori(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]|b
    return reg
    
def setr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = reg[a]
    return reg
    
def seti(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = a
    return reg
    
def gtir(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if a>reg[b] else 0
    return reg

def gtri(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if reg[a]>b else 0
    return reg

def gtrr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if reg[a]>reg[b] else 0
    return reg

def eqir(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if a==reg[b] else 0
    return reg

def eqri(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if reg[a]==b else 0
    return reg

def eqrr(registers, a, b, c):
    reg = copy.copy(registers)
    reg[c] = 1 if reg[a]==reg[b] else 0
    return reg

In [7]:
func_dict = {
"addr":addr,
"addi":addi,
"mulr":mulr,
"muli":muli,
"banr":banr,
"bani":bani,
"borr":borr,
"bori":bori,
"setr":setr,
"seti":seti,
"gtir":gtir,
"gtri":gtri,
"gtrr":gtrr,
"eqir":eqir,
"eqri":eqri,
"eqrr":eqrr,
}
def test_all(registers, code, result):
    matches = []
    for key in func_dict:
        output = func_dict[key](registers, code[1], code[2], code[3])
        if output == result:
            matches.append(key)
    return matches

In [9]:
total = 0

registers = None
after = None
code = None

result_dict = defaultdict(list)

for row in data:
    nums = re.findall(r'\d+',row)
    nums = [int(x) for x in nums]
    
    if "Before" in row:
        registers = nums
    elif "After" in row:
        after = nums
        matches = test_all(registers, code, after)
        if len(matches)>=3:
            total+=1
        result_dict[code[0]].append(matches)
    elif row:
        code = nums
print(total)

614


In [10]:
opcode_dict = {}

while len(opcode_dict) < 16:
        
    #Look for any opcodes that only matched one function
    for opcode in result_dict:
        if opcode in opcode_dict:
            continue
        for matches in result_dict[opcode]:
            valid_matches = [x for x in matches if x not in opcode_dict.values()]
            if len(matches)==1:
                opcode_dict[opcode] = matches[0]
                break
    
    #Look for any functions that only matched with one opcode
    for key in func_dict:
        if key in opcode_dict.values():
            continue
        matched_codes = set()
        for opcode in [op for op in result_dict if op not in opcode_dict]:
            for matches in result_dict[opcode]:
                if key in matches:
                    matched_codes.add(opcode)
        if len(matched_codes)==1:
            opcode = list(matched_codes)[0]
            opcode_dict[opcode] = key        

In [11]:
print(opcode_dict)

{0: 'eqir', 1: 'borr', 2: 'addr', 3: 'gtri', 4: 'muli', 5: 'gtir', 6: 'mulr', 7: 'banr', 8: 'bori', 9: 'eqri', 10: 'eqrr', 11: 'bani', 12: 'setr', 13: 'gtrr', 14: 'addi', 15: 'seti'}


In [12]:
last_after = None
for i in range(len(data)):
    if "After" in data[i]:
        last_after = i

In [13]:
registers = [0,0,0,0]
for cmd in data[last_after+1:]:
    if not cmd:
        continue
    cmd = re.findall(r'\d+',cmd)
    cmd = [int(x) for x in cmd]
    operation = opcode_dict[cmd[0]]
    func = func_dict[operation]
    registers = func(registers, cmd[1], cmd[2], cmd[3])
print(registers[0])

656
