In [1]:
import dill

import jax
import jax.numpy as jnp
import numpy as np

## Load/create some data

In [2]:
with open("data/pmf_model.dill", "rb") as f:
    pmf = dill.load(f)
    
len(pmf.trace)

300

In [None]:
""" Train-test split """
train, test = train_test_split(pivoted_ratings.values, frac_test=0.1)

In [3]:
U = pmf.trace["U"][-150:]
Vt = np.transpose(pmf.trace["V"], axes=[0, 2, 1])[-150:]

print(U.shape, Vt.shape)

(150, 610, 10) (150, 10, 9724)


## Naive Python mapping

In [4]:
%%time
py_map_R = map(np.matmul, U, Vt)
py_map_R = np.mean(tuple(py_map_R), axis=0)
print(py_map_R.shape)

(610, 9724)
CPU times: user 8.16 s, sys: 25.2 s, total: 33.4 s
Wall time: 53.4 s


## Failed attempt with `np.tensordot`
Jupyter straight up crashed.

In [None]:
%%time
np_pred_R = np.tensordot(U, Vt, axes = [[2], [1]])

## `jax.vmap`

In [6]:
"""
    Transform jnp.matmul with jax.vmap, which vectorizes
    the function across the axes specified in `in_axes`. 
    
    In other words, `in_axes` specifies the batch dimension
    in each function argument --- of which there are two here.
"""

batch_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0))

In [7]:
%%time
jax_R = batch_matmul(U, Vt)
jax_R = jnp.mean(jax_R, axis=0).block_until_ready()

print(jax_R.shape)



(610, 9724)
CPU times: user 3.46 s, sys: 7.06 s, total: 10.5 s
Wall time: 2.91 s


In [None]:
j