# Basic usage

In [1]:
import tensorflow as tf
import numpy as np

from super_op import SuperOp

N = 3 # 100

x = tf.ones((N,))
a = tf.ones((N,)) * 81

super_op = SuperOp()

@super_op
def f(x):
    return (x + a / x) * 0.5


In [2]:
y = x
for i in range(10):
    y = f(y)
y

Tensor("Identity:0", shape=(3,), dtype=float32)(139772279188496)

In [3]:
super_op.compute(results=[{1: y}, x])

[{1: <tf.Tensor: shape=(3,), dtype=float32, numpy=array([9., 9., 9.], dtype=float32)>},
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>]

In [4]:
from super_op import SuperOp

N = 100

x = tf.ones((N,))
a = tf.ones((N,)) * 81

@tf.function(autograph=False, experimental_compile=True)
def fn(x, a, n):
    super_op = SuperOp()

    @super_op
    def f(x, a):
        return (x + a / x) * 0.5

    for i in range(n):
        x = f(x, a)
    return super_op.compute(results=x)

In [5]:
%%time
y = fn(x, a, 500)

CPU times: user 812 ms, sys: 62.5 ms, total: 875 ms
Wall time: 851 ms


In [6]:
%%timeit
y = fn(x, a, 500)

8.51 ms ± 58.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
y[0].numpy()

9.0

# Test while_loop parallelism with TAs

In [28]:
n_steps = 10

@tf.function(autograph=True, experimental_compile=True)
def fn(x, a):
    ta0 = tf.TensorArray(tf.float32, size=n_steps + 1)
    ta1 = tf.TensorArray(tf.float32, size=n_steps + 1)
    ta0 = ta1.write(0, x)
    ta1 = ta2.write(0, x)
    
    def body(i, ta0, ta1):
        def f0():
            pos = i // 2
            x = ta0.read(pos)
            return (
                i + 1,
                ta0.write(1 + pos, (x + a / x) * 0.5),
                ta1,
            )
        def f1():
            pos = i // 2
            x = ta1.read(pos)
            return (
                i + 1,
                ta0,
                ta1.write(1 + i // 2, (x + a / x) * 0.5),
            )
        return tf.switch_case(
            tf.cast(i % 2, int),
            branch_fns=[f0, f1],
        )

    _, ta0, ta1 = tf.while_loop(
        lambda i, ta1, ta2: tf.less(i, 2 * n_steps),
        body,
        [tf.constant(0), ta1, ta2],
    )
    return ta0.read(n_steps), ta1.read(n_steps)

In [47]:
n_steps = 1000

@tf.function(autograph=True, experimental_compile=True)
def fn(x, a):
    ta = tf.TensorArray(tf.float32, size=n_steps + 1)
    ta = ta.write(0, x).write(1, x)
    
    def body(i, ta):
        x = ta.read(i)
        return i + 1, ta.write(i + 2, (x + a / x) * 0.5)

    _, ta = tf.while_loop(
        lambda i, ta: tf.less(i, n_steps),
        body,
        [tf.constant(0), ta],
    )
    return ta.read(n_steps)

In [48]:
%%timeit
y = fn(x, a)

951 µs ± 16.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [258]:
n_steps = 1000
n = 4

@tf.function(autograph=True, experimental_compile=True)
def fn(x, a):
    ta = tf.TensorArray(tf.float32, size=n_steps + 1)
    ta = ta.write(0, x)
    
    def body(i, ta):
        def fi():
            x = ta.read(i)
            return x + 1 # (x + a / x) * 0.5
#         v = fi()
        v = tf.switch_case(
            i % tf.constant(n),
            branch_fns=[fi for _ in range(n)],
        )
#         v = ta.read(i) + 1
        return i + 1, ta.write(i + 1, v)

    _, ta = tf.while_loop(
        lambda i, ta: tf.less(i, n_steps + 1),
        body,
        [tf.constant(0), ta],
    )
    return ta.read(n_steps)

In [260]:
%%time
for _ in range(1000):
    y = fn(x, a)

CPU times: user 156 ms, sys: 0 ns, total: 156 ms
Wall time: 171 ms


In [1]:
import tensorflow as tf
import numpy as np

n_steps = 1000
n = 10
N = 20

x = tf.ones((N, N))
W = tf.eye(N)

def fn_grad(fn):
    @tf.function(autograph=False, experimental_compile=True)
    def wrapped(x):
        with tf.GradientTape() as tape:
            tape.watch(W)
            y = fn(x)
        return tape.gradient(y, W)
    return wrapped

# @fn_grad
@tf.function(autograph=False, experimental_compile=False)
def fn(x):
    def body(i, x):
        return i + 1, x @ W

    _, x = tf.while_loop(
        lambda i, x: tf.less(i, n_steps),
        body,
        [tf.constant(0), x],
        maximum_iterations=n_steps,
    )
    return x

@fn_grad
@tf.function(autograph=False, experimental_compile=True)
def fn_case(x):
    def body(i, x):
        def fi():
            return x @ W

        return i + 1, tf.switch_case(
            i % tf.constant(n),
            branch_fns=[fi for _ in range(n)],
        )

    _, x = tf.while_loop(
        lambda i, x: tf.less(i, n_steps),
        body,
        [tf.constant(0), x],
        maximum_iterations=n_steps,
    )
    return x

@fn_grad
@tf.function(autograph=False, experimental_compile=True)
def fn_unroll(x):
    for i in range(n_steps):
        x = x @ W
    return x

@fn_grad
@tf.function(autograph=False, experimental_compile=True)
def fn_case_3(x):
    def body(i, x1, x2, x3):

        return (i + 1, *tf.switch_case(
            i % tf.constant(3),
            branch_fns=[
                lambda: (x1, x1 @ W, x1),
                lambda: (x2, x2, x2 @ W),
                lambda: (x3 @ W, x3, x3),
            ],
        ))

    _, x1, x2, x3 = tf.while_loop(
        lambda i, *xs: tf.less(i, n_steps),
        body,
        [tf.constant(0), x, x, x],
        maximum_iterations=n_steps,
    )
    return x2

@fn_grad
@tf.function(autograph=False, experimental_compile=True)
def fn_hop(x):
    def body(i, x1, x2):
        return i + 1, x2, x1 @ W

    _, _, x = tf.while_loop(
        lambda i, x1, x2: tf.less(i, n_steps),
        body,
        [tf.constant(0), x, x],
        maximum_iterations=n_steps,
    )
    return x

In [2]:
from super_op import SuperOp

# @fn_grad
@tf.function(autograph=False, experimental_compile=False)
def fn_super(x):
    super_op = SuperOp()

    @super_op
    def f(x):
        return x @ W

    for i in range(n_steps):
        x = f(x)
    return super_op.compute(results=x)

In [16]:
%%time
y = fn(x)

CPU times: user 15.6 ms, sys: 0 ns, total: 15.6 ms
Wall time: 8.23 ms


In [18]:
with tf.profiler.experimental.Profile('logdir2'):
    for _ in range(1000):
        y = fn(x)

In [5]:
%%time
y = fn_super(x)

CPU times: user 93.8 ms, sys: 0 ns, total: 93.8 ms
Wall time: 82.7 ms


In [7]:
with tf.profiler.experimental.Profile('logdir'):
    for _ in range(10):
        y = fn_super(x)

In [8]:
%%time
y = fn_hop(x)

CPU times: user 15.6 ms, sys: 15.6 ms, total: 31.2 ms
Wall time: 4.68 ms


In [7]:
%%time
y = fn_case(x)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 3.83 ms


In [9]:
%%time
y = fn_unroll(x)

CPU times: user 8.8 s, sys: 125 ms, total: 8.92 s
Wall time: 8.97 s


In [11]:
%%time
y = fn_case_3(x)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 5.84 ms
