In [1]:
import optax
from flax.training import train_state, checkpoints
from flax import traverse_util
from transformers import FlaxBertForMaskedLM
from typing import Callable
import jax.numpy as jnp

2023-01-05 09:40:46.358568: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64
2023-01-05 09:40:46.358716: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def create_learning_rate_fn(
        num_total_train_steps: int, num_warmup_steps: int, learning_rate: float, schedule_type: str = 'linear',
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    if schedule_type == 'constant':
        decay_fn = optax.constant_schedule(value=learning_rate)
    elif schedule_type == 'linear':
        decay_fn = optax.linear_schedule(
            init_value=learning_rate, end_value=0, transition_steps=num_total_train_steps - num_warmup_steps
        )
    elif schedule_type == 'polynomial':
        decay_fn = optax.polynomial_schedule(
            init_value=learning_rate, end_value=0, power=2, transition_steps=num_total_train_steps - num_warmup_steps
        )
    elif schedule_type == 'cosine':
        decay_fn = optax.cosine_decay_schedule(
            init_value=learning_rate, decay_steps=num_total_train_steps - num_warmup_steps, alpha=0.1
        )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn

In [3]:
# Load a dummy model (weights are ~0.5MB)
model = FlaxBertForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)

# Create a dummy learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
    num_total_train_steps=10, 
    num_warmup_steps=5,
    learning_rate=2e-10, 
    schedule_type="linear",
)

# Create an optimiser
optimizer = optax.adamw(learning_rate=linear_decay_lr_schedule_fn)

# Create a train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)

# Save train state
CKPT_DIR = "ckpts"
checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=0, keep=0, overwrite=True);

Some weights of the model checkpoint at hf-internal-testing/tiny-random-bert were not used when initializing FlaxBertForMaskedLM: {('cls', 'seq_relationship', 'bias'), ('cls', 'predictions', 'decoder', 'kernel'), ('qa_outputs', 'kernel'), ('cls', 'seq_relationship', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'bias'), ('cls', 'predictions', 'decoder', 'bias'), ('classifier', 'kernel'), ('qa_outputs', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# Reload train state
loaded_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)

In [5]:
# Create a new learning rate schedule
new_linear_decay_lr_schedule_fn = create_learning_rate_fn(
    num_total_train_steps=5, 
    num_warmup_steps=1,
    learning_rate=2e-5, 
    schedule_type="linear",
)

In [6]:
# Impose new LR schedule on loaded train state?..