In [1]:
from transformers import RobertaConfig as HFRobertaConfig
from transformers import RobertaForMaskedLM as HFRobertaForMaskedLM
from levanter.models.roberta import RobertaConfig
from levanter.models.roberta import RobertaForMaskedLM

import jax.random as jrandom
import jax.numpy as jnp
import haliax as hax

  from .autonotebook import tqdm as notebook_tqdm
2025-02-27 18:06:07,216	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
def load_weights_from_hf(hf, my):
    # Load the Hugging Face model
    hf_model = hf.from_pretrained("roberta-base")
    hf_config = HFRobertaConfig.from_pretrained("roberta-base")

    hf_config.hidden_dropout_prob = 0
    hf_config.attention_probs_dropout_prob = 0
    # hf_config.pad_token_id = -1

    # Convert Hugging Face config to Levanter config
    lv_config = RobertaConfig.from_hf_config(hf_config)

    # Create an instance of your custom RobertaForMaskedLM
    # key = jrandom.PRNGKey(0)
    # vocab_size = hf_config.vocab_size
    # Vocab = hax.Axis("vocab", vocab_size)
    # custom_model = my.init(Vocab=Vocab, config=lv_config, key=key)

    # Load weights from Hugging Face into your model
    converter = lv_config.hf_checkpoint_converter()

    model = converter.load_pretrained(
        lv_config.model_type,
        lv_config,
        axis_mapping=None, 
        dtype="float32",  
    )
    
    #print("Weights loaded successfully.")
    return model, lv_config, hf_model, hf_config

lv_model, lv_config, hf_model, hf_config = load_weights_from_hf(HFRobertaForMaskedLM, RobertaForMaskedLM)


Loading weights: 100%|██████████| 203/203 [00:01<00:00, 147.36it/s]


In [3]:
import torch
import numpy as np

In [4]:
# Compare outputs
def check(my_out, hf_out, precision=1e-4):
    acc = np.isclose(hf_out, my_out, rtol=precision, atol=precision).mean()
    diff = np.abs(my_out - hf_out).mean()
    return f"Accuracy: {acc:.4f}, Avg Difference: {diff:.6f}"

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
prompt = tokenizer("Mark <mask> is the CEO of Facebook, located in <mask> <mask>, California.", return_tensors="pt", padding='max_length', max_length=514)

In [6]:
Batch = hax.Axis("batch", 1)
Pos = lv_config.Pos

key = jrandom.PRNGKey(42)
key_var, key_model = jrandom.split(key, 2)

lv_prompt = {k: hax.NamedArray(np.array((prompt[k])), axes = (Batch, Pos)) for k in prompt.keys()}

# input_ids = hax.NamedArray(np.array(prompt["input_ids"]), axes = (Batch, Pos))
# attention_mask = hax.NamedArray(np.array(prompt["attention_mask"]), axes = (Batch, Pos))

In [7]:
def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
    mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
    incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
    # incremental_indices -= mask.all(axis=self.Pos)
    return incremental_indices + lv_config.pad_token_id

In [8]:
position_ids = create_position_ids_from_input_ids(lv_prompt["input_ids"])

lv_prompt["position_ids"] = position_ids
prompt["position_ids"] = torch.from_numpy(np.array(position_ids.array))

In [9]:
print(check(np.array(lv_prompt["input_ids"].array), np.array(prompt["input_ids"])))
print(check(np.array(lv_prompt["attention_mask"].array), np.array(prompt["attention_mask"])))
print(check(np.array(lv_prompt["position_ids"].array), np.array(prompt["position_ids"])))

Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 1.0000, Avg Difference: 0.000000


In [10]:
def check_dicts(my_dict, hf_dict):
    print(my_dict.keys())
    print(hf_dict.keys())

    my_keys = set(my_dict)
    hf_keys = set(hf_dict)

    both = hf_keys.union(my_keys)

    correct_msg = "Accuracy: 1.0000, Avg Difference: 0.000000"

    for k in both:
        if k not in my_keys:
            print(f"Key {k} not in my_keys!")
            continue

        if k not in hf_keys:
            print(f"Key {k} not in hf_keys!")
            continue
        
        check_str = check(my_dict[k], hf_dict[k])
        if check_str != correct_msg:
            print(check_str)

my_dict = lv_model.to_state_dict()
hf_dict = hf_model.state_dict()

my_dict = {k: np.array(my_dict[k]) for k in my_dict.keys()}
hf_dict = {k: np.array(hf_dict[k]) for k in hf_dict.keys()}

check_dicts(my_dict, hf_dict)

print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.decoder.bias"]))
print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))
print(check(hf_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))

dict_keys(['roberta.encoder.layer.0.attention.self.query.weight', 'roberta.encoder.layer.0.attention.self.query.bias', 'roberta.encoder.layer.0.attention.self.key.weight', 'roberta.encoder.layer.0.attention.self.key.bias', 'roberta.encoder.layer.0.attention.self.value.weight', 'roberta.encoder.layer.0.attention.self.value.bias', 'roberta.encoder.layer.0.attention.output.dense.weight', 'roberta.encoder.layer.0.attention.output.dense.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.intermediate.dense.weight', 'roberta.encoder.layer.0.intermediate.dense.bias', 'roberta.encoder.layer.0.output.dense.weight', 'roberta.encoder.layer.0.output.dense.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.self.query.weight', 'roberta.encoder.layer.1.attention.self.query.bias', 'roberta.encoder.layer.1.attentio

In [11]:
print(hf_config.bos_token_id)
print(hf_config.pad_token_id)
print(hf_config.eos_token_id)

0
1
2


In [50]:
word_embeddings = hf_dict["roberta.embeddings.word_embeddings.weight"]
input_ids = lv_prompt["input_ids"]
input_ids_np = np.array(lv_prompt["input_ids"].array)

input_embeds_np = word_embeddings[input_ids_np]
input_embeds_hax = hax.NamedArray(input_embeds_np, axes=(Batch, Pos, lv_config.Embed))

cond = (input_ids == lv_config.pad_token_id)
cond = hax.broadcast_to(cond, input_embeds_hax.axes)

input_embeds_hax = hax.where(
    cond,
    hax.zeros_like(input_embeds_hax),
    input_embeds_hax
)

input_embeds_torch = torch.from_numpy(np.array(input_embeds_hax.array))
new_lv_prompt = dict()
new_hf_prompt = dict()

new_lv_prompt["input_embeds"] = input_embeds_hax
new_hf_prompt["inputs_embeds"] = input_embeds_torch

new_lv_prompt["attention_mask"] = lv_prompt["attention_mask"]
new_hf_prompt["attention_mask"] = prompt["attention_mask"]

new_lv_prompt["position_ids"] = lv_prompt["position_ids"]
new_hf_prompt["position_ids"] = prompt["position_ids"]

In [53]:
print(check(np.array(input_embeds_hax.array), np.array(input_embeds_torch)))
print(check(np.array(input_embeds_hax.array), input_embeds_np))
print(check(np.array(input_embeds_torch), input_embeds_np))

Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 0.0331, Avg Difference: 0.012707
Accuracy: 0.0331, Avg Difference: 0.012707


In [36]:
# lv_result = lv_model(**lv_prompt)
# hf_result = hf_model(**prompt)

lv_result = lv_model(**new_lv_prompt)
hf_result = hf_model(**new_hf_prompt)

In [37]:
lv_result_logits = torch.from_numpy(np.array(lv_result.array))
hf_result_logits = hf_result.logits

In [38]:
print(check(np.array(lv_result_logits), np.array(hf_result_logits.detach())))
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(hf_result_logits.argmax(dim=-1).detach())))
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(prompt["input_ids"])))
print(check(np.array(hf_result_logits.argmax(dim=-1).detach()), np.array(prompt["input_ids"])))

Accuracy: 0.0002, Avg Difference: 2.350392
Accuracy: 0.9844, Avg Difference: 90.013619
Accuracy: 0.0175, Avg Difference: 286.669261
Accuracy: 0.0272, Avg Difference: 196.655642


In [39]:
tokenizer.batch_decode(lv_result_logits.argmax(dim=-1), skip_special_tokens=False)

['</s>Markham is the city of California, located in San Diego, California</s> California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>

In [40]:
tokenizer.batch_decode(hf_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s>Mark Zuckerberg is the CEO of Facebook, located in Palo Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

In [41]:
lv_result_logits.argmax(dim=-1)

tensor([[    2, 10006,  1908,    16,     5,   343,     9,   886,     6,  2034,
            11,   764,  3402,     6,   886,     2,   886,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [42]:
hf_result_logits.argmax(dim=-1)

tensor([[    0, 10006, 10741,    16,     5,  1324,     9,   622,     6,  2034,
            11, 21065, 18402,     6,   886,     4,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [43]:
np.array(prompt["input_ids"])

array([[    0, 10006, 50264,    16,     5,  1324,     9,   622,     6,
         2034,    11, 50264, 50264,     6,   886,     4,     2,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
      