In [1]:
import tensorflow as tf
try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    for device in visible_devices:
        assert device.device_type != 'GPU'
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

# Fixes bad convolution
#gpus = tf.config.experimental.list_physical_devices('GPU')
#for gpu in gpus:
#    tf.config.experimental.set_memory_growth(gpu, True)

In [2]:
tf.config.threading.set_inter_op_parallelism_threads(
    16
)
tf.config.threading.set_intra_op_parallelism_threads(
    16
)

In [3]:
inp = tf.constant([1, 3, 2])

@tf.function(experimental_compile=True)
def solve(inp, until_turn=2020):
    ta = tf.TensorArray(tf.int32, size=until_turn, clear_after_read=False)
    n = inp.shape[0]
    
    for i in tf.range(n):
        turn = i+1
        ta = ta.write(inp[i], turn)
    
    x = 0

    for i in tf.range(n, until_turn-1):
        prev_turn = ta.read(x)
        ta = ta.write(x, i+1)
        
        x = 0
        if prev_turn > 0:
            x = i+1-prev_turn
    
    ta = ta.close()
    return x

assert solve(tf.constant([0, 3, 6])) == 436
assert solve(tf.constant([1, 3, 2])) == 1
assert solve(tf.constant([2, 1, 3])) == 10
assert solve(tf.constant([1, 2, 3])) == 27

In [4]:
%%timeit
solve(tf.constant([2,0,1,9,5,19]))

794 µs ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
%%time
solve(tf.constant([2,0,1,9,5,19]), 30_000_000)

In [30]:
import datetime
# Set up logging.
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/day15/%s' % stamp

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.profiler.experimental.start(logdir)
solve(tf.constant([2,0,1,9,5,19]), 30000)
tf.profiler.experimental.stop()