In [39]:
import jax.numpy as jnp
import flax.nnx as nn
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [40]:
from transformers import GPT2LMHeadModel

model_type = 'gpt2'
model_hf = GPT2LMHeadModel.from_pretrained(model_type, cache_dir="/var/local/ML/TRAIN/STAGE")

In [41]:
# !cd /var/local/ML/TRAIN/pico_shakespeare && cat $(find ../STAGE/*gpt2* -name 'config.json')

In [42]:
model_hf.config

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.43.4",
  "use_cache": true,
  "vocab_size": 50257
}

In [43]:
sd_hf = model_hf.state_dict()
sd_hf.keys()

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.

In [44]:
type(sd_hf['transformer.wte.weight'])
sd_hf['transformer.wte.weight'].shape

torch.Tensor

torch.Size([50257, 768])

In [45]:
vocab_size, embedding_size = sd_hf['transformer.wte.weight'].shape
wte = nn.Embed(num_embeddings=vocab_size, features=embedding_size, rngs=nn.Rngs(0))
wte

Embed(
  embedding=Param(
    value=Array(shape=(50257, 768), dtype=float32)
  ),
  num_embeddings=50257,
  features=768,
  dtype=dtype('float32'),
  param_dtype=<class 'jax.numpy.float32'>,
  embedding_init=<function variance_scaling.<locals>.init at 0x7f88b8b31310>
)

In [46]:
graphdef, state = nn.split(wte)
state
sd_hf['transformer.wte.weight'].cpu().numpy()[0][0] 
state["embedding"].value[0][0]
type(state["embedding"].value)
state["embedding"].value = jnp.array(sd_hf['transformer.wte.weight'].cpu().numpy())
type(state["embedding"].value)
nn.update(wte, state)

State({
  'embedding': VariableState(
    type=Param,
    value=Array([[ 1.1953074e-02, -4.3702841e-02,  1.5583087e-02, ...,
             2.9827623e-02,  2.3185154e-02, -5.0074987e-02],
           [-7.3833148e-06, -2.3771707e-02,  3.2675751e-02, ...,
             1.2370082e-02, -1.8245960e-02, -5.5854514e-02],
           [-6.4068988e-02,  1.0926131e-02, -9.7181993e-03, ...,
            -1.6451136e-03, -1.8916763e-02, -7.8727528e-02],
           ...,
           [-4.5188973e-03, -8.1994740e-04,  1.7434264e-02, ...,
             1.5338360e-02,  2.8312072e-02,  2.1429532e-04],
           [-1.0427453e-03,  1.4039346e-02,  4.0459871e-02, ...,
            -3.7717942e-02, -1.7851518e-02, -4.7507521e-02],
           [ 2.5526977e-03, -1.7003939e-02,  2.0834690e-02, ...,
            -3.3392377e-02, -8.9475606e-04, -4.4884817e-03]], dtype=float32)
  )
})

np.float32(-0.11010301)

Array(0.01195307, dtype=float32)

jaxlib.xla_extension.ArrayImpl

jaxlib.xla_extension.ArrayImpl

In [47]:
[(x,model_hf.state_dict()[x].shape) for x in model_hf.state_dict().keys() if "h.0" in x]

[('transformer.h.0.ln_1.weight', torch.Size([768])),
 ('transformer.h.0.ln_1.bias', torch.Size([768])),
 ('transformer.h.0.attn.c_attn.weight', torch.Size([768, 2304])),
 ('transformer.h.0.attn.c_attn.bias', torch.Size([2304])),
 ('transformer.h.0.attn.c_proj.weight', torch.Size([768, 768])),
 ('transformer.h.0.attn.c_proj.bias', torch.Size([768])),
 ('transformer.h.0.ln_2.weight', torch.Size([768])),
 ('transformer.h.0.ln_2.bias', torch.Size([768])),
 ('transformer.h.0.mlp.c_fc.weight', torch.Size([768, 3072])),
 ('transformer.h.0.mlp.c_fc.bias', torch.Size([3072])),
 ('transformer.h.0.mlp.c_proj.weight', torch.Size([3072, 768])),
 ('transformer.h.0.mlp.c_proj.bias', torch.Size([768]))]

In [48]:
ln_features, = sd_hf['transformer.h.0.ln_1.weight'].shape
ln = nn.LayerNorm(num_features=ln_features, rngs=nn.Rngs(0))
ln

LayerNorm(
  scale=Param(
    value=Array(shape=(768,), dtype=float32)
  ),
  bias=Param(
    value=Array(shape=(768,), dtype=float32)
  ),
  num_features=768,
  epsilon=1e-06,
  dtype=None,
  param_dtype=<class 'jax.numpy.float32'>,
  use_bias=True,
  use_scale=True,
  bias_init=<function zeros at 0x7f88b93d3160>,
  scale_init=<function ones at 0x7f88b93d3310>,
  reduction_axes=-1,
  feature_axes=-1,
  axis_name=None,
  axis_index_groups=None,
  use_fast_variance=True
)

In [49]:
ln.scale.value = jnp.array(sd_hf['transformer.h.0.ln_1.weight'].cpu().numpy())
ln.bias.value = jnp.array(sd_hf['transformer.h.0.ln_1.bias'].cpu().numpy())
ln

LayerNorm(
  scale=Param(
    value=Array(shape=(768,), dtype=float32)
  ),
  bias=Param(
    value=Array(shape=(768,), dtype=float32)
  ),
  num_features=768,
  epsilon=1e-06,
  dtype=None,
  param_dtype=<class 'jax.numpy.float32'>,
  use_bias=True,
  use_scale=True,
  bias_init=<function zeros at 0x7f88b93d3160>,
  scale_init=<function ones at 0x7f88b93d3310>,
  reduction_axes=-1,
  feature_axes=-1,
  axis_name=None,
  axis_index_groups=None,
  use_fast_variance=True
)

In [50]:
mha = nn.MultiHeadAttention(
num_heads=12, 
in_features=768, 
qkv_features=768,
out_features=768,
decode=False, 
rngs=nn.Rngs(0)
)
mha

MultiHeadAttention(
  num_heads=12,
  in_features=768,
  qkv_features=768,
  out_features=768,
  dtype=None,
  param_dtype=<class 'jax.numpy.float32'>,
  broadcast_dropout=True,
  dropout_rate=0.0,
  deterministic=None,
  precision=None,
  kernel_init=<function variance_scaling.<locals>.init at 0x7f88b8b31040>,
  out_kernel_init=None,
  bias_init=<function zeros at 0x7f88b93d3160>,
  out_bias_init=None,
  use_bias=True,
  attention_fn=<function dot_product_attention at 0x7f88b8b31e50>,
  decode=False,
  normalize_qk=False,
  qkv_dot_general=None,
  out_dot_general=None,
  qkv_dot_general_cls=None,
  out_dot_general_cls=None,
  head_dim=64,
  query=LinearGeneral(
    in_features=(768,),
    out_features=(12, 64),
    axis=(-1,),
    batch_axis=FrozenDict({}),
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    kernel_init=<function variance_scaling.<locals>.init at 0x7f88b8b31040>,
    bias_init=<function zeros at 0x7f88b93d3160>,
    precision=None,
    

# Prep

In [51]:
from jax_gpt2 import GPT, GPTConfig
import flax.nnx as nn
config = GPTConfig()
model = GPT(config, nn.Rngs(0))
[(x[0], type(x[1]).__name__) for x in model.iter_modules()]


[(('h', 0, 'attn', 'c_attn'), 'Linear'),
 (('h', 0, 'attn', 'c_proj'), 'Linear'),
 (('h', 0, 'attn'), 'CausalSelfAttention'),
 (('h', 0, 'ln_1'), 'LayerNorm'),
 (('h', 0, 'ln_2'), 'LayerNorm'),
 (('h', 0, 'mlp', 'c_fc'), 'Linear'),
 (('h', 0, 'mlp', 'c_proj'), 'Linear'),
 (('h', 0, 'mlp'), 'MLP'),
 (('h', 0), 'Block'),
 (('h', 1, 'attn', 'c_attn'), 'Linear'),
 (('h', 1, 'attn', 'c_proj'), 'Linear'),
 (('h', 1, 'attn'), 'CausalSelfAttention'),
 (('h', 1, 'ln_1'), 'LayerNorm'),
 (('h', 1, 'ln_2'), 'LayerNorm'),
 (('h', 1, 'mlp', 'c_fc'), 'Linear'),
 (('h', 1, 'mlp', 'c_proj'), 'Linear'),
 (('h', 1, 'mlp'), 'MLP'),
 (('h', 1), 'Block'),
 (('h', 2, 'attn', 'c_attn'), 'Linear'),
 (('h', 2, 'attn', 'c_proj'), 'Linear'),
 (('h', 2, 'attn'), 'CausalSelfAttention'),
 (('h', 2, 'ln_1'), 'LayerNorm'),
 (('h', 2, 'ln_2'), 'LayerNorm'),
 (('h', 2, 'mlp', 'c_fc'), 'Linear'),
 (('h', 2, 'mlp', 'c_proj'), 'Linear'),
 (('h', 2, 'mlp'), 'MLP'),
 (('h', 2), 'Block'),
 (('h', 3, 'attn', 'c_attn'), 'Linear

In [52]:
set([type(x[1]).__name__ for x in model.iter_modules()])

{'Block', 'CausalSelfAttention', 'Embed', 'GPT', 'LayerNorm', 'Linear', 'MLP'}

In [53]:
from transformers import GPT2LMHeadModel
import jax

model_type = 'gpt2'
model_hf = GPT2LMHeadModel.from_pretrained(model_type, cache_dir="/var/local/ML/TRAIN/STAGE")

jax_modules_dict = {}
for module_pair in model.iter_modules():
    if type(module_pair[1]).__name__  in ['Block', 'CausalSelfAttention', 'GPT', 'MLP']:
        continue
    module_path = '.'.join([str(x) for x in module_pair[0]])
    module = module_pair[1]
    jax_modules_dict[module_path] = module

len(jax_modules_dict.keys())

equivalent_jax_modules = []
hf_sd = model_hf.state_dict()

for param in hf_sd:
    if 'transformer' in param:
        key = '.'.join(param.split(".")[1:-1])
    else:
        key = '.'.join(param.split(".")[:-1])
    equivalent_jax_module = jax_modules_dict[key]
    equivalent_jax_modules.append(type(equivalent_jax_module).__name__)

    if type(equivalent_jax_module).__name__ == 'Embed':
        inner = equivalent_jax_module.embedding
    elif type(equivalent_jax_module).__name__ == 'Linear':
        if 'weight' in param:
            inner = equivalent_jax_module.kernel
        else:
            inner = equivalent_jax_module.bias
    elif type(equivalent_jax_module).__name__ == 'LayerNorm':
        if 'weight' in param:
            inner = equivalent_jax_module.scale
        else:
            inner = equivalent_jax_module.bias
    
    if inner.value.shape == tuple(hf_sd[param].shape):
        inner.value = jnp.array(hf_sd[param].cpu().numpy())
    elif inner.value.shape == tuple(hf_sd[param].shape)[::-1]:
        print("Transposing ", key)
        inner.value = jnp.array(hf_sd[param].cpu().numpy().T)

assert len(equivalent_jax_modules) == len(model_hf.state_dict())

len(equivalent_jax_modules)

76

Transposing  lm_head


149