<a href="https://colab.research.google.com/github/talkin24/jaxflax_lab/blob/main/Flax_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.8/2.1 MB[0m [31m26.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.1/2.1 MB[0m [31m42.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.0/79.0 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flax (pyproject.toml) ... [?25l[?25hdone
[0m

In [2]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

In [3]:
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

출력만 입력하면 됨

# Model Parameters & initialization

- 파라미터는 모델에 저장되지 않음. init 함수를 호출하여 초기화해주어야 함. 이때 PRNGKey와 더미 인풋을 사용함.

In [9]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

- JAX/FLAX는 numpy처럼 row-based 시스템. 벡터는 행 벡터 기준임

- `init_with_output` 메소드로 더미 인풋의 forward pass 출력도 함께 리턴 가능

- 출력은 FrozenDict 인스턴스로 저장됨. 이는 immutable하며 사용자가 이를 인식하게 도와줌.



In [10]:
try:
    params['new_key'] = jnp.ones((2,2))
except ValueError as e:
    print("Error: ", e)

Error:  FrozenDict is immutable.


In [11]:
model.apply(params, x)

Array([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 , -1.1271116 ],      dtype=float32)

In [12]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [13]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [21]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.02363979
Loss step 0:  35.343876
Loss step 10:  0.5143469
Loss step 20:  0.11384159
Loss step 30:  0.039326735
Loss step 40:  0.019916208
Loss step 50:  0.014209135
Loss step 60:  0.012425654
Loss step 70:  0.01185039
Loss step 80:  0.011661784
Loss step 90:  0.011599409
Loss step 100:  0.011578695


# Optimizing with Optax

- Optax는 스케쥴링(시간에 따라 옵티마이저 파라미터 변경), 마스킹(트리별 다른 파라미터 업데이트)도 지원함

In [22]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

- Optimizer state도 설정해주어야 함

In [23]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577628
Loss step 10:  0.26143155
Loss step 20:  0.07675027
Loss step 30:  0.03644055
Loss step 40:  0.022012806
Loss step 50:  0.016178599
Loss step 60:  0.0130028
Loss step 70:  0.012026141
Loss step 80:  0.011764516
Loss step 90:  0.0116460435
Loss step 100:  0.011585529


`optax.apply_updates`로 업데이트 

# Serializing the result

In [26]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4561517, -2.0283859,  2.0785248,  1.2180303, -0.9993835],      dtype=float32), 'kernel': Array([[ 1.009248  ,  0.18867432,  0.04391698, -0.928683  ,  0.34787187],
       [ 1.7301217 ,  0.98819846,  1.1643261 ,  1.1008728 , -0.1068265 ],
       [-1.2040868 ,  0.28517762,  1.4144661 ,  0.11754218, -1.3146015 ],
       [-1.1944853 , -0.18990842,  0.03379964,  1.3166167 ,  0.08113391],
       [ 0.13801418,  1.3707805 , -1.319223  ,  0.5310065 , -2.2408094 ],
       [ 0.5634556 ,  0.812729  ,  0.31803787,  0.5350531 ,  0.9062198 ],
       [-0.37820524,  1.7420965 ,  1.0800813 , -0.5029262 ,  0.9282519 ],
       [ 0.96965903, -1.3163381 ,  0.33582455,  0.8089384 , -1.2017919 ],
       [ 1.0207036 , -0.6189928 ,  1.0831568 , -1.8377141 , -0.45602012],
       [-0.64296687,  0.4573661 , -1.1322317 , -0.6846899 ,  0.16773744]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14.c\xba\xbf\x13\xd1\x01\xc

In [27]:
serialization.from_bytes(params, bytes_output)

FrozenDict({
    params: {
        bias: array([-1.4561517, -2.0283859,  2.0785248,  1.2180303, -0.9993835],
              dtype=float32),
        kernel: array([[ 1.009248  ,  0.18867432,  0.04391698, -0.928683  ,  0.34787187],
               [ 1.7301217 ,  0.98819846,  1.1643261 ,  1.1008728 , -0.1068265 ],
               [-1.2040868 ,  0.28517762,  1.4144661 ,  0.11754218, -1.3146015 ],
               [-1.1944853 , -0.18990842,  0.03379964,  1.3166167 ,  0.08113391],
               [ 0.13801418,  1.3707805 , -1.319223  ,  0.5310065 , -2.2408094 ],
               [ 0.5634556 ,  0.812729  ,  0.31803787,  0.5350531 ,  0.9062198 ],
               [-0.37820524,  1.7420965 ,  1.0800813 , -0.5029262 ,  0.9282519 ],
               [ 0.96965903, -1.3163381 ,  0.33582455,  0.8089384 , -1.2017919 ],
               [ 1.0207036 , -0.6189928 ,  1.0831568 , -1.8377141 , -0.45602012],
               [-0.64296687,  0.4573661 , -1.1322317 , -0.6846899 ,  0.16773744]],
              dtype=float32),
  

이전에 생성된 params을 템플릿으로 활용하여 모델 load

# Defining your own models

In [28]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810348 -0.0255094   0.02151717 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


`setup()` 메서드는 서브모듈, 변수, 파라미터 등 모델에 필요한 것들을 등록하는 `__postinit__` 다음에 호출됨

In [30]:
try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)

"ExplicitMLP" object has no attribute "layers". If "layers" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


모델 구조와 파라미터가 직접 연결되어 있지 않으므로 바로 `model(x)`를 call 할 수 없음

In [31]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810348 -0.0255094   0.02151717 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


모델이 간단하므로 `@nn.compact` 어노테이션을 사용하여 setup 대체

두 가지 방법의 차이
- `setup`에서는 sublayer를 네이밍할 수 있고, 향후 사용을 위해 킵해둘 수 있음
- 여러 메서드를 사용하려면 `setup`을 사용. `@nn.compact` 어노테이션은 하나의 메서드에만 허용됨

In [32]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: Array([[ 0.61506   , -0.22728713,  0.6054702 ],
               [-0.29617992,  1.1232013 , -0.879759  ],
               [-0.35162622,  0.3806491 ,  0.6893246 ],
               [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32),
        bias: Array([0., 0., 0.], dtype=float32),
    },
})
output:
 [[-0.02996203  1.102088   -0.6660265 ]
 [-0.31092793  0.63239413 -0.53678817]
 [ 0.01424009  0.9424717  -0.63561463]
 [ 0.3681896   0.3586519  -0.00459218]]


Dense layer 없으면 이렇게 직접 사용. `param` 함수 사용

In [33]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    mean = ra_mean.value # This will either get the value or trigger init
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

initialized variables:
 FrozenDict({
    batch_stats: {
        mean: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
    params: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})


`variable` 변수를 사용하여 파라미터를 넘어서는 변수 선언

In [34]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = variables.pop('params')
  variables = freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part

updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),
    },
})


In [35]:
from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)

Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),
    },
})
