In [158]:
from transformers import RobertaConfig as HFRobertaConfig
from transformers import RobertaForMaskedLM as HFRobertaForMaskedLM
from levanter.models.roberta import RobertaConfig
from levanter.models.roberta import RobertaForMaskedLM

import jax.random as jrandom
import jax.numpy as jnp
import haliax as hax

In [159]:
def load_weights_from_hf(hf, my):
    # Load the Hugging Face model
    hf_model = hf.from_pretrained("roberta-base")
    hf_config = HFRobertaConfig.from_pretrained("roberta-base")

    hf_config.hidden_dropout_prob = 0
    hf_config.attention_probs_dropout_prob = 0
    # hf_config.pad_token_id = -1

    lv_config = RobertaConfig.from_hf_config(hf_config)

    converter = lv_config.hf_checkpoint_converter()

    model = converter.load_pretrained(
        lv_config.model_type,
        lv_config,
        axis_mapping=None, 
        dtype="float32",  
    )

    print(converter.Vocab)
    
    #print("Weights loaded successfully.")
    return model, lv_config, hf_model, hf_config

lv_model, lv_config, hf_model, hf_config = load_weights_from_hf(HFRobertaForMaskedLM, RobertaForMaskedLM)


Loading weights: 100%|██████████| 203/203 [00:01<00:00, 132.19it/s]


vocab(50265)




In [351]:
print(lv_config.vocab_size)
print(hf_config.vocab_size)

print(lv_config.max_position_embeddings)
print(hf_config.max_position_embeddings)

50265
50265
514
514


In [161]:
import torch
import numpy as np

In [162]:
# Compare outputs
def check(my_out, hf_out, precision=1e-4):
    acc = np.isclose(hf_out, my_out, rtol=precision, atol=precision).mean()
    diff = np.abs(my_out - hf_out).mean()
    return f"Accuracy: {acc:.4f}, Avg Difference: {diff:.6f}"

In [315]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
prompt = tokenizer("Mark <mask> is the CEO of Facebook, located in <mask> <mask>, California.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Paris is the <mask> of France.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Mark Zuckerberg is the boss of Facebook, located in Palo Alto, California.", return_tensors="pt", padding='max_length', max_length=514)

In [316]:
Batch = hax.Axis("batch", 1)
Pos = lv_config.Pos

key = jrandom.PRNGKey(42)
key_var, key_model = jrandom.split(key, 2)

lv_prompt = {k: hax.NamedArray(np.array((prompt[k])), axes = (Batch, Pos)) for k in prompt.keys()}

# input_ids = hax.NamedArray(np.array(prompt["input_ids"]), axes = (Batch, Pos))
# attention_mask = hax.NamedArray(np.array(prompt["attention_mask"]), axes = (Batch, Pos))

In [317]:
# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
#     incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
#     incremental_indices -= mask.all(axis=Pos)
#     return incremental_indices + lv_config.pad_token_id

# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     return hax.arange(axis = Pos, start = 0, dtype=jnp.int32)

# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
#     incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
#     incremental_indices -= mask.all(axis=Pos)
    # return incremental_indices # + lv_config.pad_token_id

def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
    mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
    incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
    incremental_indices -= mask.all(axis=Pos) * lv_config.pad_token_id
    return incremental_indices

In [None]:
dict = hf_model.state_dict()
print(dict["roberta.embeddings.position_embeddings.weight"].shape)
print(dict["roberta.embeddings.position_embeddings.weight"])

torch.Size([514, 768])
tensor([[-0.0115,  0.0204,  0.0197,  ...,  0.0050, -0.0274, -0.0439],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0346, -0.0169, -0.0895,  ..., -0.0542,  0.0291,  0.0173],
        ...,
        [ 0.1191, -0.0587, -0.0396,  ..., -0.0179,  0.1249,  0.0205],
        [ 0.0833, -0.0272, -0.0368,  ..., -0.0429,  0.1288, -0.0052],
        [ 0.0850,  0.2573, -0.1257,  ..., -0.0952,  0.1443, -0.0131]])


In [353]:
prompt["input_ids"].shape

torch.Size([1, 514])

In [None]:
dict = lv_model.load_state_dict()
print(dict["roberta.embeddings.position_embeddings.weight"].shape)
print(dict["roberta.embeddings.position_embeddings.weight"])

In [318]:
position_ids = create_position_ids_from_input_ids(lv_prompt["input_ids"])

lv_prompt["position_ids"] = position_ids
prompt["position_ids"] = torch.from_numpy(np.array(position_ids.array))

In [319]:
print(check(np.array(lv_prompt["input_ids"].array), np.array(prompt["input_ids"])))
print(check(np.array(lv_prompt["attention_mask"].array), np.array(prompt["attention_mask"])))
# print(check(np.array(lv_prompt["position_ids"].array), np.array(prompt["position_ids"])))

Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 1.0000, Avg Difference: 0.000000


In [320]:
# def check_dicts(my_dict, hf_dict):
#     print(my_dict.keys())
#     print(hf_dict.keys())

#     my_keys = set(my_dict)
#     hf_keys = set(hf_dict)

#     both = hf_keys.union(my_keys)

#     correct_msg = "Accuracy: 1.0000, Avg Difference: 0.000000"

#     for k in both:
#         if k not in my_keys:
#             print(f"Key {k} not in my_keys!")
#             continue

#         if k not in hf_keys:
#             print(f"Key {k} not in hf_keys!")
#             continue
        
#         check_str = check(my_dict[k], hf_dict[k])
#         if check_str != correct_msg:
#             print(check_str)

# my_dict = lv_model.to_state_dict()
# hf_dict = hf_model.state_dict()

# my_dict = {k: np.array(my_dict[k]) for k in my_dict.keys()}
# hf_dict = {k: np.array(hf_dict[k]) for k in hf_dict.keys()}

# check_dicts(my_dict, hf_dict)

# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.decoder.bias"]))
# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))
# print(check(hf_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))

In [321]:
print(hf_config.bos_token_id)
print(hf_config.pad_token_id)
print(hf_config.eos_token_id)

0
1
2


In [322]:
# word_embeddings = hf_dict["roberta.embeddings.word_embeddings.weight"]
# input_ids = lv_prompt["input_ids"]
# input_ids_np = np.array(lv_prompt["input_ids"].array)

# input_embeds_np = word_embeddings[input_ids_np]
# input_embeds_hax = hax.NamedArray(input_embeds_np, axes=(Batch, Pos, lv_config.Embed))

# cond = (input_ids == lv_config.pad_token_id)
# cond = hax.broadcast_to(cond, input_embeds_hax.axes)

# input_embeds_hax = hax.where(
#     cond,
#     hax.zeros_like(input_embeds_hax),
#     input_embeds_hax
# )

# input_embeds_torch = torch.from_numpy(np.array(input_embeds_hax.array))
# new_lv_prompt = dict()
# new_hf_prompt = dict()

# new_lv_prompt["input_embeds"] = input_embeds_hax
# new_hf_prompt["inputs_embeds"] = input_embeds_torch

# new_lv_prompt["attention_mask"] = lv_prompt["attention_mask"]
# new_hf_prompt["attention_mask"] = prompt["attention_mask"]

# new_lv_prompt["position_ids"] = lv_prompt["position_ids"]
# new_hf_prompt["position_ids"] = prompt["position_ids"]

In [323]:
# print(check(np.array(input_embeds_hax.array), np.array(input_embeds_torch)))
# print(check(np.array(input_embeds_hax.array), input_embeds_np))
# print(check(np.array(input_embeds_torch), input_embeds_np))

In [324]:
lv_result = lv_model(**lv_prompt)
hf_result = hf_model(**prompt)

# lv_result = lv_model(**new_lv_prompt)
# hf_result = hf_model(**new_hf_prompt)

In [325]:
lv_result_logits = torch.from_numpy(np.array(lv_result.array))
hf_result_logits = hf_result.logits

In [326]:
print(check(np.array(lv_result_logits), np.array(hf_result_logits.detach())))
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(hf_result_logits.argmax(dim=-1).detach())))
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(prompt["input_ids"])))
print(check(np.array(hf_result_logits.argmax(dim=-1).detach()), np.array(prompt["input_ids"])))

Accuracy: 0.0000, Avg Difference: 3.441651
Accuracy: 0.0195, Avg Difference: 21839.554475
Accuracy: 0.0175, Avg Difference: 22111.554475
Accuracy: 0.0253, Avg Difference: 272.000000


In [327]:
old_fn = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

base = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange = ["</s>Mark</s> is the CEO of Facebook, located in San Angeles, California</s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s> cast cast cast</s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast</s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s></s></s></s></s></s> chance cast cast cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> chance cast cast cast cast cast cast cast cast chance cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast chance chance cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s></s></s></s></s> cast chance chance cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> chance cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast chance chance</s>"]

no_offset = ['<s><s>You are the age of America, located in the States, California</s></s> Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim']

In [328]:
old_fn_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

base_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange_hf = ['<s></s><s>is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>. Facebook.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

no_offset_hf = ['<s><s>Facebook is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

In [329]:
print(lv_result_logits.shape)
print(hf_result_logits.shape)

torch.Size([1, 514, 50265])
torch.Size([1, 514, 50265])


In [330]:
tokenizer.batch_decode(lv_prompt["input_ids"].array, skip_special_tokens=False)

['<s>Mark<mask> is the CEO of Facebook, located in<mask><mask>, California.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [331]:
tokenizer.batch_decode(prompt["input_ids"], skip_special_tokens=False)

['<s>Mark<mask> is the CEO of Facebook, located in<mask><mask>, California.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [332]:
tokenizer.batch_decode(lv_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s><s>You are the age of America, located in the States, California</s></s> Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Cr

In [333]:
tokenizer.batch_decode(hf_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s><s>Facebook is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

In [334]:
# from transformers import pipeline
# unmasker = pipeline('fill-mask', model='roberta-base')
# unmasker("The man worked as a <mask>.")

In [335]:
lv_result_logits

tensor([[[20.1505, -3.7458, 14.5442,  ..., -1.3559,  1.9249,  6.5880],
         [19.7139, -3.7571, 14.4782,  ..., -1.4627,  1.8345,  6.4597],
         [ 0.9419, -4.5570,  7.0148,  ..., -5.2417, -4.0212,  0.8554],
         ...,
         [ 5.6627, -3.6613,  9.3652,  ..., -1.4894,  0.8250,  1.8860],
         [ 5.6627, -3.6613,  9.3652,  ..., -1.4894,  0.8250,  1.8860],
         [ 5.6627, -3.6613,  9.3652,  ..., -1.4894,  0.8250,  1.8860]]])

In [336]:
hf_result_logits

tensor([[[33.7781, -3.9853, 21.2499,  ...,  2.9935,  5.3468, 11.6988],
         [33.7777, -3.9862, 21.2376,  ...,  2.9968,  5.3491, 11.6962],
         [ 3.5744, -3.4707,  9.2699,  ..., -2.4105, -0.3681,  2.4805],
         ...,
         [10.6002, -4.1998, 29.9508,  ..., -0.1202, -3.0393,  9.9499],
         [10.6002, -4.1998, 29.9508,  ..., -0.1202, -3.0393,  9.9499],
         [10.6002, -4.1998, 29.9508,  ..., -0.1202, -3.0393,  9.9499]]],
       grad_fn=<ViewBackward0>)

In [337]:
position_ids

NamedArray(array=Array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,

In [338]:
lv_result_logits.argmax(dim=-1)

tensor([[    0,     0,  1185,    32,     5,  1046,     9,   730,     6,  2034,
            11,     5,   532,     6,   886,     2,     2, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548, 22548,
         22548, 22548, 22548, 22548, 22548, 22548, 2

In [339]:
hf_result_logits.argmax(dim=-1)

tensor([[    0,     0, 18064,    16,     5,  1324,     9,   622,     6,  2034,
            11,   764,  2659,     6,   886,     4,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [340]:
np.array(prompt["input_ids"])

array([[    0, 10006, 50264,    16,     5,  1324,     9,   622,     6,
         2034,    11, 50264, 50264,     6,   886,     4,     2,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,
      