**Table of contents**<a id='toc0_'></a>    
- 1. [概要](#toc1_)    
  - 1.1. [非纯函数](#toc1_1_)    
  - 1.2. [纯函数](#toc1_2_)    
- 2. [jax的一般特性](#toc2_)    
  - 2.1. [加速-jit()](#toc2_1_)    
  - 2.2. [自动微分-grad()](#toc2_2_)    
    - 2.2.1. [同时获取函数值与导数-jax.value_and_grad()](#toc2_2_1_)    
    - 2.2.2. [多元函数求导](#toc2_2_2_)    
    - 2.2.3. [多返回值函数的求导](#toc2_2_3_)    
    - 2.2.4. [高阶导函数](#toc2_2_4_)    
  - 2.3. [自动向量化-vmap()/pmap()](#toc2_3_)    
    - 2.3.1. [例1](#toc2_3_1_)    
    - 2.3.2. [例2](#toc2_3_2_)    
    - 2.3.3. [例3](#toc2_3_3_)    
    - 2.3.4. [数组索引](#toc2_3_4_)    
      - 2.3.4.1. [二维数组](#toc2_3_4_1_)    
      - 2.3.4.2. [三维数组](#toc2_3_4_2_)    
    - 2.3.5. [vmap操作](#toc2_3_5_)    
        - 2.3.5.1.1. [vmap实战案例](#toc2_3_5_1_1_)    
          - 2.3.5.1.1.1. [batch的实现](#toc2_3_5_1_1_1_)    
- 3. [jax的高级特性](#toc3_)    
  - 3.1. [jax.numpy特性](#toc3_1_)    
    - 3.1.1. [赋值](#toc3_1_1_)    
    - 3.1.2. [数组规范](#toc3_1_2_)    
    - 3.1.3. [算数运算](#toc3_1_3_)    
  - 3.2. [jax控制分支](#toc3_2_)    
    - 3.2.1. [分支对grad影响](#toc3_2_1_)    
    - 3.2.2. [分支对jit影响](#toc3_2_2_)    
    - 3.2.3. [条件判断-jax.lax.cond(True, func1, func2, args)](#toc3_2_3_)    
    - 3.2.4. [循环-jax.lax.while_loop()](#toc3_2_4_)    
    - 3.2.5. [循环-jax.lax.fori_loop()](#toc3_2_5_)    
  - 3.3. [jax.nn包含的函数](#toc3_3_)    
  - 3.4. [jax.example_libraries](#toc3_4_)    
- 4. [多层感知机](#toc4_)    
  - 4.1. [准备数据集](#toc4_1_)    
    - 4.1.1. [mnist](#toc4_1_1_)    
    - 4.1.2. [独热码（one-hot）](#toc4_1_2_)    
      - 4.1.2.1. [自定义函数实现](#toc4_1_2_1_)    
      - 4.1.2.2. [jax.nn.one_hot()实现](#toc4_1_2_2_)    
  - 4.2. [JAX实现全连接层](#toc4_2_)    

<!-- vscode-jupyter-toc-config
	numbering=true
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [10]:
import jax
jax.__version__

'0.4.19'

# tensorflow_datasets

In [1]:
# !pip install tensorflow_datasets==4.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds

  from .autonotebook import tqdm as notebook_tqdm


# 1. <a id='toc1_'></a>[概要](#toc0_)
```
numpy:python中高效的科学计算模块，但是支持cpu加速；
jax:google推出的可以利用cpu、gpu、tpu进行加速的计算模块：
    random          产生随机数
    numpy           数学计算，同python中的numpy模块，接口使用非常类似，可以无缝衔接
    scipy           统计分析
    nn              神经网络类的计算库
    experimental    一些实验形式的内容
    jit
    grad

JAX: 有个很强的特性，支支持“纯函数”，不支持OOP；
    纯函数：
        1. 相同输入总是返回相同输出：
            返回值只和函数参数有关，与外部无关。（无论外部如何变化，函数的返回值都不会发生变化。）
        2. 不产生副作用：
            函数执行过程中对外部产生了可观察的变化，称之为 函数产生了副作用。
        3. 不依赖于外部状态：
            函数执行的过程中不会对外部产生可观察到的变化。
为此google推出了dm-Haiku集成jax的library，支持OOP；
```

## 1.1. <a id='toc1_1_'></a>[非纯函数](#toc0_)

In [14]:
# 因为print的存在，编程了非纯函数
import jax
import jax.numpy as jnp

def impure_print_side_effect(x):
    print(f"实施函数计算")
    return x

print(f"First call: {jax.jit(impure_print_side_effect)(4.0)}")
print('-'*100)

print(f"Second call: {jax.jit(impure_print_side_effect)(5.0)}") # 与第一次数据类型一致，出发JIT缓存机制，不答应print内容了
print('-'*100)

print(f"Third call: {jax.jit(impure_print_side_effect)([5.0])}")
print('-'*100)

实施函数计算
First call: 4.0
----------------------------------------------------------------------------------------------------
Second call: 5.0
----------------------------------------------------------------------------------------------------
实施函数计算
Third call: [Array(5., dtype=float32, weak_type=True)]
----------------------------------------------------------------------------------------------------


In [8]:
# 函数内部参数影响外部参数
g = 0.
def impure_saves_global(x):
    global g
    g = x
    return x

print(f"First call: {jax.jit(impure_saves_global)(4.0)}")

print(f"Saved global: {g}")

First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


## 1.2. <a id='toc1_2_'></a>[纯函数](#toc0_)

In [16]:
# 计算结果没有对外部函数做出任何影响，因此未纯函数
def pure_uses_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state['even' if i % 2 == 0 else 'odd'] += x
    
    return state['even'] + state['odd']

print(f"{jax.jit(pure_uses_internal_state)(3.)}")
print(f"{jax.jit(pure_uses_internal_state)(6.)}")

print(f"{jax.jit(pure_uses_internal_state)(jnp.array([5.0]))}")

30.0
60.0
[50.]


In [9]:
help(jax.grad)

Help on function grad in module jax._src.api:

grad(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False, reduce_axes: 'Sequence[AxisName]' = ()) -> 'Callable'
    Creates a function that evaluates the gradient of ``fun``.
    
    Args:
      fun: Function to be differentiated. Its arguments at positions specified by
        ``argnums`` should be arrays, scalars, or standard Python containers.
        Argument arrays in the positions specified by ``argnums`` must be of
        inexact (i.e., floating-point or complex) type. It
        should return a scalar (which includes arrays with shape ``()`` but not
        arrays with shape ``(1,)`` etc.)
      argnums: Optional, integer or sequence of integers. Specifies which
        positional argument(s) to differentiate with respect to (default 0).
      has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
        first element is conside

# 2. <a id='toc2_'></a>[jax的一般特性](#toc0_)
## 2.1. <a id='toc2_1_'></a>[加速-jit()](#toc0_)
```
JIT：及时编译器。当采用标准的Python、Numpy函数，经过JIT编译后可在加速器上高效运行。
```

In [50]:
import time
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x>0, x, alpha*jnp.exp(x) - alpha)

x = jax.random.normal(jax.random.PRNGKey(17), (10000000,))

# 普通计算
start = time.time()
selu(x)
stop = time.time(); print(f"耗时：{stop - start}")

# 方法一：jax.jit()编译
start = time.time()
selu_jited = jax.jit(selu) # 利用jax.jit()编译
selu_jited(x)
stop = time.time(); print(f"耗时：{stop - start}")

# 方法二： @jax.jit装饰器修饰
@jax.jit
def selu_jit(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x>0, x, alpha*jnp.exp(x) - alpha)
start = time.time()
selu_jit(x)
stop = time.time(); print(f"耗时：{stop - start}")

耗时：0.04858279228210449
耗时：0.05484318733215332
耗时：0.0653378963470459


## 2.2. <a id='toc2_2_'></a>[自动微分-grad()](#toc0_)
```
grad:
    1. 必须使用浮点型数值计算
```

In [64]:
help(jax.grad)

Help on function grad in module jax._src.api:

grad(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False, reduce_axes: 'Sequence[AxisName]' = ()) -> 'Callable'
    Creates a function that evaluates the gradient of ``fun``.
    
    Args:
      fun: Function to be differentiated. Its arguments at positions specified by
        ``argnums`` should be arrays, scalars, or standard Python containers.
        Argument arrays in the positions specified by ``argnums`` must be of
        inexact (i.e., floating-point or complex) type. It
        should return a scalar (which includes arrays with shape ``()`` but not
        arrays with shape ``(1,)`` etc.)
      argnums: Optional, integer or sequence of integers. Specifies which
        positional argument(s) to differentiate with respect to (default 0).
      has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
        first element is conside

In [30]:
import jax
import jax.numpy as jnp

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.) # 必须是浮点型

Dfn = jax.grad(sum_logistic)

Dfn_jited = jax.jit(Dfn) # 对导函数加速

In [31]:
%timeit Dfn(x_small)

8.86 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [32]:
%timeit Dfn_jited(x_small) # 使用jit加速

9.66 µs ± 389 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### 2.2.1. <a id='toc2_2_1_'></a>[同时获取函数值与导数-jax.value_and_grad()](#toc0_)

In [35]:
import jax
import jax.numpy as jnp

def body_fn(x):
    return x**2

jax.value_and_grad(body_fn)(1.)
# 函数值：1
# 导函数值：2

(Array(1., dtype=float32, weak_type=True),
 Array(2., dtype=float32, weak_type=True))

### 2.2.2. <a id='toc2_2_2_'></a>[多元函数求导](#toc0_)
```
argnums=(0,1,2)
```

In [42]:
import jax
import jax.numpy as jnp

def body_fun(x, y):
    return x*y
grad_body_fun = jax.grad(body_fun)
x = (2.)
y = (3.)

dx, dy = (jax.grad(body_fun, argnums=(0, 1))(x, y))
dx, dy

(Array(3., dtype=float32, weak_type=True),
 Array(2., dtype=float32, weak_type=True))

In [54]:
import jax
import jax.numpy as jnp

def body_fun(x, y, z):
    return x*y*z

x = 2.
y = 3.
z = 4.
dx, dy, dz = jax.grad(body_fun, argnums=(0, 1, 2))(x, y, z)   # dx, dy, dz
# dx, dy = jax.grad(body_fun, argnums=(0, 1))(x, y, z)          # dx, dy
# dx = jax.grad(body_fun, argnums=(0))(x, y, z)                 # dx, dy, dz
dx, dy, dz

(Array(12., dtype=float32, weak_type=True),
 Array(8., dtype=float32, weak_type=True),
 Array(6., dtype=float32, weak_type=True))

### 2.2.3. <a id='toc2_2_3_'></a>[多返回值函数的求导](#toc0_)
```
has_aux=True
```

In [63]:
import jax
import jax.numpy as jnp

def body_fn(x, y):
    return x*y, x**2+y**2

x = 2.
y = 3.
jax.grad(body_fn, has_aux=True)(x, y)

(Array(3., dtype=float32, weak_type=True),
 Array(13., dtype=float32, weak_type=True))

### 2.2.4. <a id='toc2_2_4_'></a>[高阶导函数](#toc0_)

In [66]:
import jax

fun = lambda x: x**3 + 2*x**2 - 3*x + 1

dx = jax.grad(fun) # 一阶导函数
dx(1.)

Array(4., dtype=float32, weak_type=True)

In [67]:
ddx = jax.grad(dx) # 二阶导函数
ddx(1.)

Array(10., dtype=float32, weak_type=True)

In [69]:
dddx = jax.grad(ddx) # 三阶导函数
dddx(1.)

Array(6., dtype=float32, weak_type=True)

## 2.3. <a id='toc2_3_'></a>[自动向量化-vmap()/pmap()](#toc0_)
```
什么是向量化？
    vmap 在单张GPU的多个CUDA核心上并行计算
    pmap 在单台机器的多个GPU计算卡上并行计算
```

### 2.3.1. <a id='toc2_3_1_'></a>[例1](#toc0_)

In [3]:
import jax
import jax.numpy as jnp
import time

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

In [6]:
start = time.time()

x_small = jnp.arange(1024000.)
Dfn = (jax.grad(sum_logistic))

stop = time.time()
print(f"耗时：{stop - start}")

耗时：0.0020017623901367188


In [7]:
start = time.time()

x_small = jnp.arange(1024000.)
Dfn = jax.vmap(jax.grad(sum_logistic))

stop = time.time()
print(f"耗时：{stop - start}")

耗时：0.0019969940185546875


In [9]:
start = time.time()

x_small = jnp.arange(1024000.)
Dfn = jax.jit(jax.vmap(jax.grad(sum_logistic)))

stop = time.time()
print(f"耗时：{stop - start}")

耗时：0.003172159194946289


### 2.3.2. <a id='toc2_3_2_'></a>[例2](#toc0_)

In [2]:
import jax
import jax.numpy as jnp

def vec_vec(x, y):
    return jnp.dot(x, y)

x = jnp.array([1, 2, 1])
y = jnp.array([2, 1, 2])

vec_vec(x, y)

Array(6, dtype=int32)

### 2.3.3. <a id='toc2_3_3_'></a>[例3](#toc0_)

In [1]:
import jax

jax.device_count()

1

### 2.3.4. <a id='toc2_3_4_'></a>[数组索引](#toc0_)
```
知乎教程：https://zhuanlan.zhihu.com/p/476098317
规律：可见索引的规律如下：索引的轴维度消失，前后维度合并。
```

#### 2.3.4.1. <a id='toc2_3_4_1_'></a>[二维数组](#toc0_)

In [95]:
import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.PRNGKey(55), (3, 2))
x, x.shape

(Array([[ 0.45765412, -1.8838878 ],
        [-1.7656637 , -0.32822677],
        [ 0.04516343,  1.7529849 ]], dtype=float32),
 (3, 2))

In [72]:
# 索引第0行，第0/1个元素, 类似矩阵索引。
x[0,0], x[0,1]

(Array(0.45765412, dtype=float32), Array(-1.8838878, dtype=float32))

In [74]:
# 索引一整行:
x[0, :]

Array([ 0.45765412, -1.8838878 ], dtype=float32)

In [75]:
# 索引一整列:
x[:, 0]

Array([ 0.45765412, -1.7656637 ,  0.04516343], dtype=float32)

In [76]:
# 对axes=0 索引：
for i in range(x.shape[0]):
    print(x[i,:])

[ 0.45765412 -1.8838878 ]
[-1.7656637  -0.32822677]
[0.04516343 1.7529849 ]


In [77]:
# 对axes=1 每个元素索引：
for i in range(x.shape[1]):
    print(x[:, i])

[ 0.45765412 -1.7656637   0.04516343]
[-1.8838878  -0.32822677  1.7529849 ]


#### 2.3.4.2. <a id='toc2_3_4_2_'></a>[三维数组](#toc0_)

In [79]:
# 生成一个三维数组
x = jax.random.normal(jax.random.PRNGKey(55), (3, 4, 2))
x

Array([[[-1.2363433 ,  0.711691  ],
        [ 1.1999513 , -0.3833862 ],
        [-1.9615409 , -0.91742545],
        [-0.2196888 ,  1.1890033 ]],

       [[-1.6743991 ,  0.1835342 ],
        [-1.372812  , -1.2838745 ],
        [-0.56314117,  0.11130438],
        [-0.9001647 , -0.612242  ]],

       [[-0.2662137 , -1.3849598 ],
        [-1.0626214 , -0.24122705],
        [-0.14088325, -0.3180565 ],
        [ 0.36471063,  0.46731815]]], dtype=float32)

In [84]:
# for i in range(x.shape[0]):
#     print(x[i,:,:])
x[0, :, :]

Array([[-1.2363433 ,  0.711691  ],
       [ 1.1999513 , -0.3833862 ],
       [-1.9615409 , -0.91742545],
       [-0.2196888 ,  1.1890033 ]], dtype=float32)

In [82]:
# 索引所有矩阵的第0行组成的矩阵：
x[:, 0, :]

Array([[-1.2363433,  0.711691 ],
       [-1.6743991,  0.1835342],
       [-0.2662137, -1.3849598]], dtype=float32)

In [83]:
# 索引所有矩阵的第0列组成的矩阵的转置：
x[:, :, 0]

Array([[-1.2363433 ,  1.1999513 , -1.9615409 , -0.2196888 ],
       [-1.6743991 , -1.372812  , -0.56314117, -0.9001647 ],
       [-0.2662137 , -1.0626214 , -0.14088325,  0.36471063]],      dtype=float32)

### 2.3.5. <a id='toc2_3_5_'></a>[vmap操作](#toc0_)
```
vmap(fun: 'F', in_axes: 'int | None | Sequence[Any]' = 0, out_axes: 'Any' = 0, axis_name: 'AxisName | None' = None, axis_size: 'int | None' = None, spmd_axis_name: 'AxisName | tuple[AxisName, ...] | None' = None) -> 'F'
    Vectorizing map. Creates a function which maps ``fun`` over argument axes.

fun: 代表你需要进行向量化操作的具体函数；
in_axes：输入格式为元组，代表fun中每个输入参数中，使用哪一个维度进行向量化；
out_axes: 经过fun计算后，每组输出在哪个维度输出。

在介绍了numpy多维数组的索引之后，vmap的in_axes参数的功能就一目了然了。其本质就是对某轴进行索引，得到n个新的数组，将这些数组传递给fun函数进行操作后叠加，用通用公式表示：

# 1. 定义某种函数，以dot为例:
fun = lambda x,y :  jnp.dot(x, y)

# 2. 切片后数据有fun进行操作
for i in range(x.shape[in_axes]):
	fun(x[:, i, :], y[:, i, :])
	
# 3. 所有数组stack在out_axes
```

In [109]:
# 定义jnp.dot(x,y)函数
f = lambda x,y : jnp.dot(x,y)

# 初始化x,y样本
x = jax.random.normal(jax.random.PRNGKey(55), (4, 3))
y = jax.random.normal(jax.random.PRNGKey(42), (3, 2))
x, x.shape, y, y.shape

(Array([[ 0.40211347,  1.0316547 ,  0.24331902],
        [-1.1584883 , -1.2835754 , -0.56345284],
        [-0.01159265,  0.17644508, -0.5234676 ],
        [-0.01921092, -0.66263014, -0.22824208]], dtype=float32),
 (4, 3),
 Array([[ 0.6122652 ,  1.1225883 ],
        [ 1.1373317 , -0.8127325 ],
        [-0.890405  ,  0.12623145]], dtype=float32),
 (3, 2))

In [101]:
# 先看看原始的jnp.dot的计算结果的shape和我们直接用矩阵乘法预期一致：
z = f(x,y)
z, z.shape

(Array([[ 1.2028812 , -0.35633698],
        [-1.667452  , -0.3284274 ],
        [ 0.65967697, -0.2224945 ],
        [-0.5621646 ,  0.48816377]], dtype=float32),
 (4, 2))

In [106]:
# 接下来，我们使用vmap来对x的第0轴索引后产生的每个新数组，与y进行jnp.dot():
jax.vmap(f, in_axes=(0,None), out_axes=0)(x,y)

Array([[ 1.2028812 , -0.35633698],
       [-1.667452  , -0.3284274 ],
       [ 0.65967697, -0.2224945 ],
       [-0.5621646 ,  0.48816377]], dtype=float32)

In [125]:
# x 第0维有4个array
for i in range(x.shape[0]):
	
	print(x[i, :])
	print(jnp.dot(x[i,:], y))
	print("===============================")

[0.40211347 1.0316547  0.24331902]
[ 1.2028812  -0.35633698]
[-1.1584883  -1.2835754  -0.56345284]
[-1.667452  -0.3284274]
[-0.01159265  0.17644508 -0.5234676 ]
[ 0.65967697 -0.2224945 ]
[-0.01921092 -0.66263014 -0.22824208]
[-0.5621646   0.48816377]


In [126]:
jax.vmap(f, (None, 1), 0)(x,y)

Array([[ 1.2028812 , -1.667452  ,  0.65967697, -0.5621646 ],
       [-0.35633698, -0.3284274 , -0.2224945 ,  0.48816377]],      dtype=float32)

##### 2.3.5.1.1. <a id='toc2_3_5_1_1_'></a>[vmap实战案例](#toc0_)
###### 2.3.5.1.1.1. <a id='toc2_3_5_1_1_1_'></a>[batch的实现](#toc0_)

In [129]:
# 定义函数：
f = lambda x,w : jnp.dot(w,x)

# 定义x, batch_x, w。
x = jax.random.normal(jax.random.PRNGKey(55), (5, 3))
x_batch = jax.random.normal(jax.random.PRNGKey(55), (4, 5, 3))
w = jax.random.normal(jax.random.PRNGKey(42), (100, x.shape[0]))

5

In [130]:
# 激活值a=w@x:
jnp.dot(w, x).shape

(100, 3)

In [140]:
# 这样手写batch, 使用for loop:
def for_loop_batch(x_batch):
    for x in x_batch:
        # print(x)
        jnp.dot(w, x)    # shape: (100, 3)

# for_loop_batch(x_batch)
%timeit for_loop_batch(x_batch)

98.8 µs ± 6.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [139]:
# batch_a = jax.vmap(f, in_axes=(0,None), out_axes=0)(x_batch, w)
# print(batch_a.shape)
%timeit jax.vmap(f, in_axes=(0, None), out_axes=0)(x_batch, w)

1.34 ms ± 71.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# 3. <a id='toc3_'></a>[jax的高级特性](#toc0_)
## 3.1. <a id='toc3_1_'></a>[jax.numpy特性](#toc0_)
```
jax.numpy和Numpy用法非常相似，可以做到无缝切换；
```

### 3.1.1. <a id='toc3_1_1_'></a>[赋值](#toc0_)

In [142]:
import jax.numpy as jnp
import numpy as np

In [145]:
x = jnp.array([1,2,3])
y = np.array([1,2,3])
x, x.shape, y, y.shape

(Array([1, 2, 3], dtype=int32), (3,), array([1, 2, 3]), (3,))

In [146]:
jnp.arange(10), np.arange(10)

(Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

In [147]:
jnp.arange(10).reshape(2,5), np.arange(10).reshape(2,5)

(Array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]], dtype=int32),
 array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]))

In [149]:
jnp.linspace(0,9,10), np.linspace(0,9,10) 

(Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32),
 array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]))

In [307]:
tem = jnp.arange(10).reshape(2, 5)
tem, jnp.split(tem, 2, axis=0)
# jnp.split()

(Array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]], dtype=int32),
 [Array([[0, 1, 2, 3, 4]], dtype=int32),
  Array([[5, 6, 7, 8, 9]], dtype=int32)])

### 3.1.2. <a id='toc3_1_2_'></a>[数组规范](#toc0_)
```
jax.numpy.arrary被设置好后就不能改变，只能通过at的方式进行更改。
jax.numpy.ndarray.at:
    x = x.at[idx].set(y)            ->          x[idx] = y
    x = x.at[idx].add(y)            ->          x[idx] += y
    x = x.at[idx].multiply(y)       ->          x[idx] *= y
    x = x.at[idx].divide(y)         ->          x[idx] /= y
    x = x.at[idx].power(y)          ->          x[idx] **= y
    x = x.at[idx].min(y)            ->          x[idx] = minimum(x[idx], y)
    x = x.at[idx].max(y)            ->          x[idx] = maximum(x[idx], y)
    x = x.at[idx].apply(ufunc)      ->          ufunc.at(x, idx)
    x = x.at[idx].get()             ->          x = x[idx]
```

In [24]:
import jax
import jax.numpy as jnp

jax_array = jnp.zeros((3,3), dtype=jnp.float32)
jax_array

ImportError: cannot import name 'index_update' from 'jax.ops' (d:\ProgramFiles\miniconda3\envs\tensorflow2\lib\site-packages\jax\ops\__init__.py)

In [20]:
jax_array[1, :]

Array([0., 0., 0.], dtype=float32)

In [22]:
# 直接更改将报错
jax_array[1, :] = 1.0

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [29]:
new_jax_array = jax_array.at[1,:].set(1.0)
new_jax_array

Array([[0., 0., 0.],
       [1., 1., 1.],
       [0., 0., 0.]], dtype=float32)

### 3.1.3. <a id='toc3_1_3_'></a>[算数运算](#toc0_)
```
jnp.add()
jnp.subtract()
jnp.multiply()
jnp.divide()
jnp.power()
jnp.exp()
```

In [193]:
jnp.add(3,2)
jnp.subtract(3,2)
jnp.multiply(3,2)
jnp.divide(3,3)
jnp.power(2,2)
jnp.exp(2)
jnp.log(100)
jnp.log10(100)
jnp.sin(2)
jnp.cos(2)
jnp.tan(3)
# jnp.dot([1,2,3], [3,4,5]) # 报错
jnp.dot(jnp.array([1,2,3]), jnp.array([3,4,5]))
jnp.where(True, 1, 0)

Array(1, dtype=int32, weak_type=True)

In [172]:
rng = jax.random.PRNGKey(seed=0) # rng:随机数生成器

input = jax.random.normal(rng, shape=(10, 2))

# weight = jax.random.normal(rng, shape=(2, 5))
weight = jnp.ones(shape=(2, 5))

# bias = jax.random.normal(rng, shape=(5,))
bias = jnp.ones(shape=(5,))

input, weight, bias

(Array([[ 1.0545162 , -0.96928865],
        [-0.5946021 , -0.03188572],
        [ 2.4109333 , -1.8784491 ],
        [-0.7847696 , -0.31370842],
        [ 0.3337089 ,  1.7677035 ],
        [-1.0277646 ,  1.4111718 ],
        [-0.5084971 , -0.5263775 ],
        [ 0.5031504 ,  1.0549793 ],
        [-0.08740733,  0.7958167 ],
        [ 2.6565616 , -0.5822906 ]], dtype=float32),
 Array([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], dtype=float32),
 Array([1., 1., 1., 1., 1.], dtype=float32))

In [173]:
jnp.dot(input, weight)

Array([[ 0.08522755,  0.08522755,  0.08522755,  0.08522755,  0.08522755],
       [-0.62648785, -0.62648785, -0.62648785, -0.62648785, -0.62648785],
       [ 0.5324842 ,  0.5324842 ,  0.5324842 ,  0.5324842 ,  0.5324842 ],
       [-1.0984781 , -1.0984781 , -1.0984781 , -1.0984781 , -1.0984781 ],
       [ 2.1014125 ,  2.1014125 ,  2.1014125 ,  2.1014125 ,  2.1014125 ],
       [ 0.38340724,  0.38340724,  0.38340724,  0.38340724,  0.38340724],
       [-1.0348747 , -1.0348747 , -1.0348747 , -1.0348747 , -1.0348747 ],
       [ 1.5581298 ,  1.5581298 ,  1.5581298 ,  1.5581298 ,  1.5581298 ],
       [ 0.70840937,  0.70840937,  0.70840937,  0.70840937,  0.70840937],
       [ 2.074271  ,  2.074271  ,  2.074271  ,  2.074271  ,  2.074271  ]],      dtype=float32)

In [171]:
jnp.dot(input, weight) + bias

Array([[ 1.0852275 ,  1.0852275 ,  1.0852275 ,  1.0852275 ,  1.0852275 ],
       [ 0.37351215,  0.37351215,  0.37351215,  0.37351215,  0.37351215],
       [ 1.5324842 ,  1.5324842 ,  1.5324842 ,  1.5324842 ,  1.5324842 ],
       [-0.09847808, -0.09847808, -0.09847808, -0.09847808, -0.09847808],
       [ 3.1014125 ,  3.1014125 ,  3.1014125 ,  3.1014125 ,  3.1014125 ],
       [ 1.3834072 ,  1.3834072 ,  1.3834072 ,  1.3834072 ,  1.3834072 ],
       [-0.03487468, -0.03487468, -0.03487468, -0.03487468, -0.03487468],
       [ 2.5581298 ,  2.5581298 ,  2.5581298 ,  2.5581298 ,  2.5581298 ],
       [ 1.7084093 ,  1.7084093 ,  1.7084093 ,  1.7084093 ,  1.7084093 ],
       [ 3.074271  ,  3.074271  ,  3.074271  ,  3.074271  ,  3.074271  ]],      dtype=float32)

In [None]:
jnp.swapaxes(msa_act, -2, -3)
jnp.transpose(act, [1, 0, 2])
jnp.expand_dims(msa_mask, axis=-1)
jnp.asarray(loss)
jnp.minimum(num_iter, self.config.num_recycle)
jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
jnp.where(bias, logits, _SOFTMAX_MASK)
jnp.concatenate([evoformer_input['msa'], template_activations], axis=0)
[jnp.zeros_like(x) for x in unit_vector]

## 3.2. <a id='toc3_2_'></a>[jax控制分支](#toc0_)
```
grad：python分支和jax分支都支持
jit：支支持jax分支

jax控制流：
    lax.cond: 等同于if
    lax.while_loop: 等同于while
    lax.fori_loop:等同于for
    lax.scan：对数组进行操作的函数
```
### 3.2.1. <a id='toc3_2_1_'></a>[分支对grad影响](#toc0_)

In [208]:
import jax
import jax.numpy as jnp

def f(x):
    if x < 3:
        return 3.0 * x**2
    else:
        return -4.0 * x
    
Df = jax.grad(f)
Df(2.0), Df(3.0)
# 不影响

(Array(12., dtype=float32, weak_type=True),
 Array(-4., dtype=float32, weak_type=True))

### 3.2.2. <a id='toc3_2_2_'></a>[分支对jit影响](#toc0_)

In [None]:
f_jited = jax.jit(f)
f_jited(2.0) # 会报错

### 3.2.3. <a id='toc3_2_3_'></a>[条件判断-jax.lax.cond(True, func1, func2, args)](#toc0_)
```
jax.lax.cond(判断条件, True对应的执行函数, False对应的执行)
```

In [213]:
# 将python分支改写成jax分支
import jax
import jax.numpy as jnp

def laxf(x):
    return jax.lax.cond(x<3, lambda x: 3.0 * x**2, lambda x:-4*x, x)

Dlaxf = jax.grad(laxf)
laxf_jited = jax.jit(Dlaxf)
laxf_jited(2.0)

Array(12., dtype=float32, weak_type=True)

### 3.2.4. <a id='toc3_2_4_'></a>[循环-jax.lax.while_loop()](#toc0_)
```
jax.lax.while_loop(cond_fun, body_fun, init_val)
```

In [214]:
init_val = 0

def cond_fun(x):
    return x < 17

def body_fun(x):
    return x + 1

jax.lax.while_loop(cond_fun, body_fun, init_val)

Array(17, dtype=int32, weak_type=True)

### 3.2.5. <a id='toc3_2_5_'></a>[循环-jax.lax.fori_loop()](#toc0_)
```
jax.lax.fori_loop(start, stop, body_fun, init_val)
```

In [216]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x : x + i

jax.lax.fori_loop(start, stop, body_fun, init_val)

Array(45, dtype=int32, weak_type=True)

## 3.3. <a id='toc3_3_'></a>[jax.nn包含的函数](#toc0_)

In [277]:
# jax.nn.one_hot()
# jax.nn.normalize()
# jax.nn.sigmoid()
# jax.nn.tanh()
# jax.nn.softmax()
# jax.nn.relu()

## 3.4. <a id='toc3_4_'></a>[jax.example_libraries](#toc0_)

In [280]:
# import  jax.example_libraries.optimizers as optimezers
# import jax.example_libraries.stax as stax

# optimezers.adam()
# optimezers.sgd()
# optimezers.adagrad()
# optimezers.rmsprop()

# stax.BatchNorm()
# stax.Dense()
# stax.Dropout()
# stax.Conv()

# 4. <a id='toc4_'></a>[多层感知机](#toc0_)

## 4.1. <a id='toc4_1_'></a>[准备数据集](#toc0_)

### 4.1.1. <a id='toc4_1_1_'></a>[mnist](#toc0_)

In [262]:
import jax.numpy as jnp
import jax

X_train = jnp.load('Minist/mnist_train_x.npy')
y_train = jnp.load("Minist/mnist_train_y.npy")

In [263]:
X_train, X_train.shape

(Array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],

In [264]:
y_trian, y_train.shape

(Array([5, 0, 4, ..., 5, 6, 8], dtype=uint8), (60000,))

### 4.1.2. <a id='toc4_1_2_'></a>[独热码（one-hot）](#toc0_)
```
离散数据之间若无相关关系，则在数值表示上最好也不要有顺序等相关关系，所有用独热码的形式表示。
```
#### 4.1.2.1. <a id='toc4_1_2_1_'></a>[自定义函数实现](#toc0_)

In [265]:
y_train_3 = y_train[0:3] # 取前三行
y_train_3

Array([5, 0, 4], dtype=uint8)

In [266]:
y_train_3[:,None] # 行向量编程列向量

Array([[5],
       [0],
       [4]], dtype=uint8)

In [267]:
jnp.arange(10) # 0-9的10个行向量

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [268]:
y_train_3[:,None] == jnp.arange(10) # 匹配上的显示True，否则为False

Array([[False, False, False, False, False,  True, False, False, False,
        False],
       [ True, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False]], dtype=bool)

In [269]:
jnp.array([False,False,True]) # 默认变成了bool

Array([False, False,  True], dtype=bool)

In [270]:
jnp.array([False,False,True], dtype=jnp.float32) # 指定数据类型

Array([0., 0., 1.], dtype=float32)

In [271]:
# 用函数框起来
def one_hot_nojit(x, k=10, dtype=jnp.float32):
    return jnp.array(x[:,None] == jnp.arange(k), dtype)

y_trian[0:5], one_hot_nojit(y_trian)[0:5]

(Array([5, 0, 4, 1, 9], dtype=uint8),
 Array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32))

#### 4.1.2.2. <a id='toc4_1_2_2_'></a>[jax.nn.one_hot()实现](#toc0_)

In [274]:
# jax.nn.one_hot()
## 简单实用

jax.nn.one_hot(y_train, num_classes=10, dtype=jnp.float32)[0:5]

Array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)

## 全连接训练mnist数据集

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds

import jax
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import stax, optimizers

from IPython import display
import matplotlib.pyplot as plt
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# 1. 定义网络结构，并计算出预测值(y_hat)
 
# {Dense(1024) -> ReLU}x2 -> Dense(10) -> LogSoftmax
init_random_params, predict = stax.serial(
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax
    )
## 初始化网络参数
_, init_params = init_random_params(rng=random.PRNGKey(0), input_shape=(-1, 28*28))
# init_params

In [7]:
# 2. 构造损失函数：真实值(y) - 预测值(y_hat)

@jit
def loss(params, batch):
    """ Cross-entropy loss over a minibatch. """
    inputs, targets = batch
    return jnp.mean(jnp.sum(-targets * predict(params, inputs), axis=1))

In [8]:
# 3. 定义优化器：用来更新权重（w)和偏置(b)

opt = optimizers.adam(step_size=2e-4)
# opt_state = opt.init_fn(网络参数)        # 初始化优化器参数
# params = opt.params_fn(opt_state)                # 获得优化器参数 
# opt_state = opt.update_fn()

## 传入网络的参数后初始化优化器参数
opt_state = opt.init_fn(init_params) 

In [9]:
@jit
def pred_check(params, batch):
    """ Correct predictions over a minibatch. """
    inputs, targets = batch
    predict_result = predict(params, inputs)
    predicted_class = jnp.argmax(predict_result, axis=1)
    targets = jnp.argmax(targets, axis=1)
    return jnp.sum(predicted_class == targets)


# 准备数据集
x_train = jnp.load("Minist/mnist_train_x.npy")
y_train = jnp.load("Minist/mnist_train_y.npy")
## y_trian要做独热编码处理
y_train = jax.nn.one_hot(x=y_train, num_classes=len(jnp.unique(y_train)))

## x y配对并打乱顺序
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train = tfds.as_numpy(ds_train)


# 4. 训练过程
x = []
train_acc = [] # 训练准确度暂存列表
for epoch in range(5):
    itercount = 0
    for batch_raw in ds_train:
        data = batch_raw[0].reshape((-1, 28 * 28))
        targets = batch_raw[1].reshape((-1, 10))
        # 更新权重和偏置
        opt_state = opt.update_fn(itercount,
                                  grad(loss)(opt.params_fn(opt_state), (data,targets)), 
                                  opt_state
                                  )
        itercount += 1
    params = opt.params_fn(opt_state)
    #上面是训练部分，这里是存档部分，这里直接仿照numpy进行存档即可

    #上面是载入部分，直接仿照numpy中数据进行载入即可
    #params = jnp.load("params.npy",allow_pickle =True)
    # Train Acc
    correct_preds = 0.0
    for batch_raw in ds_train:
        data = batch_raw[0].reshape((-1, 28 * 28))
        targets = batch_raw[1]
        correct_preds += pred_check(params, (data, targets))

    acc = correct_preds / float(len(y_train))
    train_acc.append(acc)
    print(f"{epoch}) Training set accuracy: {acc}")

    # # 绘图
    x.append(epoch)
    plt.clf()
    plt.plot(x, train_acc)
    plt.xlabel('epoch')
    plt.ylabel('acc')
    plt.title('train curve')
    plt.pause(0.0001)  # 暂停一段时间，不然画的太快会卡住显示不出来
    display.clear_output(wait=True)

0) Training set accuracy: 0.0


: 