In [None]:
import transformers
import numpy as np
import torch as t
import jax
import jax.numpy as jnp
import os
import flax

name = "gpt2-medium"
hf_model = transformers.AutoModelForCausalLM.from_pretrained(name)
hf_params = hf_model.state_dict()

In [None]:
hf_config = hf_model.config
hf_config

In [None]:
from interp.model.gpt_model import Gpt
import jax
import jax.numpy as jnp
from interp.tools.interpretability_tools import batch_tokenize
model = Gpt(num_layers=hf_config.n_layer,num_heads=hf_config.n_head,hidden_size = hf_config.n_embd,
    max_sequence_len = hf_config.n_ctx,vocab_size=hf_config.vocab_size,use_mlp=True,norm_type="layer_norm",attn_bias=True,layer_norm_epsilon=hf_config.layer_norm_epsilon)
text = "[BEGIN] \"I don't sleep right,\" Harry said."

data = batch_tokenize([text])
our_params = jax.jit(model.init)(jax.random.PRNGKey(0), data)["params"]

In [None]:
import numpy as np
def recurse(x):
    if isinstance(x, jnp.ndarray):
        return np.array(x)
    return {k:recurse(v) for k,v in x.items()}
out_params = recurse(our_params)

In [None]:
def recurse_print(x):
    if isinstance(x, jnp.ndarray):
        return x.shape
    return {k:recurse_print(v) for k,v in x.items()}
recurse_print(our_params)

In [None]:
# manually copying stuff over. Can check that all our keys are used, HF models have attention mask params and such taht we don't need. 

def cp(a,b):
    assert tuple([x for x in a.shape])==tuple([x for x in b.shape]),(a.shape,b.shape)
    np.copyto(b,np.array(a),casting='no')
    
cp(hf_params["transformer.wpe.weight"], out_params["embedding"]["position_embedding"]["embedding"])
cp(hf_params["transformer.wte.weight"], out_params["embedding"]["token_embedding"]["embedding"])
cp(hf_params["transformer.ln_f.bias"], out_params["norm_output"]["bias"])
cp(hf_params["transformer.ln_f.weight"], out_params["norm_output"]["scale"])

for i in range(hf_config.n_layer):
    cp(hf_params[f"transformer.h.{i}.ln_1.weight"], out_params[f"blocks_{i}"]["norm1"]["scale"])
    cp(hf_params[f"transformer.h.{i}.ln_1.bias"], out_params[f"blocks_{i}"]["norm1"]["bias"])
    
    cp(hf_params[f"transformer.h.{i}.ln_2.weight"], out_params[f"blocks_{i}"]["norm2"]["scale"])
    cp(hf_params[f"transformer.h.{i}.ln_2.bias"], out_params[f"blocks_{i}"]["norm2"]["bias"])
    
    cp(hf_params[f"transformer.h.{i}.attn.c_attn.weight"], out_params[f"blocks_{i}"]["attention"]["attn_weights"]["kernel"])
    cp(hf_params[f"transformer.h.{i}.attn.c_attn.bias"], out_params[f"blocks_{i}"]["attention"]["attn_weights"]["bias"])

    cp(hf_params[f"transformer.h.{i}.attn.c_proj.weight"], out_params[f"blocks_{i}"]["attention"]["project_output"]["kernel"])
    cp(hf_params[f"transformer.h.{i}.attn.c_proj.bias"], out_params[f"blocks_{i}"]["attention"]["project_output"]["bias"])
    
    cp(hf_params[f"transformer.h.{i}.mlp.c_fc.bias"], out_params[f"blocks_{i}"]["linear1"]["bias"])
    cp(hf_params[f"transformer.h.{i}.mlp.c_fc.weight"], out_params[f"blocks_{i}"]["linear1"]["kernel"])
    
    cp(hf_params[f"transformer.h.{i}.mlp.c_proj.bias"], out_params[f"blocks_{i}"]["linear2"]["bias"])
    cp(hf_params[f"transformer.h.{i}.mlp.c_proj.weight"], out_params[f"blocks_{i}"]["linear2"]["kernel"])

In [None]:
hf_params.keys()

In [None]:
out_params_frozen = flax.core.frozen_dict.FrozenDict({"params":out_params})

In [None]:
our_out = jax.nn.softmax(model.apply(out_params_frozen,data[:,1:])[0],axis=-1)

In [None]:
import torch as t
hf_out = jax.nn.softmax(hf_model(t.tensor(np.array(data[:,1:]))).logits.detach().numpy(),axis=-1)

In [None]:
print(hf_out[0,0,:12],our_out[0,0,:12])
assert np.allclose(our_out,hf_out,atol=0.001) # not much precision, but i don't care that much

In [None]:
# NOTE: You have to create the model_info.json yourself because it's a sucky thing. Copy from another similar model
from  flax.serialization import to_bytes
import os
local = True
try:
    os.mkdir(f"/home/ubuntu{'' if local else '/rrfs'}/interpretability_models_jax/{name}")
except:
    pass
open(f"/home/ubuntu{'' if local else '/rrfs'}/interpretability_models_jax/{name}/model.bin","wb").write(to_bytes(out_params_frozen))

In [None]:
from interp.model.model_loading import load_model
model,params,tokenizer = load_model(name,models_dir="/home/ubuntu/interpretability_models_jax")