In [1]:
! /opt/bin/nvidia-smi

Tue Jun  1 08:42:14 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [16]:
import jax.numpy as jnp
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [8]:
user_dim = 10000
item_dim = 1000
epoch = 100

rating_mat = np.random.random(size=(user_dim, item_dim))

## Sklearn Version

In [4]:
%%time
sk_result = cosine_similarity(rating_mat)

CPU times: user 5.11 s, sys: 431 ms, total: 5.54 s
Wall time: 2.96 s


## Numpy Version

In [5]:
%%time
l2_vector = np.linalg.norm(rating_mat, axis=1)
numerator = np.tensordot(rating_mat, rating_mat, [1, 1])
denominator = np.einsum('i,j->ij', l2_vector, l2_vector)
np_result = numerator / denominator

CPU times: user 5.42 s, sys: 550 ms, total: 5.97 s
Wall time: 3.53 s


## Jax Version

In [7]:
%%time
jnp_rating_mat = jnp.array(rating_mat).astype(float)
jnp_rating_mat

jnp_l2_vector = jnp.linalg.norm(jnp_rating_mat, axis=1)
numerator = jnp.tensordot(jnp_rating_mat, jnp_rating_mat, [1, 1])
denominator = jnp.einsum('i,j->ij', jnp_l2_vector, jnp_l2_vector)
jnp_result = numerator / denominator

CPU times: user 19.7 ms, sys: 96 µs, total: 19.8 ms
Wall time: 22.6 ms


## JAX With JIT

In [17]:
from jax import jit

In [26]:
@jit
def jnp_cosine(rating_mat):
    jnp_l2_vector = jnp.linalg.norm(jnp_rating_mat, axis=1)
    numerator = jnp.tensordot(jnp_rating_mat, jnp_rating_mat, [1, 1])
    denominator = jnp.einsum('i,j->ij', jnp_l2_vector, jnp_l2_vector)
    jnp_result = numerator / denominator
    # return jnp_result

jnp_rating_mat = jnp.array(rating_mat).astype(float)

In [28]:
%%time
jnp_cosine(jnp_rating_mat)

CPU times: user 113 µs, sys: 17 µs, total: 130 µs
Wall time: 138 µs
