In [None]:
import jax.numpy as jnp
from jax import jit
import numpy as np
%load_ext line_profiler

In [None]:
import jax
# jax.config.update("jax_default_device", jax.devices("cpu")[0])

In [None]:
for use_jax, use_jit in ((False, False), (True, False), (True, True)):
    print("----------------")
    print("With Jax" if use_jax else "With Numpy")
    xp = jnp if use_jax else np

    n = int(1e8)
    fp32 = xp.ones((n,), np.float32)
    in32 = xp.ones((n, ), np.int32)
    in16 = xp.ones((n,), np.int16)
    in8 = xp.ones((n, ), np.int8)

    def process_tst(f):
        return f * f + f + 3*f + 2* f

    if use_jit:
        print("Using JIT")
        process_tst = jit(process_tst) 

    for f in (fp32, in8, in16, in32):
        print(f.dtype)
        if use_jax:
            %timeit process_tst(f).block_until_ready()
        else:
            %timeit process_tst(f)

----------------
With Numpy
float32
276 ms ± 280 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
int8
76.8 ms ± 67.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
int16
143 ms ± 484 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
int32
283 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
----------------
With Jax
float32
25.7 ms ± 8.82 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
int8
7.21 ms ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
int16
13.3 ms ± 17.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
int32
25.7 ms ± 14.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
----------------
With Jax
Using JIT
float32
3.8 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
int8
1.17 ms ± 6.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int16
1.95 ms ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int32
3.9 ms ± 31.8 µs per loop (mean ± std. dev. of 7 runs, 100

In [None]:
for use_jax, use_jit in ((False, False), (True, False), (True, True)):
    print("----------------")
    print("With Jax" if use_jax else "With Numpy")
    xp = jnp if use_jax else np

    n = int(1e6)
    in32 = np.random.randint(2**16, None, (n, ), np.uint32)
    in16 = np.random.randint(2**16, None, (n, ), np.uint16)
    in8 = np.random.randint(2**8, None, (n, ), np.uint8)

    if use_jax:
        in32 = jnp.asarray(in32)
        in16 = jnp.asarray(in16)
        in8 = jnp.asarray(in8)

    def process_tst(f):
        if use_jax:
            return jnp.bincount(f, length=2**16)
        return xp.bincount(f, minlength=2**16)

    if use_jit:
        print("Using JIT")
        process_tst = jit(process_tst) 

    for f in (in8, in16, in32):
        print(f.dtype)
        if use_jax:
            %timeit process_tst(f).block_until_ready()
        else:
            %timeit process_tst(f)

----------------
With Numpy
uint8
1.76 ms ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint16
2.21 ms ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
uint32
2.24 ms ± 7.38 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
----------------
With Jax
uint8
2.3 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
uint16
1.64 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint32
1.66 ms ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
----------------
With Jax
Using JIT
uint8
422 µs ± 6.85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint16
129 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
uint32
130 µs ± 608 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [None]:
for use_jax, use_jit in ((False, False), (True, False), (True, True)):
    print("----------------")
    print("With Jax" if use_jax else "With Numpy")
    xp = jnp if use_jax else np

    n = int(8e6)
    in32 = np.random.randint(2**20, None, (n, ), np.uint32)
    in16 = np.random.randint(2**16, None, (n, ), np.uint16)
    in8 = np.random.randint(2**8, None, (n, ), np.uint8)

    if use_jax:
        in32 = jnp.asarray(in32)
        in16 = jnp.asarray(in16)
        in8 = jnp.asarray(in8)

    def process_tst(f):
        return f.astype(xp.float32)

    if use_jit:
        print("Using JIT")
        process_tst = jit(process_tst) 

    for f in (in8, in16, in32):
        print(f.dtype)
        if use_jax:
            %timeit process_tst(f).block_until_ready()
        else:
            %timeit process_tst(f)

----------------
With Numpy
uint8
2.21 ms ± 4.05 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
uint16
2.27 ms ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
uint32
5.09 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
----------------
With Jax
uint8
409 µs ± 5.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint16
418 µs ± 3.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint32
474 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
----------------
With Jax
Using JIT
uint8
322 µs ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint16
338 µs ± 2.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
uint32
389 µs ± 2.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [None]:
@jit
def process_tst(a, b):
    c = jnp.zeros_like(a[0])
    for i in range(len(a)):
        c += b[i] * a[i]
    return c

a_arr = jnp.asarray(np.random.randint(2**20, None, (int(10**7), )), dtype=np.float32)
print(process_tst.lower([a_arr, a_arr, a_arr], [2, 3, 4]).as_text())

module @jit_process_tst {
  func.func public @main(%arg0: tensor<10000000xf32>, %arg1: tensor<10000000xf32>, %arg2: tensor<10000000xf32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>) -> tensor<10000000xf32> {
    %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10000000xf32>
    %2 = mhlo.convert(%arg3) : (tensor<i32>) -> tensor<f32>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10000000xf32>
    %4 = mhlo.multiply %3, %arg0 : tensor<10000000xf32>
    %5 = mhlo.add %1, %4 : tensor<10000000xf32>
    %6 = mhlo.convert(%arg4) : (tensor<i32>) -> tensor<f32>
    %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<10000000xf32>
    %8 = mhlo.multiply %7, %arg1 : tensor<10000000xf32>
    %9 = mhlo.add %5, %8 : tensor<10000000xf32>
    %10 = mhl

In [None]:
for use_jax, use_jit in ((False, False), (True, False), (True, True)):
    print("----------------")
    print("With Jax" if use_jax else "With Numpy")
    xp = jnp if use_jax else np

    n = int(1e6)
    in32 = xp.ones((n, ), np.int32)
    in16 = xp.ones((n,), np.int16)
    in8 = xp.ones((n, ), np.int8)

    def process_tst(f, b):
        c = 1
        acc = xp.zeros_like(f[0], dtype=xp.int16)
        for i in range(len(b)):
            acc = c * f[i] + acc
            c *= b[i]
        
        return acc

    b = [5, 5, 12, 6, 8]

    if use_jit:
        print("Using JIT")
        process_tst = jit(process_tst) 

    for f in (in8, in16, in32):
        print(f.dtype)

        f = [f + i for i in range(len(b))]

        if use_jax:
            # print(jit(process_tst).lower(f,b).as_text())
            %timeit process_tst(f, b).block_until_ready()
        else:
            %timeit process_tst(f, b)

----------------
With Numpy
int8
1.26 ms ± 3.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int16
1.02 ms ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int32
2.57 ms ± 5.05 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
----------------
With Jax
int8
790 µs ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int16
877 µs ± 2.98 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int32
911 µs ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
----------------
With Jax
Using JIT
int8
215 µs ± 3.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int16
273 µs ± 3.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
int32
338 µs ± 5.84 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
