In [1]:
import tensorflow as tf
try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    for device in visible_devices:
        assert device.device_type != 'GPU'
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

# Fixes bad convolution
#gpus = tf.config.experimental.list_physical_devices('GPU')
#for gpu in gpus:
#    tf.config.experimental.set_memory_growth(gpu, True)

In [2]:
tf.config.threading.set_inter_op_parallelism_threads(
    16
)
tf.config.threading.set_intra_op_parallelism_threads(
    16
)

In [3]:
inp = tf.constant([1, 3, 2])

@tf.function(experimental_compile=True)
def solve(inp, until_turn=2020):
    inp = tf.cast(inp, tf.int32)
    n = inp.shape[0]
    res = tf.pad(inp[None,:], [[0, 0], [0, until_turn - n]], constant_values=-1)[0]
    
    rnge = tf.range(until_turn+1, dtype=tf.int32)
    
    x = tf.constant(0, tf.int32)
    for i in rnge[n:-1]:
        mask = (tf.cast(res == x, tf.int32))*rnge[1:]
        max_val = tf.reduce_max(mask)
        
        if max_val == 0:
            t_ps = i
        else:
            t_ps = max_val-1

        #res += tf.one_hot(i, until_turn, dtype=tf.int32)*(x+1)
        res = tf.tensor_scatter_nd_update(res, [[i]], [x])

        x = i-t_ps
    
    return res[-1]

assert solve(tf.constant([0, 3, 6])) == 436
assert solve(tf.constant([1, 3, 2])) == 1
assert solve(tf.constant([2, 1, 3])) == 10
assert solve(tf.constant([1, 2, 3])) == 27

print(solve(tf.constant([0, 3, 6], tf.int64), 10))

tf.Tensor(0, shape=(), dtype=int32)


In [5]:
%%time
solve(tf.constant([2,0,1,9,5,19]), 300000)

CPU times: user 16.7 s, sys: 52.2 ms, total: 16.7 s
Wall time: 16.8 s


<tf.Tensor: shape=(), dtype=int32, numpy=8336>

In [None]:
%%time
solve(tf.constant([2,0,1,9,5,19]), 30_000_000)

In [30]:
import datetime
# Set up logging.
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/day15/%s' % stamp

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.profiler.experimental.start(logdir)
solve(tf.constant([2,0,1,9,5,19]), 30000)
tf.profiler.experimental.stop()

In [70]:
res = [2,0,1,9,5,19]
last_pos = dict(zip(res, range(len(res))))
for i in range(5, 10-1):
    print(i - last_pos.get(res[i], i))
    res.append(i - last_pos.get(res[i], i))
    last_pos[res[i]] = i

0
5
3
0


In [71]:
res

[2, 0, 1, 9, 5, 19, 0, 5, 3, 0]

In [15]:
res = tf.constant([2,0,1,9,5,19])

@tf.function
def solve1(res, n):
    last_pos = tf.lookup.experimental.DenseHashTable(
        tf.int32, tf.int32, -1, -1, -2, initial_num_buckets=2**16, name='stuffs4'
    )

    for i in tf.range(res.shape[0]):
        x = res[i]
        last_pos.insert(x, i)
    
    last_x = res[-1]
    
    for i in tf.range(res.shape[0]-1, n-1):
        last_pos_val = last_pos[last_x]
        
        if last_pos_val == -1:
            last_pos_val = i

        last_pos.insert_or_assign(last_x, i)
        last_x = i-last_pos_val

    return last_x

#solve(tf.constant([0,3,6]), 10)
#solve1(res, 2020)

In [16]:
type(solve1)

tensorflow.python.eager.def_function.Function

In [17]:
%%time
solve1(res, 30_000_000)

CPU times: user 6min 57s, sys: 49.6 ms, total: 6min 57s
Wall time: 6min 58s


<tf.Tensor: shape=(), dtype=int32, numpy=62714>