In [1]:
from transformers import RobertaConfig as HFRobertaConfig
from transformers import RobertaForMaskedLM as HFRobertaForMaskedLM
from levanter.models.roberta import RobertaConfig
from levanter.models.roberta import RobertaForMaskedLM

import jax
import jax.random as jrandom
import jax.numpy as jnp
import haliax as hax

  from .autonotebook import tqdm as notebook_tqdm
2025-04-01 19:03:57,848	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
def load_weights_from_hf():
    # Load the Hugging Face model
    hf_model = HFRobertaForMaskedLM.from_pretrained("roberta-base")
    hf_config = HFRobertaConfig.from_pretrained("roberta-base")

    hf_config.hidden_dropout_prob = 0
    hf_config.attention_probs_dropout_prob = 0
    # hf_config.pad_token_id = -1

    lv_config = RobertaConfig.from_hf_config(hf_config)

    converter = lv_config.hf_checkpoint_converter()

    model = converter.load_pretrained(
        lv_config.model_type,
        lv_config,
        axis_mapping=None, 
        dtype="float32",  
    )

    print(converter.Vocab)
    
    #print("Weights loaded successfully.")
    return model, lv_config, hf_model, hf_config

lv_model, lv_config, hf_model, hf_config = load_weights_from_hf()


Loading weights: 100%|██████████| 203/203 [00:02<00:00, 82.99it/s] 


vocab(50265)




In [3]:
print(lv_config.vocab_size)
print(hf_config.vocab_size)

print(lv_config.max_position_embeddings)
print(hf_config.max_position_embeddings)

50265
50265
514
514


In [4]:
import torch
import numpy as np

In [5]:
def named_to_tensor(named_array):
    out_tensor = torch.tensor(np.array(named_array.array))
    return out_tensor

In [6]:
# Compare outputs
def check(my_out, hf_out, precision=1e-4):
    acc = np.isclose(hf_out, my_out, rtol=precision, atol=precision).mean()
    diff = np.abs(my_out - hf_out).mean()
    return f"Accuracy: {acc:.4f}, Avg Difference: {diff:.6f}"

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [46]:
prompt = tokenizer("Mark <mask> is the CEO of Facebook, located in <mask> <mask>, California.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Paris is the <mask> of France.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Mark Zuckerberg is the boss of Facebook, located in Palo Alto, California.", return_tensors="pt", padding='max_length', max_length=514)

In [47]:
# prompt["input_ids"]

In [48]:
Batch = hax.Axis("batch", 1)
Pos = lv_config.Pos

key = jrandom.PRNGKey(42)
key_var, key_model = jrandom.split(key, 2)

In [49]:
# attention_mask_tensor = torch.ones(size = (Batch.size, Pos.size), dtype = int)
# prompt["attention_mask"] = attention_mask_tensor

In [50]:
lv_prompt = {k: hax.NamedArray(np.array((prompt[k])), axes = (Batch, Pos)) for k in prompt.keys()}

In [51]:
# input_ids = hax.random.randint(key_var, shape = (Batch, Pos), minval = lv_config.eos_token_id+1, maxval = lv_config.vocab_size)
# attention_mask = hax.ones(shape = (Batch, Pos), dtype=int)

# input_ids_tensor = named_to_tensor(input_ids)
# attention_mask_tensor = named_to_tensor(attention_mask)

In [52]:
# input_ids_tensor = torch.randint(low = lv_config.eos_token_id+1, high = lv_config.vocab_size, size = (Batch.size, Pos.size))
# input_ids_tensor[:, 0] = 0
# input_ids_tensor[:, (lv_config.max_position_embeddings // 2):] = 1
# input_ids_tensor[:, (lv_config.max_position_embeddings // 2)] = 2

# # idx = torch.randint(low = 1, high=lv_config.max_position_embeddings, size = (1,))
# idx = lv_config.max_position_embeddings // 2

# attention_mask_tensor = torch.ones(size = (Batch.size, Pos.size), dtype = int)
# attention_mask_tensor[:, idx:] = 0

# prompt = {"input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor}
# lv_prompt = {k: hax.NamedArray(np.array((prompt[k])), axes = (Batch, Pos)) for k in prompt.keys()}

In [53]:
# prompt = {"input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor}
# lv_prompt = {"input_ids": input_ids, "attention_mask": attention_mask}

In [81]:
def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
    mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
    incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
    incremental_indices -= mask.all(axis=Pos)
    return incremental_indices + lv_config.pad_token_id

# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     return hax.arange(axis = Pos, start = 0, dtype=jnp.int32)

# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
#     incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
#     incremental_indices -= mask.all(axis=Pos) * lv_config.pad_token_id
#     return incremental_indices

In [82]:
# dict = hf_model.state_dict()
# print(dict["roberta.embeddings.position_embeddings.weight"].shape)
# print(dict["roberta.embeddings.position_embeddings.weight"])

In [83]:
position_ids = create_position_ids_from_input_ids(lv_prompt["input_ids"])

lv_prompt["position_ids"] = position_ids
prompt["position_ids"] = torch.from_numpy(np.array(position_ids.array))

In [84]:
print(check(np.array(lv_prompt["input_ids"].array), np.array(prompt["input_ids"])))
print(check(np.array(lv_prompt["attention_mask"].array), np.array(prompt["attention_mask"])))
# print(check(np.array(lv_prompt["position_ids"].array), np.array(prompt["position_ids"])))

Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 1.0000, Avg Difference: 0.000000


In [85]:
# def check_dicts(my_dict, hf_dict):
#     print(my_dict.keys())
#     print(hf_dict.keys())

#     my_keys = set(my_dict)
#     hf_keys = set(hf_dict)

#     both = hf_keys.union(my_keys)

#     correct_msg = "Accuracy: 1.0000, Avg Difference: 0.000000"

#     for k in both:
#         if k not in my_keys:
#             print(f"Key {k} not in my_keys!")
#             continue

#         if k not in hf_keys:
#             print(f"Key {k} not in hf_keys!")
#             continue
        
#         check_str = check(my_dict[k], hf_dict[k])
#         if check_str != correct_msg:
#             print(check_str)

# my_dict = lv_model.to_state_dict()
# hf_dict = hf_model.state_dict()

# my_dict = {k: np.array(my_dict[k]) for k in my_dict.keys()}
# hf_dict = {k: np.array(hf_dict[k]) for k in hf_dict.keys()}

# check_dicts(my_dict, hf_dict)

# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.decoder.bias"]))
# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))
# print(check(hf_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))

In [86]:
print(hf_config.bos_token_id)
print(hf_config.pad_token_id)
print(hf_config.eos_token_id)

0
1
2


In [87]:
# word_embeddings = hf_dict["roberta.embeddings.word_embeddings.weight"]
# input_ids = lv_prompt["input_ids"]
# input_ids_np = np.array(lv_prompt["input_ids"].array)

# input_embeds_np = word_embeddings[input_ids_np]
# input_embeds_hax = hax.NamedArray(input_embeds_np, axes=(Batch, Pos, lv_config.Embed))

# cond = (input_ids == lv_config.pad_token_id)
# cond = hax.broadcast_to(cond, input_embeds_hax.axes)

# input_embeds_hax = hax.where(
#     cond,
#     hax.zeros_like(input_embeds_hax),
#     input_embeds_hax
# )

# input_embeds_torch = torch.from_numpy(np.array(input_embeds_hax.array))
# new_lv_prompt = dict()
# new_hf_prompt = dict()

# new_lv_prompt["input_embeds"] = input_embeds_hax
# new_hf_prompt["inputs_embeds"] = input_embeds_torch

# new_lv_prompt["attention_mask"] = lv_prompt["attention_mask"]
# new_hf_prompt["attention_mask"] = prompt["attention_mask"]

# new_lv_prompt["position_ids"] = lv_prompt["position_ids"]
# new_hf_prompt["position_ids"] = prompt["position_ids"]

In [88]:
# print(check(np.array(input_embeds_hax.array), np.array(input_embeds_torch)))
# print(check(np.array(input_embeds_hax.array), input_embeds_np))
# print(check(np.array(input_embeds_torch), input_embeds_np))

In [89]:
# batch = hax.Axis("batch", 1)
# heads = hax.Axis("heads", 12)
# position = hax.Axis("position", 514)
# key_position = hax.Axis("key_position", 514)

# attention_scores = hax.ones({batch: batch.size, heads: heads.size, position: position.size, key_position: key_position.size})

In [90]:
# attention_mask = lv_prompt["attention_mask"]

# attention_mask = (attention_mask == 0) * -1e12

# attention_mask = attention_mask.rename({"position": "key_position"})
# # print(attention_mask)
# print(attention_mask.axes)

# attention = attention_scores + attention_mask

# print(attention)

In [91]:
lv_result = lv_model(**lv_prompt)
hf_result = hf_model(**prompt)

# lv_result = lv_model(**new_lv_prompt)
# hf_result = hf_model(**new_hf_prompt)

In [92]:
lv_result_logits = torch.from_numpy(np.array(lv_result.array))
hf_result_logits = hf_result.logits

In [93]:
print(f"lv_logits vs hf_logits: {check(np.array(lv_result_logits), np.array(hf_result_logits.detach()))}")
print(f"lv_tokens vs hf_tokens: {check(np.array(lv_result_logits.argmax(dim=-1)), np.array(hf_result_logits.argmax(dim=-1).detach()))}")
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(prompt["input_ids"])))
print(check(np.array(hf_result_logits.argmax(dim=-1).detach()), np.array(prompt["input_ids"])))

lv_logits vs hf_logits: Accuracy: 1.0000, Avg Difference: 0.000005
lv_tokens vs hf_tokens: Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 0.0272, Avg Difference: 236.151751
Accuracy: 0.0272, Avg Difference: 236.151751


In [94]:
old_fn = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

no_ids = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange = ["</s>Mark</s> is the CEO of Facebook, located in San Angeles, California</s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s> cast cast cast</s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast</s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s></s></s></s></s></s> chance cast cast cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> chance cast cast cast cast cast cast cast cast chance cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast chance chance cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s></s></s></s></s> cast chance chance cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> chance cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast chance chance</s>"]

arange_no_offset = "nan"

old_fn_no_offset = ['<s><s>You are the age of America, located in the States, California</s></s> Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim']

In [95]:
old_fn_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

no_ids_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange_hf = ['<s></s><s>is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>. Facebook.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange_no_offset = "error"

old_fn_no_offset_hf = ['<s><s>Facebook is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

In [96]:
print(lv_result_logits.shape)
print(hf_result_logits.shape)

torch.Size([1, 514, 50265])
torch.Size([1, 514, 50265])


In [97]:
tokenizer.batch_decode(lv_prompt["input_ids"].array, skip_special_tokens=False)

['<s>Mark<mask> is the CEO of Facebook, located in<mask><mask>, California.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [98]:
tokenizer.batch_decode(prompt["input_ids"], skip_special_tokens=False)

['<s>Mark<mask> is the CEO of Facebook, located in<mask><mask>, California.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [99]:
tokenizer.batch_decode(lv_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s

In [100]:
tokenizer.batch_decode(hf_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s

In [101]:
# from transformers import pipeline
# unmasker = pipeline('fill-mask', model='roberta-base')
# unmasker("The man worked as a <mask>.")

In [102]:
lv_result_logits

tensor([[[32.7624, -3.9994, 19.7142,  ...,  2.9184,  4.8333, 10.7573],
         [ 4.9514, -4.4551, 14.2476,  ..., -1.5525, -0.6854,  1.8734],
         [ 1.2989, -2.7265,  8.4715,  ...,  2.7949,  3.4709,  2.1556],
         ...,
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683],
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683],
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683]]])

In [103]:
hf_result_logits

tensor([[[32.7624, -3.9994, 19.7141,  ...,  2.9185,  4.8333, 10.7573],
         [ 4.9514, -4.4551, 14.2476,  ..., -1.5525, -0.6854,  1.8734],
         [ 1.2989, -2.7265,  8.4715,  ...,  2.7949,  3.4709,  2.1556],
         ...,
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683],
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683],
         [12.5755, -3.8982, 33.2602,  ...,  1.6317, -2.0114,  9.8683]]],
       grad_fn=<ViewBackward0>)

In [104]:
position_ids

NamedArray(array=Array([[ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,

In [105]:
lv_result_logits.argmax(dim=-1)

tensor([[    0, 10006, 10741,    16,     5,  1324,     9,   622,     6,  2034,
            11,   764, 18402,     6,   886,     4,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [106]:
hf_result_logits.argmax(dim=-1)

tensor([[    0, 10006, 10741,    16,     5,  1324,     9,   622,     6,  2034,
            11,   764, 18402,     6,   886,     4,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [107]:
np.array(prompt["input_ids"])

array([[    0, 10006, 50264,    16,     5,  1324,     9,   622,     6,
         2034,    11, 50264, 50264,     6,   886,     4,     2,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
      