In [1]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

## Ran into all kinds of errors with the mask

- Tried various things
- Errors were not clear as the full state was getting dumped
- Running like this was helpful - python test_training_with_weight_decay.py 2>&1 | head -50
- But the breakthrough came with testing with a small model in prepare_training_with_decay_simple.ipynb
- First success was with the more complex optax.multi_transform and flax.traverse_util.path_aware_map, but was good learning
- Eventually got the mask to work


ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: State({

ValueError: Mismatch custom node data: ('embedding_dropout', 'h', 'lm_head', 'ln_f', 'wpe', 'wte') != ('h', 'lm_head', 'ln_f', 'wpe', 'wte'); value: State({


In [5]:
from gpt2 import GPT
import flax.nnx as nnx
import flax
# from flax.core import FrozenDict, frozen_dict

model = GPT.from_pretrained("gpt2")


def model_state_decay_mask(model: GPT):
    flat_state = nnx.state(model).flat_state()
    flat_mask = {}
    for key in flat_state.keys():
        # The grads fed to the optimizer dont have the dropout
        if "dropout" in key[0]:
            continue
        flat_mask[key] = key[-1] not in ("bias", "embedding", "scale", "count", "key")
    return nnx.State.from_flat_path(flat_mask)


# None of the following worked but kept getting errors like expected FrozenDict so tried this

# def param_decay_mask(params: FrozenDict) -> FrozenDict:
#     """pytree mask for non-bias parameters"""
#     flat_params = flax.traverse_util.flatten_dict(params)
#     flat_param_mask = {
#         k: k[-1] not in ("bias", "embedding", "scale") for k in flat_params.keys()
#     }
#     param_mask = flax.traverse_util.unflatten_dict(flat_param_mask)
#     return frozen_dict.freeze(param_mask)


model_state_decay_mask(model)

loading weights from pretrained gpt: gpt2


Length of prepared JAX modules dict: 89


State({
  'h': {
    0: {
      'attn': {
        'c_attn': {
          'bias': False,
          'kernel': True
        },
        'c_proj': {
          'bias': False,
          'kernel': True
        }
      },
      'ln_1': {
        'bias': False,
        'scale': False
      },
      'ln_2': {
        'bias': False,
        'scale': False
      },
      'mlp': {
        'c_fc': {
          'bias': False,
          'kernel': True
        },
        'c_proj': {
          'bias': False,
          'kernel': True
        }
      }
    },
    1: {
      'attn': {
        'c_attn': {
          'bias': False,
          'kernel': True
        },
        'c_proj': {
          'bias': False,
          'kernel': True
        }
      },
      'ln_1': {
        'bias': False,
        'scale': False
      },
      'ln_2': {
        'bias': False,
        'scale': False
      },
      'mlp': {
        'c_fc': {
          'bias': False,
          'kernel': True
        },
        'c_proj': {
      

In [6]:
from gpt2 import GPT
import optax
import flax.nnx as nnx


@nnx.jit
def train_step(model: GPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    x, y = batch

    def loss_fn(model: GPT):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits.reshape([-1, logits.shape[-1]]), y.reshape([-1])
        ).mean()
        return loss, logits

    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model)
    metrics.update(loss=loss, logits=logits, labels=y)
    optimizer.update(grads)

In [7]:
from gpt2 import GPT
import optax
import flax.nnx as nnx
import numpy as np
import tiktoken
import datasets

model = GPT.from_pretrained("gpt2")
model.train()

tx = optax.adamw(
    learning_rate=1e-4, weight_decay=1e-4, mask=model_state_decay_mask(model)
)
optimizer = nnx.Optimizer(model, tx)
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average("loss"),
)

enc = tiktoken.get_encoding("gpt2")
batch_size = 6
block_size = 1024

data = datasets.load_dataset(path="Trelis/tiny-shakespeare")
train_data = "\n".join([x["Text"] for x in data["train"]])
train_data = enc.encode_ordinary(train_data)
train_data = np.array(train_data, dtype=np.uint16)
val_data = "\n".join([x["Text"] for x in data["test"]])
val_data = enc.encode_ordinary(val_data)
val_data = np.array(val_data, dtype=np.uint16)


def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = np.random.randint(len(data) - block_size, size=(batch_size,))
    x = np.stack([data[i : i + block_size].astype(np.int32) for i in ix])
    y = np.stack([data[i + 1 : i + 1 + block_size].astype(np.int32) for i in ix])
    return x, y


train_step(model, optimizer, metrics, get_batch("train"))
metrics.compute()

train_step(model, optimizer, metrics, get_batch("train"))
metrics.compute()

loading weights from pretrained gpt: gpt2
Length of prepared JAX modules dict: 89


{'accuracy': Array(0.33658853, dtype=float32),
 'loss': Array(4.338888, dtype=float32)}

{'accuracy': Array(0.32535806, dtype=float32),
 'loss': Array(4.1847277, dtype=float32)}