In [1]:
# comment this line in for cpu

CUDA_VISIBLE_DEVICES=""

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np
from numpy import random

jax.config.update('jax_platform_name', 'cpu')
#jax.config.update('jax_platform_name', 'gpu')

In [3]:
jax.devices()

[CpuDevice(id=0)]

In [4]:
@jit
#@jax.value_and_grad
def TV(x):
    """Computes the total variation norm and its gradient. From jcjohnson/cnn-vis."""
    eps = jnp.float32(1e-8)
    x_diff = x - jnp.roll(x, -1, axis=1)
    y_diff = x - jnp.roll(x, -1, axis=0)
    z_diff = x - jnp.roll(x, -1, axis=2)

    k = x_diff**2 + y_diff**2 + z_diff**2 + eps
    norm = jnp.sum(jnp.sqrt(k))
    return norm

grad_TV = grad(TV)

In [5]:
@jit
#@jax.value_and_grad
def TV_2d(x):
    """Computes the total variation norm and its gradient. From jcjohnson/cnn-vis."""
    eps = jnp.float32(1e-8)
    x_diff = x - jnp.roll(x, -1, axis=1)
    y_diff = x - jnp.roll(x, -1, axis=0)

    k = x_diff**2 + y_diff**2 + eps
    norm = jnp.sum(jnp.sqrt(k))
    return norm

grad_TV_2d = grad(TV_2d)

In [6]:
arr = np.array(random.rand(300, 300, 300), dtype=np.float32)
arr_gpu = np.array(random.rand(300, 300, 300), dtype=np.float32)
arr_gpu = device_put(arr_gpu)

In [7]:
%timeit TV(arr).block_until_ready()

48.5 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%timeit grad_TV(arr).block_until_ready()

365 ms ± 6.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%timeit TV(arr_gpu).block_until_ready();

1.04 ms ± 11.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
%timeit grad_TV(arr_gpu).block_until_ready();

11.4 ms ± 63.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
arr_2d = np.array(random.rand(512, 512), dtype=np.float32)
arr_2d_gpu = np.array(random.rand(512, 512), dtype=np.float32)
arr_2d_gpu = device_put(arr_2d_gpu)

In [10]:
%timeit TV_2d(arr_2d).block_until_ready()

137 µs ± 17.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%timeit grad_TV_2d(arr_2d).block_until_ready()

1.12 ms ± 165 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%timeit TV_2d(arr_2d_gpu).block_until_ready()

28.3 µs ± 7.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [17]:
%timeit grad_TV_2d(arr_2d_gpu).block_until_ready()

910 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
