**Table of contents**<a id='toc0_'></a>    
- 1. [教程](#toc1_)    
- 2. [概要](#toc2_)    
  - 2.1. [basic](#toc2_1_)    

<!-- vscode-jupyter-toc-config
	numbering=true
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# 下载和安装

In [2]:
# download and install latest haiku from github
if False:
    !pip install git+https://github.com/deepmind/dm-haiku
    !conda install dm-haiku jax -y

In [1]:
import haiku as hk
print(f"dm-haiku version: {hk.__version__}")

import jax
print(f"Jax version: {jax.__version__}")

import jax.numpy as jnp

dm-haiku version: 0.0.11.dev
Jax version: 0.4.19


In [29]:
import numpy as np
import jax.numpy as jnp
import haiku as hk

class MyLinear(hk.Module):
  def __init__(self, output_size, name=None):
      # 从hk.module继承形成子类，使用super的调用父类中的__init__函数定义子类的名称。
      super().__init__(name=name)
      self.output_size = output_size

  def __call__(self, x):
      # __call__()的作用是使实例能够像函数一样被调用
      j, k = x.shape[-1], self.output_size
      w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))

      # 当这个类被hk.transform进行转换后，hk.get_parameter其实是从initi直接生成参数/或apply重复使		  
      # 用输入的参数，因此不需要担心多次执行时，__call__里面的参数会被重新随机化。
      w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
      b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
      return jnp.dot(x, w) + b

In [34]:
# All `hk.Module`s must be initialized inside an `hk.transform`.
# 直接用会报错
# MyLinear(10, 'layer1')

In [31]:
# 先包含在一个函数中
def forward_fn(x, name_):
    # 实例化:
    model = MyLinear(output_size=10, name=name_)
    return model(x)  # 直接调用函数__call__

forward = hk.transform(forward_fn)

# 或者
# forward = hk.transform(lambda x: MyLinear()(x))

In [33]:
dummy_x = jnp.array([[1., 2., 3.]])      # 虚拟的input data的数据x，用于实例化网络的维度。
rng_key = jax.random.PRNGKey(42)  # 初始化的种子，这个种子产生的id会输入给forward.init中做参数的初始化，因此这个seed应该是要随机生成的，对应一个确定的seed，生成的网络参数是确定的！
params = forward.init(rng=rng_key, x=dummy_x, name_='layer1')
params

{'layer': {'w': Array([[-8.7532178e-03,  7.3761588e-01, -2.0879389e-01,  7.3657016e-04,
          -1.3183086e-01, -1.3437216e-01, -5.5001986e-01, -5.0030202e-01,
           2.6157841e-01, -2.7943218e-01],
         [-3.2892400e-01, -1.3924532e-01, -5.7093412e-01, -2.6730701e-01,
           1.1441944e-01, -2.6090544e-01,  4.6571079e-01, -4.6325716e-01,
          -6.8820441e-01,  5.0376415e-01],
         [ 3.7003931e-01, -8.1507877e-02,  1.4804702e-01,  5.5953854e-01,
           4.5092812e-01, -5.7936341e-01, -7.9028800e-02,  6.0494971e-01,
           2.9943559e-01,  4.0849480e-01]], dtype=float32),
  'b': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}}

In [38]:
# __call__():
sample_x = jnp.array([[1., 2., 3.]])
output_1 = forward.apply(params=params, x=sample_x, rng=rng_key, name_='layer')
output_1

Array([[ 0.4435167 ,  0.21460162, -0.9065211 ,  1.1447382 ,  1.4497924 ,
        -2.3942733 ,  0.14431532,  0.3880328 , -0.21652362,  1.9535805 ]],      dtype=float32)

# 1. <a id='toc1_'></a>[教程](#toc0_)
```
https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html
https://zhuanlan.zhihu.com/p/471892075
```

# 2. <a id='toc2_'></a>[概要](#toc0_)
```
google:
    Tensorflow(Sonnet)
    Haiku(JAX)
facebook:
    Pytorch
Microsoftware:
    CNTK
AWA:
    MXnet
```

## 2.1. <a id='toc2_1_'></a>[basic](#toc0_)

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
class MyLinear1(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b

In [3]:
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

In [4]:
forward_linear1

Transformed(init=<function without_state.<locals>.init_fn at 0x000001F25B7A4D30>, apply=<function without_state.<locals>.apply_fn at 0x000001F25B7A4DC0>)

In [5]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)

{'my_linear1': {'w': Array([[-0.30350363,  0.5123802 ],
       [ 0.08009142, -0.3163005 ],
       [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': Array([1., 1.], dtype=float32)}}
