In [42]:
def read_input(filename):
    sample = []
    with open(filename, 'r') as f:
        while True:
            line = f.readline()
            if not line.startswith('Before'):
                break
            before = [int(x) for x in line.strip().split('[')[1][:-1].split(', ')]
            line = f.readline()
            code = [int(x) for x in line.strip().split(' ')]
            line = f.readline()
            after = [int(x) for x in line.strip().split('[')[1][:-1].split(', ')]
            sample.append([before, code, after])
            f.readline()
        f.readline()
        program = []
        for line in f.readlines():
            program.append([int(x) for x in line.strip().split(' ')])
    return sample, program

class Program:
    def __init__(self, trans_dict, reg=[0, 0, 0, 0]):
        self.reg = reg
        self.trans_dict = trans_dict
    
    def run(self, instruction):
        reg = self.reg
        op, a, b, c = instruction
        op = self.trans_dict[op]
        match op:
            case 'addr':
                res = reg[a] + reg[b]
            case 'addi':
                res = reg[a] + b
            case 'mulr':
                res = reg[a] * reg[b]
            case 'muli':
                res = reg[a] * b
            case 'banr':
                res = reg[a] & reg[b]
            case 'bani':
                res = reg[a] & b
            case 'borr':
                res = reg[a] | reg[b]
            case 'bori':
                res = reg[a] | b
            case 'setr':
                res = reg[a]
            case 'seti':
                res = a
            case 'gtir':
                res = 1 if a > reg[b] else 0
            case 'gtri':
                res = 1 if reg[a] > b else 0
            case 'gtrr':
                res = 1 if reg[a] > reg[b] else 0
            case 'eqir':
                res = 1 if a == reg[b] else 0
            case 'eqri':
                res = 1 if reg[a] == b else 0
            case 'eqrr':
                res = 1 if reg[a] == reg[b] else 0
        self.reg[c] = res

def test_opcodes(before, instruction, after):
    n, a, b, c = instruction
    res = {}
    res['addr'] = before[a] + before[b]
    res['addi'] = before[a] + b
    res['mulr'] = before[a] * before[b]
    res['muli'] = before[a] * b
    res['banr'] = before[a] & before[b]
    res['bani'] = before[a] & b
    res['borr'] = before[a] | before[b]
    res['bori'] = before[a] | b
    res['setr'] = before[a]
    res['seti'] = a
    res['gtir'] = 1 if a > before[b] else 0
    res['gtri'] = 1 if before[a] > b else 0
    res['gtrr'] = 1 if before[a] > before[b] else 0
    res['eqir'] = 1 if a == before[b] else 0
    res['eqri'] = 1 if before[a] == b else 0
    res['eqrr'] = 1 if before[a] == before[b] else 0
    work = []
    for opcode, val in res.items():
        if after[c] == val:
            work.append(opcode)
    return work

def translate(samples):
    ntoc = {}
    for sample in samples:
        n = sample[1][0]
        pos = test_opcodes(*sample)
        ntoc[n] = ntoc.get(n, set()).union(pos)
    trans = {}
    while ntoc:
        found = []
        for n, opcodes in ntoc.items():
            if len(opcodes) == 1:
                found.append(n)
        for n in found:
            trans[n], = ntoc.pop(n)
            for m in ntoc:
                ntoc[m] = ntoc[m] - {trans[n]}
    return trans

In [52]:
samples, program = read_input('16_input.txt')

In [53]:
len([x for x in [len(test_opcodes(*sample)) for sample in samples] if x >= 3])

588

In [54]:
trans = translate(samples)

In [55]:
p = Program(trans)

In [56]:
for instruction in program:
    p.run(instruction)

In [57]:
p.reg

[627, 627, 2, 3]