In [11]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [12]:
from transformers import FlaxGPT2LMHeadModel
model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')

from flax.core import unfreeze
from flax.traverse_util import flatten_dict

params = unfreeze(model.params['transformer'])
params = flatten_dict(params, sep='.')

params.keys()

dict_keys(['h.0.attn.c_attn.bias', 'h.0.attn.c_attn.kernel', 'h.0.attn.c_proj.bias', 'h.0.attn.c_proj.kernel', 'h.0.ln_1.bias', 'h.0.ln_1.scale', 'h.0.ln_2.bias', 'h.0.ln_2.scale', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_fc.kernel', 'h.0.mlp.c_proj.bias', 'h.0.mlp.c_proj.kernel', 'h.1.attn.c_attn.bias', 'h.1.attn.c_attn.kernel', 'h.1.attn.c_proj.bias', 'h.1.attn.c_proj.kernel', 'h.1.ln_1.bias', 'h.1.ln_1.scale', 'h.1.ln_2.bias', 'h.1.ln_2.scale', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_fc.kernel', 'h.1.mlp.c_proj.bias', 'h.1.mlp.c_proj.kernel', 'h.10.attn.c_attn.bias', 'h.10.attn.c_attn.kernel', 'h.10.attn.c_proj.bias', 'h.10.attn.c_proj.kernel', 'h.10.ln_1.bias', 'h.10.ln_1.scale', 'h.10.ln_2.bias', 'h.10.ln_2.scale', 'h.10.mlp.c_fc.bias', 'h.10.mlp.c_fc.kernel', 'h.10.mlp.c_proj.bias', 'h.10.mlp.c_proj.kernel', 'h.11.attn.c_attn.bias', 'h.11.attn.c_attn.kernel', 'h.11.attn.c_proj.bias', 'h.11.attn.c_proj.kernel', 'h.11.ln_1.bias', 'h.11.ln_1.scale', 'h.11.ln_2.bias', 'h.11.ln_2.scale', 'h.11.mlp.c_f

In [13]:
for param in params:
    print(f"{param} Shape: {params[param].shape}")

h.0.attn.c_attn.bias Shape: (2304,)
h.0.attn.c_attn.kernel Shape: (2304, 768)
h.0.attn.c_proj.bias Shape: (768,)
h.0.attn.c_proj.kernel Shape: (768, 768)
h.0.ln_1.bias Shape: (768,)
h.0.ln_1.scale Shape: (768,)
h.0.ln_2.bias Shape: (768,)
h.0.ln_2.scale Shape: (768,)
h.0.mlp.c_fc.bias Shape: (3072,)
h.0.mlp.c_fc.kernel Shape: (3072, 768)
h.0.mlp.c_proj.bias Shape: (768,)
h.0.mlp.c_proj.kernel Shape: (768, 3072)
h.1.attn.c_attn.bias Shape: (2304,)
h.1.attn.c_attn.kernel Shape: (2304, 768)
h.1.attn.c_proj.bias Shape: (768,)
h.1.attn.c_proj.kernel Shape: (768, 768)
h.1.ln_1.bias Shape: (768,)
h.1.ln_1.scale Shape: (768,)
h.1.ln_2.bias Shape: (768,)
h.1.ln_2.scale Shape: (768,)
h.1.mlp.c_fc.bias Shape: (3072,)
h.1.mlp.c_fc.kernel Shape: (3072, 768)
h.1.mlp.c_proj.bias Shape: (768,)
h.1.mlp.c_proj.kernel Shape: (768, 3072)
h.10.attn.c_attn.bias Shape: (2304,)
h.10.attn.c_attn.kernel Shape: (2304, 768)
h.10.attn.c_proj.bias Shape: (768,)
h.10.attn.c_proj.kernel Shape: (768, 768)
h.10.ln_1.bi

In [14]:
from jax_gpt2 import GPT, GPTConfig
import flax.nnx as nn

model = GPT(GPTConfig(), nn.Rngs(0))

jax_modules_dict = {}
for module_pair in model.iter_modules():
    if type(module_pair[1]).__name__  in ['Block', 'CausalSelfAttention', 'GPT', 'MLP']:
        continue
    module_path = '.'.join([str(x) for x in module_pair[0]])
    module = module_pair[1]
    jax_modules_dict[module_path] = module

jax_modules_dict.keys()

dict_keys(['h.0.attn.c_attn', 'h.0.attn.c_proj', 'h.0.ln_1', 'h.0.ln_2', 'h.0.mlp.c_fc', 'h.0.mlp.c_proj', 'h.1.attn.c_attn', 'h.1.attn.c_proj', 'h.1.ln_1', 'h.1.ln_2', 'h.1.mlp.c_fc', 'h.1.mlp.c_proj', 'h.2.attn.c_attn', 'h.2.attn.c_proj', 'h.2.ln_1', 'h.2.ln_2', 'h.2.mlp.c_fc', 'h.2.mlp.c_proj', 'h.3.attn.c_attn', 'h.3.attn.c_proj', 'h.3.ln_1', 'h.3.ln_2', 'h.3.mlp.c_fc', 'h.3.mlp.c_proj', 'h.4.attn.c_attn', 'h.4.attn.c_proj', 'h.4.ln_1', 'h.4.ln_2', 'h.4.mlp.c_fc', 'h.4.mlp.c_proj', 'h.5.attn.c_attn', 'h.5.attn.c_proj', 'h.5.ln_1', 'h.5.ln_2', 'h.5.mlp.c_fc', 'h.5.mlp.c_proj', 'h.6.attn.c_attn', 'h.6.attn.c_proj', 'h.6.ln_1', 'h.6.ln_2', 'h.6.mlp.c_fc', 'h.6.mlp.c_proj', 'h.7.attn.c_attn', 'h.7.attn.c_proj', 'h.7.ln_1', 'h.7.ln_2', 'h.7.mlp.c_fc', 'h.7.mlp.c_proj', 'h.8.attn.c_attn', 'h.8.attn.c_proj', 'h.8.ln_1', 'h.8.ln_2', 'h.8.mlp.c_fc', 'h.8.mlp.c_proj', 'h.9.attn.c_attn', 'h.9.attn.c_proj', 'h.9.ln_1', 'h.9.ln_2', 'h.9.mlp.c_fc', 'h.9.mlp.c_proj', 'h.10.attn.c_attn', 'h.10.att

In [15]:
params['wte.embedding'].shape
jax_modules_dict['lm_head'].kernel.value.shape

(50257, 768)

(768, 50257)

In [16]:
jax_modules_dict['h.0.attn.c_attn'].bias.value.dtype
params['h.0.attn.c_attn.bias'].dtype

jax_modules_dict['h.0.attn.c_attn'].bias.value.shape
params['h.0.attn.c_attn.bias'].shape

dtype('float32')

dtype('float32')

(2304,)

(2304,)