<a href="https://colab.research.google.com/github/talkin24/jaxflax_lab/blob/main/Flax_fundamentals.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git

[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[0m

In [2]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

In [3]:
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

출력만 입력하면 됨

# Model Parameters & initialization

- 파라미터는 모델에 저장되지 않음. init 함수를 호출하여 초기화해주어야 함. 이때 PRNGKey와 더미 인풋을 사용함.

In [4]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

- JAX/FLAX는 numpy처럼 row-based 시스템. 벡터는 행 벡터 기준임

- `init_with_output` 메소드로 더미 인풋의 forward pass 출력도 함께 리턴 가능

- 출력은 FrozenDict 인스턴스로 저장됨. 이는 immutable하며 사용자가 이를 인식하게 도와줌.



In [5]:
try:
    params['new_key'] = jnp.ones((2,2))
except ValueError as e:
    print("Error: ", e)

Error:  FrozenDict is immutable.


In [6]:
model.apply(params, x)

Array([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 , -1.1271116 ],      dtype=float32)

In [7]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [8]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [9]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.02363979
Loss step 0:  35.343876
Loss step 10:  0.5143469
Loss step 20:  0.11384159
Loss step 30:  0.039326735
Loss step 40:  0.019916208
Loss step 50:  0.014209135
Loss step 60:  0.012425654
Loss step 70:  0.01185039
Loss step 80:  0.011661784
Loss step 90:  0.011599409
Loss step 100:  0.011578695


# Optimizing with Optax

- Optax는 스케쥴링(시간에 따라 옵티마이저 파라미터 변경), 마스킹(트리별 다른 파라미터 업데이트)도 지원함

In [10]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

- Optimizer state도 설정해주어야 함

In [11]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577628
Loss step 10:  0.26143155
Loss step 20:  0.07675027
Loss step 30:  0.03644055
Loss step 40:  0.022012806
Loss step 50:  0.016178599
Loss step 60:  0.0130028
Loss step 70:  0.012026141
Loss step 80:  0.011764516
Loss step 90:  0.0116460435
Loss step 100:  0.011585529


`optax.apply_updates`로 업데이트 

# Serializing the result

In [12]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4555768 , -2.0277991 ,  2.0790975 ,  1.2186145 , -0.99809754],      dtype=float32), 'kernel': Array([[ 1.0098814 ,  0.18934374,  0.04454996, -0.9280221 ,  0.3478402 ],
       [ 1.7298453 ,  0.9879368 ,  1.1640464 ,  1.1006076 , -0.10653935],
       [-1.2029463 ,  0.28635228,  1.4155979 ,  0.11870951, -1.3141483 ],
       [-1.1941489 , -0.18958491,  0.03413862,  1.3169426 ,  0.0806038 ],
       [ 0.1385241 ,  1.3713038 , -1.3187183 ,  0.53152674, -2.2404997 ],
       [ 0.56294024,  0.8122311 ,  0.3175201 ,  0.53455096,  0.9050039 ],
       [-0.37926027,  1.7410393 ,  1.0790287 , -0.5039833 ,  0.9283062 ],
       [ 0.9706492 , -1.3153403 ,  0.33681503,  0.8099344 , -1.2018458 ],
       [ 1.0194312 , -0.6202479 ,  1.0818833 , -1.838974  , -0.45805007],
       [-0.6436537 ,  0.45666698, -1.1329137 , -0.6853864 ,  0.16829035]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14WP\xba\xbfv\xc7\x01\

In [13]:
serialization.from_bytes(params, bytes_output)

FrozenDict({
    params: {
        bias: array([-1.4555768 , -2.0277991 ,  2.0790975 ,  1.2186145 , -0.99809754],
              dtype=float32),
        kernel: array([[ 1.0098814 ,  0.18934374,  0.04454996, -0.9280221 ,  0.3478402 ],
               [ 1.7298453 ,  0.9879368 ,  1.1640464 ,  1.1006076 , -0.10653935],
               [-1.2029463 ,  0.28635228,  1.4155979 ,  0.11870951, -1.3141483 ],
               [-1.1941489 , -0.18958491,  0.03413862,  1.3169426 ,  0.0806038 ],
               [ 0.1385241 ,  1.3713038 , -1.3187183 ,  0.53152674, -2.2404997 ],
               [ 0.56294024,  0.8122311 ,  0.3175201 ,  0.53455096,  0.9050039 ],
               [-0.37926027,  1.7410393 ,  1.0790287 , -0.5039833 ,  0.9283062 ],
               [ 0.9706492 , -1.3153403 ,  0.33681503,  0.8099344 , -1.2018458 ],
               [ 1.0194312 , -0.6202479 ,  1.0818833 , -1.838974  , -0.45805007],
               [-0.6436537 ,  0.45666698, -1.1329137 , -0.6853864 ,  0.16829035]],
              dtype=float32

이전에 생성된 params을 템플릿으로 활용하여 모델 load

# Defining your own models

In [14]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810348 -0.0255094   0.02151717 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


`setup()` 메서드는 서브모듈, 변수, 파라미터 등 모델에 필요한 것들을 등록하는 `__postinit__` 다음에 호출됨

In [15]:
try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)

"ExplicitMLP" object has no attribute "layers". If "layers" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


모델 구조와 파라미터가 직접 연결되어 있지 않으므로 바로 `model(x)`를 call 할 수 없음

In [16]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810348 -0.0255094   0.02151717 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


모델이 간단하므로 `@nn.compact` 어노테이션을 사용하여 setup 대체

두 가지 방법의 차이
- `setup`에서는 sublayer를 네이밍할 수 있고, 향후 사용을 위해 킵해둘 수 있음
- 여러 메서드를 사용하려면 `setup`을 사용. `@nn.compact` 어노테이션은 하나의 메서드에만 허용됨

In [17]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: Array([[ 0.61506   , -0.22728713,  0.6054702 ],
               [-0.29617992,  1.1232013 , -0.879759  ],
               [-0.35162622,  0.3806491 ,  0.6893246 ],
               [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32),
        bias: Array([0., 0., 0.], dtype=float32),
    },
})
output:
 [[-0.02996203  1.102088   -0.6660265 ]
 [-0.31092793  0.63239413 -0.53678817]
 [ 0.01424009  0.9424717  -0.63561463]
 [ 0.3681896   0.3586519  -0.00459218]]


Dense layer 없으면 이렇게 직접 사용. `param` 함수 사용

In [18]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    mean = ra_mean.value # This will either get the value or trigger init
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

initialized variables:
 FrozenDict({
    batch_stats: {
        mean: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
    params: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})


`variable` 변수를 사용하여 파라미터를 넘어서는 변수 선언

In [19]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = variables.pop('params')
  variables = freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part

updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),
    },
})


In [20]:
from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)

Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: Array([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),
    },
})


# Managing Parameters and State

In [21]:
class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

상태변수와 최적화할 매개변수를 따로 관리해야 한다는 점이 매우 까다로움

In [29]:
def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

In [30]:
dummy_input = random.normal(key1, (10,))
num_epochs = 10

In [34]:
model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

## `vmap` accross the batch dimension 

batchnorm 을 사용하기 위해 두가지 변화 필요
1. 모델 정의 시 batch axis 네이밍. 커스터마이즈할 때는 axis_name 인자를 lax.pmean()에 직접 전달해야할 수도 있음.
2. 트레이닝 코드의 vmap에도 같은 이름을 지정해주어야 함

In [38]:
class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch", # Name batch dim
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

In [None]:
def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

위 아래 비교

In [41]:
def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):

  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
        {'params': params, **state},
        x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, out_axes=(0, None),  # Do not vmap `updated_state`.
      axis_name='batch'  # Name batch dim
    )(x_batch, y_batch)  # vmap only `x`, `y`, but not `state`.
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

loss_fn을 vmap. update_state는 vmap하지 않으므로 out_axes (0, None)

In [None]:
model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

위 아래 비교

In [42]:
model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, x_samples, y_samples, opt_state, params, state)

# `setup` vs. `compact`

- setup -> 명시적. pytorch와 유사. pytorch에서 포팅하기 쉽고 여러 포워드패스를 정의할 수 있음...?
- compact -> 인라인? 단일 메서드에 전체 포워드 패스를 명시함. 짧은 코드. 코드가 좀 더 수학적 표현에 가까워짐. 파라미터가 입력 변수의 shape에 의해 결정되어도 됨, 그러나 setup에서는 불가능.

# Dealing with Flax Module arguments

Flax linen에선 모듈 인수를 dataclass 속성 또는 보통 `__ call__` 메서드의 인수로 구분가능

- 완전히 고정적인 속성(kernel initializer의 선택, 출력 feature 수)는 하이퍼파라미터이며 따라서 dataclass 속성으로 정의되어야 함. 일반적으로 다른 하이퍼파라미터를 가진 두 모듈은 share할 수 없음.
- 동적 속성(입력, mode switch) 등은 `__call__`이나 다른 속성으로 전달되어야 함

그러나 명확하지 않은 케이스들도 있음. dropout 모듈의 경우

아래의 경우는 명확
- 하이퍼파라미터: the dropout rate, dropout mask가 생성되는 축
- call time 인자: dropout을 사용하여 maked되는 입력. random masking을 위해 사용되는 rng

그러나 `deterministic` 속성은 모호함.


In [None]:
# 하이퍼파라미터로 간주하는 경우
from functools import partial
from flax import linen as nn

class ResidualModel(nn.Module):
  drop_rate: float

  @nn.compact
  def __call__(self, x, *, train):
    dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
    for i in range(10):
      x += ResidualBlock(dropout=dropout, ...)(x)

In [None]:
# 메서드 인자로 간주하는 경우
class SomeModule(nn.Module):
  drop_rate: float

  def setup(self):
    self.dropout = nn.Dropout(rate=self.drop_rate)

  @nn.compact
  def __call__(self, x, *, train):
    # ...
    x = self.dropout(x, deterministic=not train)
    # ...

따라서 아래와 같이 처리.
이 예제에서 nn.merge_param은 self.deterministic 또는 deterministic 중 하나만 설정되지만 둘 다 설정되지는 않도록 함.



In [None]:
class MyDropout(nn.Module):
  drop_rate: float
  deterministic: Optional[bool] = None

  @nn.compact
  def __call__(self, x, deterministic=None):
    deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
    # ...