**Table of contents**<a id='toc0_'></a>    
- 1. [下载和安装](#toc1_)    
- 2. [教程](#toc2_)    
- 3. [概要](#toc3_)    
- 4. [神经网络搭建八股](#toc4_)    
- 5. [basic](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=true
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# 1. <a id='toc1_'></a>[下载和安装](#toc0_)

In [3]:
import haiku as hk
print(f"dm-haiku version: {hk.__version__}")

import jax
print(f"Jax version: {jax.__version__}")

import jax.numpy as jnp

import tensorflow as tf
print(f"tensorflow version: {tf.__version__}")

import tensorflow_datasets as tfds
print(f"tensorflow_datasets version: {tfds.__version__}")

dm-haiku version: 0.0.11
Jax version: 0.4.20
tensorflow version: 2.10.0
tensorflow_datasets version: 1.2.0


# 3. <a id='toc3_'></a>[概要](#toc0_)
```
google:
    Tensorflow(Sonnet)
    Haiku(JAX)
facebook:
    Pytorch
Microsoftware:
    CNTK
AWA:
    MXnet

dm-haiku教程：
https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html
https://zhuanlan.zhihu.com/p/471892075
```



# 4. <a id='toc4_'></a>[神经网络搭建八股](#toc0_)
```
1. 定义网络结构，计算预测值(y_hat);
2. 构造loss函数；
3. 训练（迭代）：更新权重(w)和偏置(b)。
```
## 自己摸索（便于理解）

In [None]:
import haiku as hk

import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

from IPython import display
%matplotlib inline

rng = jax.random.PRNGKey(0)

# 1.定义神经网络结构
class Network(hk.Module):
    def __init__(self, name=None):
        super().__init__()

    def __call__(self, x):
        mlp = hk.Sequential([hk.Linear(768, name='hidden1'),
                            jax.nn.relu,
                            hk.Linear(10, name='hidden2'),
                            jax.nn.relu,
                            jax.nn.softmax])
        logits = mlp(x)
        return logits
# 2.初始化网络获得params:w,b
model = hk.transform(lambda x: Network(name='TestNetwork')(x))
params = model.init(rng=rng, x=jnp.ones((256, 28*28), dtype=jnp.float32))
# params

# 3.定义loss和优化器：自动包括损失函数、反向传播（自动微分、更新权重）
def loss(x, params, y):
    logits = model.apply(x=x, params=params, rng=None)
    cce = jnp.mean(-jnp.sum(jnp.log(logits) * y))
    return cce

opt = optimizers.sgd(step_size=0.001)
opt_state = opt.init_fn(params) # 需要接受网络结构的相关信息
# opt_state

# 4.准备数据集
x_train = jnp.load("Minist/mnist_train_x.npy")
y_train = jnp.load("Minist/mnist_train_y.npy")
y_train = jax.nn.one_hot(x=y_train, num_classes=len(jnp.unique(y_train))) ## y_trian要做独热编码处理
## x y配对并打乱顺序
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train = tfds.as_numpy(ds_train)

# 5.开始训练
acc_list = []
for epoch in range(100): 
    # epochs=100，所有数据迭代100次
    step = 0
    for batch in ds_train:
        # 一个整体分为batch_size=256份，迭代256次去更新params
        features = batch[0].reshape((-1, 28*28))
        features = features/255
        labels = batch[1].reshape((-1, 10))
        # print(features.shape, labels.shape)

        grads = jax.grad(loss, argnums=(1))(features, params, labels)       # argnums=(0,1,2,3)指定需要求导的自变量（这里是params)
        opt_state = opt.update_fn(step, grads, opt_state)                   # 更新优化器的opt_state
        params = opt.params_fn(opt_state)                                   # 用新的opt_state去生成新的params
        step += 1

    ## 计算acc
    prediction = model.apply(x=x_train.reshape(-1,28*28)/255, params=params, rng=None)
    # print(prediction.shape) # (60000, 10)
    # print(y_trains) # (60000, 10)
    pred_targets = jnp.argmax(prediction, axis=1)           # 返回最大数字的下标(预测)
    y_targets = jnp.argmax(y_train, axis=1)                 # 返回最大数字的下标(真实)
    ok = jnp.sum(pred_targets == y_targets)                 # 比较下表是否一致，是则ok否则不ok
    acc = jnp.divide(ok, y_train.shape[0])                  # ok的占比
    acc_list.append(acc)
    print(acc)

    ## 动态绘图
    if epoch %10 == 0:
        plt.clf()
        plt.plot(acc_list)
        plt.xlabel('epoch')
        plt.ylabel('acc')
        plt.pause(0.000001)
        display.clear_output(wait=True)

: 

## 利用jit（推荐）

In [2]:
import haiku as hk

import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

from IPython import display
%matplotlib inline

rng = jax.random.PRNGKey(0)

# 1定义神经网络结构
class Network(hk.Module):
    def __init__(self, name=None):
        super().__init__()

    def __call__(self, x):
        mlp = hk.Sequential([hk.Linear(768, name='hidden1'),
                            jax.nn.relu,
                            hk.Linear(10, name='hidden2'),
                            jax.nn.relu,
                            jax.nn.softmax])
        logits = mlp(x)
        return logits
# 2初始化网络获得params:w,b
model = hk.transform(lambda x: Network(name='TestNetwork')(x))
params = model.init(rng=rng, x=jnp.ones((256, 28*28), dtype=jnp.float32))
# params

# 3定义loss和优化器：自动包括损失函数、反向传播（自动微分、更新权重）
def loss(x, params, y):
    logits = model.apply(x=x, params=params, rng=None)
    cce = jnp.mean(-jnp.sum(jnp.log(logits) * y))
    return cce

opt = optimizers.sgd(step_size=0.001)
opt_state = opt.init_fn(params) # 需要接受网络结构的相关信息
# opt_state

@jax.jit
def update(step, params, features, labels, opt_state):
    grads = jax.grad(loss, argnums=(1))(features, params, labels)       # argnums=(0,1,2,3)指定需要求导的自变量（这里是params)
    opt_state = opt.update_fn(step, grads, opt_state)                   # 更新优化器的opt_state
    params = opt.params_fn(opt_state)                                   # 用新的opt_state去生成新的params
    return opt_state, params                               

# 4准备数据集
x_train = jnp.load("Minist/mnist_train_x.npy")
y_train = jnp.load("Minist/mnist_train_y.npy")
## y_trian要做独热编码处理
y_train = jax.nn.one_hot(x=y_train, num_classes=len(jnp.unique(y_train)))
## x y配对并打乱顺序
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train = tfds.as_numpy(ds_train)

# 5训练
acc_list = []
for epoch in range(100):
    step = 0
    for batch in ds_train:
        features = batch[0].reshape((-1, 28*28))
        features = features/255
        labels = batch[1].reshape((-1, 10))

        opt_state, params = update(step, params, features, labels, opt_state)   # 更新参数
        # jnp.save('model_params/params.npy', params)                             # 保存params参数，此步骤严重耗时

        step += 1

    ## 计算acc
    # params = jnp.load('model_params/params.npy')            # 加载参数

    prediction = model.apply(x=x_train.reshape(-1,28*28)/255, params=params, rng=None)
    pred_targets = jnp.argmax(prediction, axis=1)           # 返回最大数字的下标(预测)
    y_targets = jnp.argmax(y_train, axis=1)                 # 返回最大数字的下标(真实)
    ok = jnp.sum(pred_targets == y_targets)                 # 比较下表是否一致，是则ok否则不ok
    acc = jnp.divide(ok, y_train.shape[0])                  # ok的占比
    acc_list.append(acc)

    ## 动态绘图
    if epoch % 10 == 0:
        print(acc)
        plt.clf()
        plt.plot(acc_list)
        plt.xlabel('epoch')
        plt.ylabel('acc')
        plt.pause(0.000001)
        display.clear_output(wait=True)

0.7356667


: 

## 可视化网络结构

In [13]:
import haiku as hk

# 定义网络结构和计算logits
class Test(hk.Module):
    def __init__(self, name=None):
        super().__init__()

    def __call__(self, x):
        mlp = hk.Sequential([
            hk.Linear(300, name='hidden1'),
            jax.nn.relu,
            hk.Linear(10, name='hidden2'),
            jax.nn.softmax
        ])
        return mlp(x)

# 初始化网络和参数（w，b）
net = hk.transform(lambda x: Test(name='Test')(x))
x = jnp.ones((2,3))
params = net.init(jax.random.PRNGKey(0), x=x)

# 打印网络结构图
print(hk.experimental.tabulate(net)(x))

+---------------------------------+------------------------------------------------------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module                          | Config                                                                             | Module params   | Input      | Output     |   Param count |   Param bytes |
| test (Test)                     | Test(name='Test')                                                                  |                 | f32[2,3]   | f32[2,10]  |         4,210 |      16.84 KB |
+---------------------------------+------------------------------------------------------------------------------------+-----------------+------------+------------+---------------+---------------+
| test/sequential (Sequential)    | Sequential(                                                                        |                 | f32[2,3]   | f32[2,10]  |         4,210 |      16.84 KB |
|  └ test (Test

# 5. <a id='toc5_'></a>[hk.transform()](#toc0_)
```
# 网络实例化
net = hk.transform()

## 对象初始化w，b
params = net.init(rng, x)

## apply计算__call__函数体
logits = net.apply(rng, params, x)
```
## def函数

In [39]:
# hk.transforma(函数)

import haiku as hk
import jax.numpy as jnp
import jax

def forward(x):
    mlp = hk.nets.MLP([300, 100, 10])
    return mlp(x)
forward = hk.transform(forward)

rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28*28])

params = forward.init(rng, x)
logits = forward.apply(params, rng, x)
logits

Array([[ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579

## @hk.transform装饰器

In [40]:
# @：装饰器修饰

import haiku as hk

import jax
import jax.numpy as jnp

@hk.transform
def forward2(x):
    mlp = hk.nets.MLP([300, 100, 10])
    return mlp(x)

rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28*28])

params2 = forward2.init(rng, x)
logits = forward2.apply(params2, rng, x)
logits

Array([[ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579

## lambda匿名函数

In [1]:
# hk.transforma(lambda函数)

import haiku as hk

import jax
import jax.numpy as jnp

class Forward3(hk.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        mlp = hk.nets.MLP([300, 100, 10])
        return mlp(x)
    
rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28*28])

forward3 = hk.transform(lambda x: Forward3()(x))
params3 = forward3.init(rng, x)
logits = forward3.apply(params3, rng, x)
logits

Array([[ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579, -0.49413946, -0.07068619],
       [ 0.45413703, -0.3273952 , -0.36023346,  0.58912265,  0.24682917,
        -0.08996333, -0.01644108,  0.06207579

# dm-haiku优势
```
rng = jax.random.PRNkey(0)

net = hk.transform()
params = net.init(rng, x) # 初始化很简单
net.apply(params, rng, x) # 运算起来也很多简单
```
```
1. 先搭建模型函数model
2. 用hk.transform转换模型函数，得到model_transform，如果正向inference不用随机数，可以再套一层hk.without_apply_rng
model_transform.init()初始化模型，返回初始化参数
3. 使用返回的参数，初始化optax包中的优化器
4. 之后使用训练数据训练模型，代码为model_transform.apply()，根据logit计算出loss之后，使用jax.grad()计算出梯度，用opt.update更新优化器状态和参数更新值，用optax.apply_updates()更新模型参数 （或则手动用jax.tree_multimap更新网络参数） 
```

## CNN

In [None]:
import haiku as hk

import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

from IPython import display
%matplotlib inline

rng = jax.random.PRNGKey(0)

# 1定义神经网络结构
class Network(hk.Module):
    def __init__(self, name=None):
        super().__init__()

    def __call__(self, x):
        # mlp = hk.Sequential([hk.Linear(768, name='hidden1'),
        #                     jax.nn.relu,
        #                     hk.Linear(10, name='hidden2'),
        #                     jax.nn.relu,
        #                     jax.nn.softmax])
        mlp = hk.Sequential([
            hk.Conv2D()
        ])
        logits = mlp(x)
        return logits
# 2初始化网络获得params:w,b
model = hk.transform(lambda x: Network(name='TestNetwork')(x))
params = model.init(rng=rng, x=jnp.ones((256, 28*28), dtype=jnp.float32))
# params

# 3定义loss和优化器：自动包括损失函数、反向传播（自动微分、更新权重）
def loss(x, params, y):
    logits = model.apply(x=x, params=params, rng=None)
    cce = jnp.mean(-jnp.sum(jnp.log(logits) * y))
    return cce

opt = optimizers.sgd(step_size=0.001)
opt_state = opt.init_fn(params) # 需要接受网络结构的相关信息
# opt_state

@jax.jit
def update(step, params, features, labels, opt_state):
    grads = jax.grad(loss, argnums=(1))(features, params, labels)       # argnums=(0,1,2,3)指定需要求导的自变量（这里是params)
    opt_state = opt.update_fn(step, grads, opt_state)                   # 更新优化器的opt_state
    params = opt.params_fn(opt_state)                                   # 用新的opt_state去生成新的params
    return opt_state, params                               

# 4准备数据集
x_train = jnp.load("Minist/mnist_train_x.npy")
y_train = jnp.load("Minist/mnist_train_y.npy")
## y_trian要做独热编码处理
y_train = jax.nn.one_hot(x=y_train, num_classes=len(jnp.unique(y_train)))
## x y配对并打乱顺序
ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1024).batch(256).prefetch(tf.data.experimental.AUTOTUNE)
ds_train = tfds.as_numpy(ds_train)

# 5训练
acc_list = []
for epoch in range(100):
    step = 0
    for batch in ds_train:
        features = batch[0].reshape((-1, 28*28))
        features = features/255
        labels = batch[1].reshape((-1, 10))

        opt_state, params = update(step, params, features, labels, opt_state)   # 更新参数
        # jnp.save('model_params/params.npy', params)                             # 保存params参数，此步骤严重耗时

        step += 1

    ## 计算acc
    # params = jnp.load('model_params/params.npy')            # 加载参数

    prediction = model.apply(x=x_train.reshape(-1,28*28)/255, params=params, rng=None)
    pred_targets = jnp.argmax(prediction, axis=1)           # 返回最大数字的下标(预测)
    y_targets = jnp.argmax(y_train, axis=1)                 # 返回最大数字的下标(真实)
    ok = jnp.sum(pred_targets == y_targets)                 # 比较下表是否一致，是则ok否则不ok
    acc = jnp.divide(ok, y_train.shape[0])                  # ok的占比
    acc_list.append(acc)

    ## 动态绘图
    if epoch % 10 == 0:
        print(acc)
        plt.clf()
        plt.plot(acc_list)
        plt.xlabel('epoch')
        plt.ylabel('acc')
        plt.pause(0.000001)
        display.clear_output(wait=True)