In [1]:
import tensorflow as tf

In [2]:
# Fixes bad convolution
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [41]:
s = tf.io.read_file('day9.txt')
s = tf.strings.split(s, '\n')
s = tf.strings.to_number(s, tf.float64)

from itertools import combinations

def generate_onehots(n, k=2):
    return tf.stack(
        [tf.scatter_nd(tf.constant(comb)[:,None], tf.ones(k, tf.float64), (n,)) for comb in combinations(range(n), k)],
        0
    )

@tf.function(experimental_compile=True, experimental_relax_shapes=True)
def solve_xla(s, n, k):
    onehots = generate_onehots(n, k)

    res = tf.squeeze(tf.map_fn(
        lambda i: onehots @ tf.slice(s, [i], [n])[:,None],
        tf.range(0, s.shape[0]-n),
        fn_output_signature=tf.TensorSpec([onehots.shape[0], 1], tf.float64)
    ))

    res = ~tf.reduce_any(res == s[n:, None], 1)
    return res

@tf.function
def solve1(s, n, k):
    res = solve_xla(s, n, k)
    return s[n+tf.where(res)[0][0]]

n = 25
k = 2

%timeit solve1(s, n, k)
solve1(s, n, k)

15.4 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


<tf.Tensor: shape=(), dtype=float64, numpy=257342611.0>

In [57]:
@tf.function(experimental_compile=True)
def single_conv(s_conv, k):
    res = tf.nn.conv1d(s_conv, tf.ones((k, 1, 1), tf.float64), stride=1, padding='VALID')
    res = tf.concat([res, tf.zeros((1, k-1, 1), tf.float64)], axis=1)
    return tf.squeeze(res)

@tf.function
def solve2(s, x):
    with tf.xla.experimental.jit_scope():
        s_conv = s[None,:,None]
        rnge = tf.range(2, s.shape[0], dtype=tf.int64)
        convs = tf.map_fn(
            lambda k: single_conv(s_conv, k),
            rnge,
            fn_output_signature=tf.TensorSpec((None,), tf.float64)
        )

    k = tf.where(convs == x)[0][0]+2
    idx = tf.where(convs == x)[0][1]
    
    consecutive_values = s[idx:idx+k]
    
    return tf.reduce_min(consecutive_values) + tf.reduce_max(consecutive_values)


%timeit solve2(s, solve1(s, n, k))

293 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [58]:
import datetime
# Set up logging.
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/day9/%s' % stamp

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.profiler.experimental.start(logdir)
solve2(s, solve1(s, n, k))
tf.profiler.experimental.stop()