In [1]:
import jax
import jax.numpy as np
import numpy as npo


N = 100


def single(key):
    x = jax.random.normal(key, shape=(1,))
    for i in range(10):
        x = 0.9*np.sin(x)
    return x

def vectorized(key):
    x = jax.random.normal(key, shape=(N,))
    for i in range(10):
        x = 0.9*np.sin(x)
    return x

def vmapped(key):
    keys = jax.random.split(key, N)
    return jax.vmap(single)(keys)

js = jax.jit(single)
print(npo.array(js(jax.random.PRNGKey(0))))

def vjmapped(key):
    keys = jax.random.split(key, N)
    return jax.vmap(js)(keys)



[-0.0695273]


In [2]:
key0 = jax.random.PRNGKey(0)
keys = jax.random.split(key0, N)

In [3]:
single(key0)

DeviceArray([-0.0695273], dtype=float32)

In [4]:
vectorized(key0)

DeviceArray([-0.22113046,  0.21539974,  0.20367654,  0.01224993,
             -0.20849706,  0.2000417 ,  0.01088641, -0.12210669,
              0.0046047 , -0.1296929 , -0.22392079, -0.22453459,
              0.21145895,  0.19531998, -0.12734865,  0.2104892 ,
             -0.18043596, -0.19388075, -0.22146116,  0.05319687,
             -0.16321823,  0.2293054 ,  0.09361042,  0.17292249,
              0.02110242,  0.12417676, -0.07623257, -0.18289421,
              0.17586501, -0.22675812,  0.21495506, -0.22154386,
             -0.02405544,  0.16726421, -0.1206226 ,  0.10500219,
              0.09256326, -0.22916071,  0.2023129 , -0.19418928,
              0.11861441, -0.13776794,  0.04658704,  0.22273225,
              0.22800104,  0.2115433 , -0.0094401 , -0.22212198,
             -0.2092242 ,  0.2109681 ,  0.07433756,  0.14468569,
              0.22858517,  0.06808888, -0.16344416, -0.12271435,
              0.01742771,  0.21626122,  0.05558302, -0.18094029,
              0.1449989 ,

In [5]:
vmapped(key0)

DeviceArray([[ 0.21117093],
             [-0.22931617],
             [-0.22925569],
             [-0.18119279],
             [ 0.0666384 ],
             [-0.19494936],
             [-0.19547391],
             [ 0.2206912 ],
             [-0.03270767],
             [ 0.0915747 ],
             [ 0.2261649 ],
             [ 0.12471877],
             [-0.2200183 ],
             [-0.0991018 ],
             [ 0.22880913],
             [ 0.07258789],
             [-0.20554072],
             [-0.22866319],
             [-0.05997632],
             [ 0.06897736],
             [-0.22918601],
             [-0.00252465],
             [-0.07360218],
             [ 0.19401576],
             [ 0.204105  ],
             [-0.14778195],
             [-0.14712855],
             [-0.01202745],
             [ 0.01280874],
             [-0.20001645],
             [ 0.22448607],
             [ 0.19981967],
             [ 0.2051334 ],
             [-0.21738507],
             [ 0.22784998],
             [-0.182

In [6]:
vjmapped(key0)

DeviceArray([[ 0.21117093],
             [-0.22931617],
             [-0.22925569],
             [-0.18119279],
             [ 0.0666384 ],
             [-0.19494936],
             [-0.19547391],
             [ 0.2206912 ],
             [-0.03270767],
             [ 0.0915747 ],
             [ 0.2261649 ],
             [ 0.12471877],
             [-0.2200183 ],
             [-0.0991018 ],
             [ 0.22880913],
             [ 0.07258789],
             [-0.20554072],
             [-0.22866319],
             [-0.05997632],
             [ 0.06897736],
             [-0.22918601],
             [-0.00252465],
             [-0.07360218],
             [ 0.19401576],
             [ 0.204105  ],
             [-0.14778195],
             [-0.14712855],
             [-0.01202745],
             [ 0.01280874],
             [-0.20001645],
             [ 0.22448607],
             [ 0.19981967],
             [ 0.2051334 ],
             [-0.21738507],
             [ 0.22784998],
             [-0.182

In [7]:
import time

tstart = time.time()
jvec = jax.jit(vectorized)
for i in range(100):
    print(npo.array(jvec(jax.random.PRNGKey(i))).shape)
    print("iter {} time {}".format(i, time.time() - tstart))
tfin = time.time()
print("total time ", tfin - tstart)

(100,)
iter 0 time 0.7711019515991211
(100,)
iter 1 time 0.7740590572357178
(100,)
iter 2 time 0.7774081230163574
(100,)
iter 3 time 0.7812991142272949
(100,)
iter 4 time 0.7835369110107422
(100,)
iter 5 time 0.785358190536499
(100,)
iter 6 time 0.7906949520111084
(100,)
iter 7 time 0.796807050704956
(100,)
iter 8 time 0.8005430698394775
(100,)
iter 9 time 0.8044760227203369
(100,)
iter 10 time 0.8088340759277344
(100,)
iter 11 time 0.8118741512298584
(100,)
iter 12 time 0.8147993087768555
(100,)
iter 13 time 0.816993236541748
(100,)
iter 14 time 0.818878173828125
(100,)
iter 15 time 0.8208551406860352
(100,)
iter 16 time 0.8229122161865234
(100,)
iter 17 time 0.8253819942474365
(100,)
iter 18 time 0.8274581432342529
(100,)
iter 19 time 0.8296782970428467
(100,)
iter 20 time 0.8320682048797607
(100,)
iter 21 time 0.833981990814209
(100,)
iter 22 time 0.8359420299530029
(100,)
iter 23 time 0.8381912708282471
(100,)
iter 24 time 0.8400919437408447
(100,)
iter 25 time 0.8418729305267334
(

In [8]:
tstart = time.time()
jvmap = jax.jit(vmapped)
for i in range(100):
    print(npo.array(jvmap(jax.random.PRNGKey(i))).shape)
    print("iter {} time {}".format(i, time.time() - tstart))
tfin = time.time()
print("total time ", tfin - tstart)

(100, 1)
iter 0 time 1.5715088844299316
(100, 1)
iter 1 time 1.5739519596099854
(100, 1)
iter 2 time 1.5758638381958008
(100, 1)
iter 3 time 1.5780689716339111
(100, 1)
iter 4 time 1.580514907836914
(100, 1)
iter 5 time 1.5826067924499512
(100, 1)
iter 6 time 1.5856959819793701
(100, 1)
iter 7 time 1.5877890586853027
(100, 1)
iter 8 time 1.5897247791290283
(100, 1)
iter 9 time 1.5918428897857666
(100, 1)
iter 10 time 1.5939199924468994
(100, 1)
iter 11 time 1.5962097644805908
(100, 1)
iter 12 time 1.5981900691986084
(100, 1)
iter 13 time 1.600430965423584
(100, 1)
iter 14 time 1.6038970947265625
(100, 1)
iter 15 time 1.6057658195495605
(100, 1)
iter 16 time 1.6082777976989746
(100, 1)
iter 17 time 1.6102497577667236
(100, 1)
iter 18 time 1.6123027801513672
(100, 1)
iter 19 time 1.6146018505096436
(100, 1)
iter 20 time 1.6166329383850098
(100, 1)
iter 21 time 1.6190109252929688
(100, 1)
iter 22 time 1.621640920639038
(100, 1)
iter 23 time 1.6249828338623047
(100, 1)
iter 24 time 1.62710

In [9]:
tstart = time.time()
jvjmap = jax.jit(vjmapped)
for i in range(100):
    print(npo.array(jvjmap(jax.random.PRNGKey(i))).shape)
    print("iter {} time {}".format(i, time.time() - tstart))
tfin = time.time()
print("total time ", tfin - tstart)

(100, 1)
iter 0 time 1.883882999420166
(100, 1)
iter 1 time 1.8861749172210693
(100, 1)
iter 2 time 1.8886032104492188
(100, 1)
iter 3 time 1.8911170959472656
(100, 1)
iter 4 time 1.8930580615997314
(100, 1)
iter 5 time 1.8953239917755127
(100, 1)
iter 6 time 1.8973948955535889
(100, 1)
iter 7 time 1.8995299339294434
(100, 1)
iter 8 time 1.9017770290374756
(100, 1)
iter 9 time 1.9043021202087402
(100, 1)
iter 10 time 1.9072351455688477
(100, 1)
iter 11 time 1.9090681076049805
(100, 1)
iter 12 time 1.9113450050354004
(100, 1)
iter 13 time 1.9133188724517822
(100, 1)
iter 14 time 1.9153988361358643
(100, 1)
iter 15 time 1.9175679683685303
(100, 1)
iter 16 time 1.9196960926055908
(100, 1)
iter 17 time 1.922590970993042
(100, 1)
iter 18 time 1.9249532222747803
(100, 1)
iter 19 time 1.9271249771118164
(100, 1)
iter 20 time 1.9292511940002441
(100, 1)
iter 21 time 1.9314749240875244
(100, 1)
iter 22 time 1.9336950778961182
(100, 1)
iter 23 time 1.9358389377593994
(100, 1)
iter 24 time 1.9382

In [10]:
N = 100

def vectorized(key):
    a = np.array([0.])
    def inner(a):
        x = jax.random.normal(key, shape=(N,))
        for i in range(10):
            x = 0.9*np.sin(x + a * np.ones_like(x))
        return x.sum()
    for j in range(10):
        a = a - 0.01 * jax.grad(inner)(a)
    return a

def vmapped(key):
    keys = jax.random.split(key, N)
    a = np.array([0.])
    def inner_with_key(key, a):
        x = jax.random.normal(key, shape=(1,))
        for i in range(10):
            x = 0.9*np.sin(x + a * np.ones_like(x))
        return x.sum()
    
    def inner(a):
        return jax.vmap(inner_with_key, (0, None))(keys, a).sum()

    for j in range(10):
        a = a - 0.01 * jax.grad(inner)(a)
    return a

In [11]:
import time

tstart = time.time()
jvec = jax.jit(vectorized)
for i in range(100):
    print(npo.array(jvec(jax.random.PRNGKey(i))))
    print("iter {} time {}".format(i, time.time() - tstart))
tfin = time.time()
print("total time ", tfin - tstart)

[-1.1915913]
iter 0 time 188.9974389076233
[-1.145646]
iter 1 time 189.00037813186646
[-1.2541505]
iter 2 time 189.0032389163971
[-1.1513368]
iter 3 time 189.00603795051575
[-1.085938]
iter 4 time 189.00866413116455
[-1.1784648]
iter 5 time 189.0116889476776
[-1.0929024]
iter 6 time 189.0150170326233
[-1.1642891]
iter 7 time 189.01845622062683
[-1.0899792]
iter 8 time 189.02158117294312
[-1.0224978]
iter 9 time 189.02431416511536
[-1.081123]
iter 10 time 189.02697086334229
[-1.2290983]
iter 11 time 189.0298728942871
[-1.1880515]
iter 12 time 189.03252720832825
[-1.1630895]
iter 13 time 189.03510308265686
[-1.1443801]
iter 14 time 189.03792691230774
[-1.0529486]
iter 15 time 189.04055190086365
[-1.2250537]
iter 16 time 189.04353404045105
[-1.1913869]
iter 17 time 189.04935312271118
[-1.1156772]
iter 18 time 189.0523920059204
[-1.0557729]
iter 19 time 189.05539512634277
[-1.274738]
iter 20 time 189.05837607383728
[-1.167709]
iter 21 time 189.06147503852844
[-1.110923]
iter 22 time 189.06

In [12]:
tstart = time.time()
jvmap = jax.jit(vmapped)
for i in range(100):
    print(npo.array(jvmap(jax.random.PRNGKey(i))))
    print("iter {} time {}".format(i, time.time() - tstart))
tfin = time.time()
print("total time ", tfin - tstart)

[-1.1216568]
iter 0 time 191.3051872253418
[-1.1484584]
iter 1 time 191.30813217163086
[-1.0793839]
iter 2 time 191.31074023246765
[-1.106757]
iter 3 time 191.31344318389893
[-1.1128056]
iter 4 time 191.31700921058655
[-1.0794543]
iter 5 time 191.3196301460266
[-1.0981183]
iter 6 time 191.32279324531555
[-1.1954511]
iter 7 time 191.3255021572113
[-1.1200957]
iter 8 time 191.32800102233887
[-1.1268663]
iter 9 time 191.33052110671997
[-1.0237478]
iter 10 time 191.33359622955322
[-1.1707941]
iter 11 time 191.33603620529175
[-1.184675]
iter 12 time 191.33862709999084
[-1.1539623]
iter 13 time 191.34112310409546
[-1.2537794]
iter 14 time 191.34379816055298
[-1.1543491]
iter 15 time 191.34639835357666
[-1.0891943]
iter 16 time 191.3492453098297
[-1.1141614]
iter 17 time 191.35185027122498
[-1.1560342]
iter 18 time 191.35451912879944
[-1.2130715]
iter 19 time 191.35735607147217
[-1.2341095]
iter 20 time 191.35994005203247
[-1.0370938]
iter 21 time 191.36227917671204
[-1.2235575]
iter 22 time 