<a href="https://colab.research.google.com/github/rahul-art/jax_flax/blob/main/multisteps_serialization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -U flax optax

In [None]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax

from typing import Sequence
from flax.training.train_state import TrainState
from flax.training.checkpoints import save_checkpoint, restore_checkpoint

In [None]:
rng = jax.random.PRNGKey(842)
rng, data_rng = jax.random.split(rng)
x = jnp.array([[x, x] for x in range(64)], dtype=jnp.float32)
y = jnp.sum(2*x + 1, axis=-1, keepdims=True)
x = x + jax.random.normal(data_rng, x.shape)
def data_gen():
    yield x, y
    
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

model = MLP([4, 1])
params = model.init(jax.random.PRNGKey(0), x)

optimizer = optax.adamw(0.01)
optimizer = optax.MultiSteps(optimizer, 4)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

  lax._check_user_dtype_supported(dtype, "zeros")


In [None]:
def compute_loss(params, batch):
    preds = state.apply_fn(params, batch[0])
    targs = batch[1]
    return jnp.mean((preds - targs)**2)

grad_fn = jax.value_and_grad(compute_loss)

def train_step(state, batch):
    def compute_loss(params):
        preds = state.apply_fn(params, batch[0])
        targs = batch[1]
        return jnp.mean((preds - targs)**2)
    grad_fn = jax.value_and_grad(compute_loss)
    
    loss, grad = grad_fn(state.params)

    new_state = state.apply_gradients(grads=grad)

    metrics = {"loss": loss}

    return new_state, metrics

train_step = jax.jit(train_step)

In [None]:
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])

  lax._check_user_dtype_supported(dtype, "zeros")


15254.637
15254.637
15254.637
15254.637
14935.168
14935.168
14935.168
14935.168


In [None]:
save_checkpoint('./_tmp/', state, 8, overwrite=True)

'_tmp/checkpoint_8'

In [None]:
model = MLP([4, 1])
params = model.init(jax.random.PRNGKey(0), x)

optimizer = optax.adamw(0.01)
optimizer = optax.MultiSteps(optimizer, 4)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

  lax._check_user_dtype_supported(dtype, "zeros")


In [None]:
from flax.serialization import from_state_dict

In [None]:
state_ckpt = restore_checkpoint('./_tmp/', None)

In [None]:
state = from_state_dict(state, state_ckpt)

In [None]:
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])

  lax._check_user_dtype_supported(dtype, "zeros")


TypeError: ignored

## A cripy solution :)

In [None]:
def _zeros_tree_like(inp_tree):
    return jax.tree_map(jnp.zeros_like, inp_tree)

In [None]:
fake_updates = _zeros_tree_like(state.params)
_, new_inner_opt_state = state.tx.inner_opt.update(fake_updates, state.opt_state.inner_opt_state, state.params)
opt_state = state.opt_state
new_opt_state = optax.MultiStepsState(mini_step=opt_state.mini_step, 
                                      gradient_step=opt_state.gradient_step, 
                                      inner_opt_state=new_inner_opt_state,
                                      acc_grads=opt_state.acc_grads)

In [None]:
state = state.replace(opt_state=new_opt_state)

In [None]:
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])

  lax._check_user_dtype_supported(dtype, "zeros")


14611.871
14611.871
14611.871
14611.871
14333.213
14333.213
14333.213
14333.213


## Reinstantiating the components of inner_opt_state resolves the issue

In [None]:
inner_opt_state = state.opt_state.inner_opt_state

In [None]:
def reinstantiate_states(opt_state):
    new_state = []
    for state in opt_state:
        cls = getattr(optax, type(state).__name__)
        new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
    return new_state

In [None]:
new_inner_opt_state = reinstantiate_states(inner_opt_state)

In [None]:
ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}
ms_state_dict["inner_opt_state"] = new_inner_opt_state

In [None]:
state = state.replace(opt_state=optax.MultiStepsState(**ms_state_dict))

In [None]:
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])

13954.496
13954.496
13954.496
13954.496
13620.805
13620.805
13620.805
13620.805


In [None]:
state.tx._opt

GradientTransformation(init=<function chain.<locals>.init_fn at 0x7ff2539e23b0>, update=<function chain.<locals>.update_fn at 0x7ff2539e2440>)

## with HF model

In [None]:
%%capture
!pip install -U transformers

In [None]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax

from typing import Sequence
from flax.training.train_state import TrainState
from flax.training.common_utils import onehot
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer

In [None]:
from flax.serialization import to_bytes, from_bytes
import os
import json

In [None]:
def mb_item(x):
    return x.item() if hasattr(x, "item") else x

In [None]:
#checkpoint functions
def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
    """
    If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
    """
    # state = jax_utils.unreplicate(state)
    # logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
    if not push_to_hub:
        save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
    model.save_pretrained(
        save_dir,
        params=state.params,
        push_to_hub=push_to_hub,
        commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
    )
    if with_opt:
        with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
            f.write(to_bytes(state.opt_state))
        with open(os.path.join(save_dir, "training_state.json"), "w") as f:
            json.dump({"step": state.step.item()}, f)
    # logger.info("checkpoint saved")


def reinstantiate_states(opt_state):
    new_state = []
    for state in opt_state:
        cls = getattr(optax, type(state).__name__)
        new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
    return new_state

def restore_model_checkpoint(save_dir, state):
    # logger.info(f"RESTORING CHECKPOINT FROM {save_dir}...")
    with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
        params = from_bytes(state.params, f.read())

    with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
        opt_state = from_bytes(state.opt_state, f.read())

    with open(os.path.join(save_dir, "training_state.json"), "r") as f:
        training_state = json.load(f)
    step = training_state["step"]

    # logger.info("checkpoint restored")
    # reinstantiate inner opt state to avoid type conflict
    if hasattr(opt_state, "inner_opt_state"):
        print("restoring multisteps optimizer")
        inner_opt_state = reinstantiate_states(opt_state.inner_opt_state)
        ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}
        ms_state_dict["inner_opt_state"] = inner_opt_state
        opt_state = optax.MultiStepsState(**ms_state_dict)

    return state.replace(step=step, params=params, opt_state=opt_state)

In [None]:
rng = jax.random.PRNGKey(842)
rng, data_rng = jax.random.split(rng)
input_ids = jax.random.randint(data_rng, (8, 128), 0, 256)
labels = input_ids.copy()

In [None]:
model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
optimizer = optax.adamw(1e-3)
optimizer = optax.MultiSteps(optimizer, 4)

state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)

  lax._check_user_dtype_supported(dtype, "zeros")


In [None]:
def loss_fn(logits, labels):
    shift_logits = logits[..., :-1, :]
    shift_labels = labels[..., 1:]
    loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
    return loss.mean()

def train_step(state, batch, rng):
    dropout_rng, rng = jax.random.split(rng)
    print("compiling...")
    def compute_loss(params):
        labels = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = loss_fn(logits, labels)
        return loss

    grad_fn = jax.value_and_grad(compute_loss)
    loss, grad = grad_fn(state.params)

    new_state = state.apply_gradients(grads=grad)

    return new_state, loss, rng

train_step = jax.jit(train_step)

In [None]:
def data_gen():
    yield {"input_ids":input_ids, "labels":labels}

In [None]:
for i in range(24):
    batch = next(data_gen())
    state, loss, rng = train_step(state, batch, rng)
    print(loss)

8.4307
8.407024
8.3607645
8.444751
28.090708
28.100212
28.028742
28.039408
18.009052
18.227276
18.357918
18.488722
9.798409
9.931963
9.774198
9.863355
8.712209
8.647808
8.6835375
8.694314
7.514835
7.5347486
7.5292873
7.491831


In [None]:
save_model_checkpoint(model, "./_hf", state)

In [None]:
old_state = state

In [None]:
model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
optimizer = optax.adamw(1e-3)
optimizer = optax.MultiSteps(optimizer, 4)

state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)

  lax._check_user_dtype_supported(dtype, "zeros")


In [None]:
state = restore_model_checkpoint("./_hf/ckpt-31", state)

restoring multisteps optimizer


In [None]:
def verify(state1, state2):
    return jax.tree_multimap(lambda a,b: (a == b).all(), state1, state2)

In [None]:
verify(old_state.opt_state.inner_opt_state, state.opt_state.inner_opt_state)

[ScaleByAdamState(count=DeviceArray(True, dtype=bool), mu={'transformer': {'h': {'0': {'attn': {'c_attn': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}, 'c_proj': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}}, 'ln_1': {'bias': DeviceArray(True, dtype=bool), 'scale': DeviceArray(True, dtype=bool)}, 'ln_2': {'bias': DeviceArray(True, dtype=bool), 'scale': DeviceArray(True, dtype=bool)}, 'mlp': {'c_fc': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}, 'c_proj': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}}}, '1': {'attn': {'c_attn': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}, 'c_proj': {'bias': DeviceArray(True, dtype=bool), 'kernel': DeviceArray(True, dtype=bool)}}, 'ln_1': {'bias': DeviceArray(True, dtype=bool), 'scale': DeviceArray(True, dtype=bool)}, 'ln_2': {'bias': DeviceArray(True, dtype=bool), 'scale': Dev

In [None]:
for i in range(16):
    batch = next(data_gen())
    state, loss, rng = train_step(state, batch, rng)
    print(loss)

compiling...


  lax._check_user_dtype_supported(dtype, "zeros")


7.0651155
7.087094
7.08247
7.0764003
6.989064
7.010929
6.989328
7.0078263
6.875692
6.870382
6.8747725
6.8677526
6.7281094
6.713859
6.739676
6.72533


## One-go

In [None]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax

from typing import Sequence
from flax.training.train_state import TrainState
from flax.training.checkpoints import save_checkpoint, restore_checkpoint


rng = jax.random.PRNGKey(842)
rng, data_rng = jax.random.split(rng)
x = jnp.array([[x, x] for x in range(64)], dtype=jnp.float32)
y = jnp.sum(2*x + 1, axis=-1, keepdims=True)
x = x + jax.random.normal(data_rng, x.shape)
def data_gen():
    yield x, y
    
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

model = MLP([4, 1])
params = model.init(jax.random.PRNGKey(0), x)

optimizer = optax.adamw(0.01)
optimizer = optax.MultiSteps(optimizer, 4)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


def compute_loss(params, batch):
    preds = state.apply_fn(params, batch[0])
    targs = batch[1]
    return jnp.mean((preds - targs)**2)

grad_fn = jax.value_and_grad(compute_loss)

def train_step(state, batch):
    def compute_loss(params):
        preds = state.apply_fn(params, batch[0])
        targs = batch[1]
        return jnp.mean((preds - targs)**2)
    grad_fn = jax.value_and_grad(compute_loss)
    
    loss, grad = grad_fn(state.params)

    new_state = state.apply_gradients(grads=grad)

    metrics = {"loss": loss}

    return new_state, metrics
train_step = jax.jit(train_step)

# train model, save checkpoint
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])
save_checkpoint('./_tmp/', state, 8, overwrite=True)

# restore checkopint, resume training - fails
state = restore_checkpoint('./_tmp/', state)
for i in range(8):
    batch = next(data_gen())
    state, metrics = train_step(state, batch)
    print(metrics["loss"])

  lax._check_user_dtype_supported(dtype, "zeros")


15254.637
15254.637
15254.637
15254.637
14935.168
14935.168
14935.168
14935.168


TypeError: ignored