In [1]:
import tensorflow as tf

In [33]:
def convert_to_instruction(op_name, val):
    if op_name == 'nop':
        return tf.stack([1, 0], 0)
    elif op_name == 'acc':
        return tf.stack([1, val], 0)
    elif op_name == 'jmp':
        return tf.stack([val, 0], 0)
    else:
        return tf.stack([0, 0], 0)
    
s = tf.io.read_file('day8.txt')
s = tf.strings.split(s, '\n')
s = tf.strings.split(s, ' ')
s = s.to_tensor()
op_names = tf.map_fn(
    lambda x: x[0], 
    s
)
vals = tf.map_fn(
    lambda x: tf.strings.to_number(x[1]), 
    s,
    fn_output_signature=tf.TensorSpec(None, tf.float32)
)
vals = tf.cast(vals, tf.int32)

instructions = tf.map_fn(
    lambda i: convert_to_instruction(op_names[i], vals[i]),
    tf.range(vals.shape[0]),
    fn_output_signature=tf.TensorSpec((2,), tf.int32),
)

@tf.function(experimental_compile=True)
def solve(instructions):
    n_ops = instructions.shape[0]
    state = tf.constant([0, 0])    
    op_visits = tf.one_hot(0, n_ops, dtype=tf.int32)

    while state[0] < n_ops and op_visits[state[0]] < 2:
        state += instructions[state[0]]
        op_visits = op_visits + tf.one_hot(state[0], n_ops, dtype=tf.int32)

    return state

solve(instructions)[1]

<tf.Tensor: shape=(), dtype=int32, numpy=1384>

In [34]:
%%timeit
solve(instructions)

7.8 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
@tf.function(input_signature=[tf.TensorSpec(shape=(None,) + instructions.shape, dtype=tf.int32), tf.TensorSpec(shape=None, dtype=tf.int32)])
def combine_if_jmp_or_nop(acc, i):
    if op_names[i] == 'acc':
        return acc
    else:
        new_op_name = 'nop'
        if op_names[i] == 'nop':
            new_op_name = 'jmp'

        new_instructions = tf.concat([
            instructions[:i],
            tf.expand_dims(convert_to_instruction(new_op_name, vals[i]), 0),
            instructions[i+1:],
        ], 0)
        return tf.concat([tf.expand_dims(new_instructions, 0), acc], 0)

all_possible_programs = tf.foldr(
    combine_if_jmp_or_nop,
    tf.range(op_names.shape[0]),
    tf.zeros((0, *instructions.shape), tf.int32)
)

In [48]:
@tf.function
def solve2():
    def mapper(acc, program):
        if acc[0] != 0:
            return acc

        state = solve(program)
        
        if state[0] >= all_possible_programs.shape[1]:
            return tf.stack([1, state[1]], 0)
        else:
            return acc

    return tf.foldl(
        mapper,
        all_possible_programs,
        tf.constant([0, 0])
    )

In [49]:
%%time
solve2()[1]

CPU times: user 1.4 s, sys: 34.4 ms, total: 1.44 s
Wall time: 1.21 s


<tf.Tensor: shape=(), dtype=int32, numpy=761>

In [11]:
import datetime
# Set up logging.
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/day8_graph/%s' % stamp

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.profiler.experimental.start(logdir)
solve2()
tf.profiler.experimental.stop()