## 一些简单的工具

refer: https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html

### random

In [17]:
from jax import jit

### ---- random test -----
# https://github.com/google/jax/issues/968#issuecomment-508179312
@jit
key = random.PRNGKey(0)  # <- initial
@jit
def split_and_sample(key):
  key, subkey = random.split(key)  # <- split
  val = random.normal(subkey, shape=(3,)) # <- usage
  return key, val

def sample_repeatedly_with_split(key):
  for _ in range(10000):
    key, _ = split_and_sample(key)

%timeit -r2 -n5 sample_repeatedly_with_split(key)

130 ms ± 34.8 ms per loop (mean ± std. dev. of 2 runs, 5 loops each)


why need `random.split`? 

In [19]:
key = random.PRNGKey(0)
print(random.normal(key, shape=(3,))) 
print(random.normal(key, shape=(3,)))  # remains same as previous

[ 1.8160863  -0.48262316  0.33988908]
[ 1.8160863  -0.48262316  0.33988908]


In [25]:
# correct usage
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,))) 
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  

[ 0.06513273  0.9301752  -1.0317023 ]
[ 0.8145605  1.0554591 -0.3506622]


### Timeit的测试

In [12]:
%timeit -r3 -n5 sample_repeatedly_with_split(key)  # only in jupyter notebook

# 其他方法
import time, timeit
# method2
ttt = timeit.timeit(lambda : sample_repeatedly_with_split(key), number=10)
print('method2: takes %.4f secs' % (ttt / 10))

# method3
now = time.time()
sample_repeatedly_with_split(key)
print('method3: takes %.4f secs' % (time.time() - now))

103 ms ± 2 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)
takes 0.1007 secs
takes 0.0941 secs


### Autoload功能


In [2]:
# Local imports from current directory - auto reload.
# Any changes you make to train.py will appear automatically.
%load_ext autoreload
%autoreload 2
from mycode.test import myNumber

In [5]:
myNumber()

1