In [1]:
CUDA_VISIBLE_DEVICES=""

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
import numpy as np
from numpy import random

jax.config.update('jax_platform_name', 'cpu')

In [3]:
jax.devices()

[CpuDevice(id=0)]

In [4]:
@jit
#@jax.value_and_grad
def TV(x):
    """Computes the total variation norm and its gradient. From jcjohnson/cnn-vis."""
    eps = jnp.float32(1e-8)
    x_diff = x - jnp.roll(x, -1, axis=1)
    y_diff = x - jnp.roll(x, -1, axis=0)
    z_diff = x - jnp.roll(x, -1, axis=2)

    k = x_diff**2 + y_diff**2 + z_diff**2 + eps
    norm = jnp.sum(jnp.sqrt(k))
    return norm

grad_TV = grad(TV)

In [5]:
arr = np.array(random.rand(300, 300, 300), dtype=np.float32)
#arr_gpu = np.array(random.rand(300, 300, 300), dtype=np.float32)
#arr_gpu = device_put(arr_gpu)

In [12]:
%timeit TV(arr).block_until_ready()

55.2 ms ± 180 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit grad_TV(arr).block_until_ready()

445 ms ± 4.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%timeit TV(arr).block_until_ready();

54.7 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
%timeit grad_TV(arr_gpu)[1].block_until_ready();

TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(17671286., dtype=float32), DeviceArray([[[ 0.96534854, -3.4273279 , -1.4476018 , ...,  1.0855145 ,
               -2.0105066 ,  2.2322562 ],
              [ 2.7970405 , -2.5609775 ,  2.7353256 , ...,  1.0619489 ,
                0.37275392,  2.3334022 ],
              [-1.9987123 ,  3.177341  ,  1.9881148 , ..., -3.2968678 ,
               -0.0631274 ,  3.221321  ],
              ...,
              [-1.166106  ,  3.5511615 , -3.2676556 , ...,  3.5613794 ,
               -3.0508475 , -3.0942762 ],
              [-2.4430842 , -2.0590582 , -2.1686807 , ..., -1.0712436 ,
                2.6748548 ,  2.767263  ],
              [ 3.0198536 ,  0.63562846, -1.8635169 , ...,  3.8508499 ,
               -4.119897  , -0.6722129 ]],

             [[-0.647869  ,  0.9406335 ,  2.7573462 , ..., -0.54172194,
               -2.4327872 , -0.86857235],
              [ 2.5939012 , -2.079701  , -2.8692634 , ..., -4.126902  ,
                2.6809306 , -3.1200614 ],
              [-1.5869396 ,  2.1581225 , -3.1563294 , ..., -0.4987429 ,
               -1.2416193 , -3.3769696 ],
              ...,
              [ 3.8213437 , -0.7240492 ,  3.6887422 , ..., -3.968338  ,
                2.373451  , -1.9470073 ],
              [-2.3910794 ,  2.3106031 , -1.3819578 , ...,  2.6784809 ,
                1.4013394 ,  2.0291133 ],
              [-3.4726071 ,  3.0056937 ,  3.0869412 , ...,  0.14068782,
                1.1234863 ,  3.0702653 ]],

             [[-3.1660814 , -3.201804  , -2.9225874 , ..., -2.6274424 ,
                3.5036032 ,  3.1592984 ],
              [-0.93464255,  4.0913296 ,  1.6997986 , ...,  2.464168  ,
               -3.8156023 ,  0.40644357],
              [-2.1809666 , -3.1216664 ,  0.5178102 , ..., -2.6765482 ,
                2.9823837 , -2.0915003 ],
              ...,
              [-2.868036  , -3.0273376 ,  2.5319588 , ...,  2.8445148 ,
               -3.9747362 ,  3.0566368 ],
              [ 4.0908422 , -2.3205836 ,  2.3634949 , ..., -1.2834951 ,
                1.1942139 , -3.6900454 ],
              [ 1.5513325 ,  2.439921  ,  0.2977907 , ..., -2.529748  ,
                1.6824632 , -2.8373716 ]],

             ...,

             [[ 3.8147616 , -3.7542758 ,  1.5250633 , ..., -2.7279952 ,
                0.35568264, -3.847764  ],
              [-3.2667716 ,  2.7230728 , -3.1768975 , ...,  2.633297  ,
                0.51907516, -1.7000432 ],
              [ 2.250659  ,  0.01782985, -2.0749784 , ...,  3.5229945 ,
               -2.3341224 ,  3.1259098 ],
              ...,
              [-1.5475612 ,  0.51448226, -3.7034788 , ..., -2.037002  ,
               -2.8363485 , -1.9102011 ],
              [-2.377885  , -2.2744548 , -2.4146771 , ..., -3.132443  ,
                0.638082  ,  2.6628275 ],
              [-1.726836  ,  2.237823  ,  3.331149  , ..., -2.3354738 ,
                0.7264925 ,  2.1241164 ]],

             [[ 0.11794241,  3.8662171 , -1.075875  , ..., -0.15200943,
               -3.6879306 ,  3.4212592 ],
              [-1.5799185 , -3.193065  , -2.1996033 , ...,  0.91134626,
                0.01865485, -0.9803102 ],
              [ 3.5488555 ,  2.8206198 , -2.1120589 , ..., -3.2985582 ,
               -0.40856424, -1.0444915 ],
              ...,
              [-3.4074855 ,  3.50187   , -0.4910264 , ..., -1.5920814 ,
                0.24366108,  2.5394835 ],
              [ 4.1576076 ,  1.9923874 , -0.8566738 , ...,  4.0228434 ,
                0.6101773 ,  0.26053023],
              [-0.97552145, -3.1702454 , -2.728284  , ..., -0.8402343 ,
                3.1847768 ,  0.6829329 ]],

             [[-3.9316044 , -2.452777  ,  3.6367161 , ...,  3.8063226 ,
               -1.2200758 ,  0.44495025],
              [-1.0987899 ,  3.6874304 ,  0.3721829 , ..., -3.0892088 ,
                2.315102  ,  2.3488479 ],
              [ 1.3735487 ,  1.0644883 , -3.3185163 , ..., -0.6398553 ,
                4.26058   , -2.6378174 ],
              ...,
              [-2.4458096 ,  1.1328197 ,  4.0723324 , ..., -0.2550786 ,
                3.0678935 ,  2.4384172 ],
              [-1.9811436 , -3.1672482 ,  1.2112219 , ..., -2.9548995 ,
               -2.841776  ,  0.8732394 ],
              [ 3.6999073 ,  1.8002665 ,  0.09203893, ..., -1.9298139 ,
                2.8303607 , -4.125231  ]]], dtype=float32)).

In [22]:
%timeit TV(arr_gpu).block_until_ready();

895 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [24]:
%timeit grad_TV(arr_gpu).block_until_ready();

10.4 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
