In [59]:
from transformers import RobertaConfig as HFRobertaConfig
from transformers import RobertaForMaskedLM as HFRobertaForMaskedLM
from levanter.models.roberta import RobertaConfig
from levanter.models.roberta import RobertaForMaskedLM

import jax
import jax.random as jrandom
import jax.numpy as jnp
import haliax as hax

In [60]:
def load_weights_from_hf():
    # Load the Hugging Face model
    hf_model = HFRobertaForMaskedLM.from_pretrained("roberta-base")
    hf_config = HFRobertaConfig.from_pretrained("roberta-base")

    hf_config.hidden_dropout_prob = 0
    hf_config.attention_probs_dropout_prob = 0
    # hf_config.pad_token_id = -1

    lv_config = RobertaConfig.from_hf_config(hf_config)

    converter = lv_config.hf_checkpoint_converter()

    model = converter.load_pretrained(
        lv_config.model_type,
        lv_config,
        axis_mapping=None, 
        dtype="float32",  
    )

    print(converter.Vocab)
    
    #print("Weights loaded successfully.")
    return model, lv_config, hf_model, hf_config

lv_model, lv_config, hf_model, hf_config = load_weights_from_hf()


Loading weights: 100%|██████████| 203/203 [00:01<00:00, 148.52it/s]


vocab(50265)




In [61]:
print(lv_config.vocab_size)
print(hf_config.vocab_size)

print(lv_config.max_position_embeddings)
print(hf_config.max_position_embeddings)

50265
50265
514
514


In [62]:
import torch
import numpy as np

In [63]:
def named_to_tensor(named_array):
    out_tensor = torch.tensor(np.array(named_array.array))
    return out_tensor

def tensor_to_named(in_tensor, axes):
    named_array = hax.NamedArray(np.array(in_tensor), axes)
    return named_array

In [64]:
# Compare outputs
def check(my_out, hf_out, precision=1e-4):
    acc = np.isclose(hf_out, my_out, rtol=precision, atol=precision).mean()
    diff = np.abs(my_out - hf_out).mean()
    return f"Accuracy: {acc:.4f}, Avg Difference: {diff:.6f}"

In [65]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [66]:
Batch = hax.Axis("batch", 1)
Pos = lv_config.Pos
KeyPos = lv_config.KeyPos
Vocab = hax.Axis("vocab", tokenizer.vocab_size)

key = jrandom.PRNGKey(42)
key_var, key_model = jrandom.split(key, 2)

In [67]:
prompt = tokenizer("Explaining metaphysics to the nation. I wish he would explain his explanation. You, Bob, are rather insolent, you know, At being disappointed in your wish To supersede all warblers here below, And be the only blackbird in the dish. And then you overstrain yourself, or so And tumble downward like the flying fish Gasping on deck, because you soar too high, Bob, And fall for lack of moisture quite a dry Bob. And Wordsworth in a rather long Excursion (I think the quarto holds five hundred pages) Has given a sample from the vasty version Of his new system to perplex the sages. ’Tis poetry, at least by his assertion, And may appear so when the Dog Star rages, And he who understands it would be able To add a story to the tower of Babel. You gentlemen, by dint of long seclusion From better company, have kept your own At Keswick, and through still continued fusion Of one another’s minds at last have grown To deem, as a most logical conclusion, That poesy has wreaths for you alone. There is a narrowness in such a notion, Which makes me wish you’d change your lakes for ocean. I would not imitate the petty thought, Nor coin my self-love to so base a vice,  For all the glory your conversion brought,    Since gold alone should not have been its price. You have your salary; was’t for that you wrought?    And Wordsworth has his place in the Excise. You’re shabby fellows—true—but poets still And duly seated on the immortal hill. Your bays may hide the baldness of your brows,    Perhaps some virtuous blushes; let them go. To you I envy neither fruit nor boughs,  And for the fame you would engross below, The field is universal and allows    Scope to all such as feel the inherent glow. Scott, Rogers, Campbell, Moore, and Crabbe will try ’Gainst you the question with posterity. For me, who, wandering with pedestrian Muses,    Contend not with you on the winged’ steed, I wish your fate may yield ye, when she chooses,    The fame you envy and the skill you need. And recollect a poet nothing loses    In giving to his    ", return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (514 > 512). Running this sequence through the model will result in indexing errors


In [68]:
# prompt = tokenizer("Mark <mask> is the CEO of Facebook, located in <mask> <mask>, California.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Paris is the <mask> of France.", return_tensors="pt", padding='max_length', max_length=514)
# prompt = tokenizer("Mark Zuckerberg is the boss of Facebook, located in Palo Alto, California.", return_tensors="pt", padding='max_length', max_length=514)

In [69]:
lv_prompt = {"input_ids": tensor_to_named(prompt["input_ids"], (Batch, Pos)), "attention_mask": tensor_to_named(prompt["attention_mask"], (Batch, KeyPos))}

In [70]:
# prompt["input_ids"]

In [71]:
# attention_mask_tensor = torch.ones(size = (Batch.size, Pos.size), dtype = int)
# prompt["attention_mask"] = attention_mask_tensor

In [72]:
# lv_prompt = {k: hax.NamedArray(np.array((prompt[k])), axes = (Batch, Pos)) for k in prompt.keys()}

In [73]:
# input_ids = hax.random.randint(key_var, shape = (Batch, Pos), minval = lv_config.eos_token_id+1, maxval = lv_config.vocab_size)
# attention_mask = hax.ones(shape = (Batch, Pos), dtype=int)

# input_ids_tensor = named_to_tensor(input_ids)
# attention_mask_tensor = named_to_tensor(attention_mask)

In [74]:
# input_ids = haliax.random.randint(jax.random.PRNGKey(42), shape = (haliax.Axis("batch", 1), haliax.Axis("position", 512)), minval = 4, maxval = 50265)
# input_ids = torch.ones((1, 512), dtype=int) * 100

In [75]:
# input_ids_tensor = torch.randint(low = lv_config.eos_token_id+1, high = lv_config.vocab_size, size = (Batch.size, Pos.size))
# # input_ids_tensor[:, 0] = 0
# # input_ids_tensor[:, (lv_config.max_position_embeddings // 2):] = 1
# # input_ids_tensor[:, (lv_config.max_position_embeddings // 2)] = 2

# # # input_ids_tensor = torch.ones((Batch.size, Pos.size), dtype=int) * 100

# idx = torch.randint(low = 1, high=lv_config.max_position_embeddings, size = (1,))
# # idx = lv_config.max_position_embeddings // 2

# attention_mask_tensor = torch.ones(size = (Batch.size, KeyPos.size), dtype = int)
# # attention_mask_tensor[:, idx:] = 0

# attention_mask = tensor_to_named(attention_mask_tensor, (Batch, KeyPos))

# # # attention_mask_tensor = torch.ones(size = (Batch.size, Pos.size), dtype = int)
# # # attention_mask_tensor[:, idx:] = 0

# # # attention_mask = tensor_to_named(attention_mask_tensor, (Batch, KeyPos))

# input_ids = tensor_to_named(input_ids_tensor, (Batch, Pos))
# prompt = {"input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor}
# lv_prompt = {"input_ids": input_ids, "attention_mask": attention_mask}

# # # prompt = {"input_ids": input_ids_tensor}
# # # lv_prompt = {"input_ids": input_ids}

In [76]:
# prompt = {"input_ids": input_ids_tensor, "attention_mask": attention_mask_tensor}
# lv_prompt = {"input_ids": input_ids, "attention_mask": attention_mask}

In [77]:
# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
#     incremental_indices = (hax.cumsum(mask, axis=lv_config.Pos).astype(mask) + past_key_values_length) * mask
#     incremental_indices -= mask.all(axis=Pos)
#     return incremental_indices + lv_config.pad_token_id

# def create_position_ids_from_input_ids(input_ids, past_key_values_length=0):
#     return hax.arange(axis = Pos, start = 0, dtype=jnp.int32)

def create_position_ids_from_input_ids(input_ids, PosInput, past_key_values_length=0):
    mask = hax.not_equal(input_ids, lv_config.pad_token_id) * 1
    incremental_indices = (hax.cumsum(mask, axis=PosInput).astype(mask) + past_key_values_length) * mask
    incremental_indices -= mask.all(axis=PosInput) * lv_config.pad_token_id
    return incremental_indices

In [78]:
# dict = hf_model.state_dict()
# print(dict["roberta.embeddings.position_embeddings.weight"].shape)
# print(dict["roberta.embeddings.position_embeddings.weight"])

In [79]:
position_ids = create_position_ids_from_input_ids(lv_prompt["input_ids"], Pos)

lv_prompt["position_ids"] = position_ids
prompt["position_ids"] = torch.from_numpy(np.array(position_ids.array))

In [80]:
print(check(np.array(lv_prompt["input_ids"].array), np.array(prompt["input_ids"])))
# print(check(np.array(lv_prompt["attention_mask"].array), np.array(prompt["attention_mask"])))
# print(check(np.array(lv_prompt["position_ids"].array), np.array(prompt["position_ids"])))

Accuracy: 1.0000, Avg Difference: 0.000000


In [81]:
# def check_dicts(my_dict, hf_dict):
#     print(my_dict.keys())
#     print(hf_dict.keys())

#     my_keys = set(my_dict)
#     hf_keys = set(hf_dict)

#     both = hf_keys.union(my_keys)

#     correct_msg = "Accuracy: 1.0000, Avg Difference: 0.000000"

#     for k in both:
#         if k not in my_keys:
#             print(f"Key {k} not in my_keys!")
#             continue

#         if k not in hf_keys:
#             print(f"Key {k} not in hf_keys!")
#             continue
        
#         check_str = check(my_dict[k], hf_dict[k])
#         if check_str != correct_msg:
#             print(check_str)

# my_dict = lv_model.to_state_dict()
# hf_dict = hf_model.state_dict()

# my_dict = {k: np.array(my_dict[k]) for k in my_dict.keys()}
# hf_dict = {k: np.array(hf_dict[k]) for k in hf_dict.keys()}

# check_dicts(my_dict, hf_dict)

# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.decoder.bias"]))
# print(check(my_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))
# print(check(hf_dict["lm_head.decoder.bias"], hf_dict["lm_head.bias"]))

In [82]:
print(hf_config.bos_token_id)
print(hf_config.pad_token_id)
print(hf_config.eos_token_id)

0
1
2


In [83]:
# word_embeddings = hf_dict["roberta.embeddings.word_embeddings.weight"]
# input_ids = lv_prompt["input_ids"]
# input_ids_np = np.array(lv_prompt["input_ids"].array)

# input_embeds_np = word_embeddings[input_ids_np]
# input_embeds_hax = hax.NamedArray(input_embeds_np, axes=(Batch, Pos, lv_config.Embed))

# cond = (input_ids == lv_config.pad_token_id)
# cond = hax.broadcast_to(cond, input_embeds_hax.axes)

# input_embeds_hax = hax.where(
#     cond,
#     hax.zeros_like(input_embeds_hax),
#     input_embeds_hax
# )

# input_embeds_torch = torch.from_numpy(np.array(input_embeds_hax.array))
# new_lv_prompt = dict()
# new_hf_prompt = dict()

# new_lv_prompt["input_embeds"] = input_embeds_hax
# new_hf_prompt["inputs_embeds"] = input_embeds_torch

# new_lv_prompt["attention_mask"] = lv_prompt["attention_mask"]
# new_hf_prompt["attention_mask"] = prompt["attention_mask"]

# new_lv_prompt["position_ids"] = lv_prompt["position_ids"]
# new_hf_prompt["position_ids"] = prompt["position_ids"]

In [84]:
# batch = hax.Axis("batch", 1)
# heads = hax.Axis("heads", 12)
# position = hax.Axis("position", 514)
# key_position = hax.Axis("key_position", 514)

# attention_scores = hax.ones({batch: batch.size, heads: heads.size, position: position.size, key_position: key_position.size})

# attention_mask = lv_prompt["attention_mask"]

# attention_mask = (attention_mask == 0) * -1e12

# attention_mask = attention_mask.rename({"position": "key_position"})
# # print(attention_mask)
# print(attention_mask.axes)

# attention = attention_scores + attention_mask

# print(attention)

In [85]:
lv_result = lv_model(**lv_prompt)
hf_result = hf_model(**prompt)

# lv_result = lv_model(**new_lv_prompt)
# hf_result = hf_model(**new_hf_prompt)

In [86]:
lv_result_logits = torch.from_numpy(np.array(lv_result.array))
hf_result_logits = hf_result.logits

In [87]:
print(f"lv_logits vs hf_logits: {check(np.array(lv_result_logits), np.array(hf_result_logits.detach()))}")
print(f"lv_tokens vs hf_tokens: {check(np.array(lv_result_logits.argmax(dim=-1)), np.array(hf_result_logits.argmax(dim=-1).detach()))}")
print(check(np.array(lv_result_logits.argmax(dim=-1)), np.array(prompt["input_ids"])))
print(check(np.array(hf_result_logits.argmax(dim=-1).detach()), np.array(prompt["input_ids"])))

lv_logits vs hf_logits: Accuracy: 1.0000, Avg Difference: 0.000005
lv_tokens vs hf_tokens: Accuracy: 1.0000, Avg Difference: 0.000000
Accuracy: 0.9689, Avg Difference: 538.085603
Accuracy: 0.9689, Avg Difference: 538.085603


In [88]:
# from levanter.models.lm_model import MaskedLmExample

# example = MaskedLmExample.masked_lm(tokens=lv_result.argmax("vocab")[{"batch": 0}], targets=lv_prompt["input_ids"][{"batch": 0}], mask_token_id=tokenizer.mask_token_id, attn_mask=lv_prompt["attention_mask"][{"batch": 0}])

In [89]:
from levanter.models.lm_model import MaskedLmExample

mask_prob = 0.15
mask_token_id = tokenizer.mask_token_id
noise_prob = 0.1

QPos = Pos
KPos = KeyPos

def _create_mlm_example(tokens, key):
    tokens_array = tokens.array
    targets = tokens_array.copy()

    if mask_prob > 0:
        this_key, key = jax.random.split(key)
        mask_shape = tokens_array.shape
        mask = jax.random.bernoulli(this_key, mask_prob, mask_shape)

        rand = jax.random.uniform(this_key, mask_shape)
        mask_token = jnp.where(rand < 0.8, mask_token_id, tokens_array)
        random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1)
        mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + noise_prob), random_tokens, mask_token)
        masked_tokens = jnp.where(mask, mask_token, tokens_array)

        # Set targets to the original tokens where mask is True, otherwise set to mask_token_id
        targets = jnp.where(mask, tokens_array, mask_token_id)

        masked_tokens_named = hax.named(masked_tokens, QPos)
        targets_named = hax.named(targets, QPos)

        attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
        attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (QPos, KPos))

        example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=mask_token_id, attn_mask=attn_mask)
    else:
        targets_named = hax.named(targets, QPos)
        attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0])
        attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (QPos, KPos))

        example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=mask_token_id, attn_mask=attn_mask)

    return example

example = _create_mlm_example(lv_prompt["input_ids"][{"batch": 0}], key_model)
lv_model.compute_loss(example, key=key_model, reduction=hax.mean, reduction_axis=Pos)

NamedArray(array=Array(3.7537427, dtype=float32), axes=())

In [90]:
# check(named_to_tensor(example.tokens * 1.), named_to_tensor(lv_prompt["input_ids"]* 1.))

In [91]:
# tokenizer.decode(named_to_tensor(example.tokens))

In [92]:
# from haliax.nn import cross_entropy_loss
# import torch.types

# lv_logits = lv_model(example.tokens, example.attn_mask, position_ids=lv_prompt["position_ids"], key=key)
# lv_logits = lv_logits.astype(jnp.float32)

# hf_logits = hf_model(**{"input_ids" : named_to_tensor(example.tokens)[None, :], "attention_mask" : named_to_tensor(example.attn_mask)[None, :], "position_ids" : prompt["position_ids"]}).logits
# hf_logits = tensor_to_named(hf_logits.detach().numpy(), (Batch, Pos, Vocab))

# targets = example.targets
# target_y = hax.nn.one_hot(targets, Vocab, dtype=lv_logits.dtype)

# # jnp.dtype("long")

# lv_on_lv_loss = cross_entropy_loss(
#     lv_logits, Vocab, target_y, hax.mean, reduction_axis=Pos, where=example.loss_mask
# )

# hf_on_lv_loss = cross_entropy_loss(
#     hf_logits, Vocab, target_y, hax.mean, reduction_axis=Pos, where=example.loss_mask
# )

# # lv_loss = cross_entropy_loss(
# #     logits, Vocab, target_y, hax.mean, reduction_axis=Pos
# # )
# # target_y_torch = named_to_tensor(target_y)

# targets_torch = named_to_tensor(example.targets)
# # targets_torch = named_to_tensor(example.targets.astype(jnp.dtype("longlong")))

# targets_torch = targets_torch * named_to_tensor(example.loss_mask)
# targets_torch[targets_torch == 0] = -100
# targets_torch = torch.tensor(targets_torch, dtype=torch.long)

# lv_on_torch_loss = torch.nn.functional.cross_entropy(named_to_tensor(lv_logits)[0], targets_torch)
# hf_on_torch_loss = torch.nn.functional.cross_entropy(named_to_tensor(hf_logits)[0], targets_torch)

# print(lv_on_lv_loss)
# print(hf_on_lv_loss)
# print(lv_on_torch_loss)
# print(hf_on_torch_loss)

In [93]:
lv_result_logits.argmax(dim=-1)

tensor([[    0,     2,     0, 17297,  2633,     7,     5,  1226,     4,    38,
          2813,    37,    74,  3922,    39,  8257,     4,   370,     6,  3045,
             6,    32,  1195, 23799,  1342,     6,    47,   216,     6,   497,
           145,  5779,    11,   110,  2813,   598, 31716, 12820,    70,   997,
         25274,   259,   874,     6,   178,    28,     5,   129,   909, 15886,
            11,     5,  8847,     4,   178,   172,    47,    81,  6031,  1851,
          2512,     6,    50,    98,   178, 26566, 14659,   101,     5,  4731,
          3539,   272,  9331,   154,    15,  9124,     6,   142,    47, 27283,
           350,   239,     6,  3045,     6,   178,  1136,    13,  1762,     9,
         16227, 27800,    10,  3841,  3045,     4,   178, 15690, 15142,    11,
            10,  1195,   251, 22847, 29471,    36,   100,   206,     5, 18284,
           139,  3106,   292,  6317,  6052,    43,  6233,   576,    10,  7728,
            31,     5, 18940,   219,  1732,  1525,  

In [94]:
old_fn = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

no_ids = ['</s>Markham is the city of California, located in San Diego, California</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange = ["</s>Mark</s> is the CEO of Facebook, located in San Angeles, California</s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s> cast cast cast</s></s></s></s></s></s> cast cast cast cast</s></s></s></s></s></s></s></s> cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast</s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast</s></s></s></s></s></s> chance cast cast cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast</s></s></s></s></s></s></s></s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast</s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> chance cast cast cast cast cast cast cast cast chance cast cast cast cast cast cast</s> cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast chance chance cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s> cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast cast</s></s></s></s></s></s></s></s></s> cast chance chance cast cast cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> chance cast</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s> cast cast cast cast cast cast cast cast chance chance</s>"]

arange_no_offset = "nan"

old_fn_no_offset = ['<s><s>You are the age of America, located in the States, California</s></s> Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim Crim']

In [95]:
old_fn_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

no_ids_hf = ['<s>Mark Zuckerberg is the CEO of Facebook, located in San Alto, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange_hf = ['<s></s><s>is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>. Facebook.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

arange_no_offset = "error"

old_fn_no_offset_hf = ['<s><s>Facebook is the CEO of Facebook, located in San Francisco, California.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

In [96]:
print(lv_result_logits.shape)
print(hf_result_logits.shape)

torch.Size([1, 514, 50265])
torch.Size([1, 514, 50265])


In [97]:
tokenizer.batch_decode(lv_prompt["input_ids"].array, skip_special_tokens=False)

['<s>Explaining metaphysics to the nation. I wish he would explain his explanation. You, Bob, are rather insolent, you know, At being disappointed in your wish To supersede all warblers here below, And be the only blackbird in the dish. And then you overstrain yourself, or so And tumble downward like the flying fish Gasping on deck, because you soar too high, Bob, And fall for lack of moisture quite a dry Bob. And Wordsworth in a rather long Excursion (I think the quarto holds five hundred pages) Has given a sample from the vasty version Of his new system to perplex the sages. ’Tis poetry, at least by his assertion, And may appear so when the Dog Star rages, And he who understands it would be able To add a story to the tower of Babel. You gentlemen, by dint of long seclusion From better company, have kept your own At Keswick, and through still continued fusion Of one another’s minds at last have grown To deem, as a most logical conclusion, That poesy has wreaths for you alone. There is

In [98]:
tokenizer.batch_decode(prompt["input_ids"], skip_special_tokens=False)

['<s>Explaining metaphysics to the nation. I wish he would explain his explanation. You, Bob, are rather insolent, you know, At being disappointed in your wish To supersede all warblers here below, And be the only blackbird in the dish. And then you overstrain yourself, or so And tumble downward like the flying fish Gasping on deck, because you soar too high, Bob, And fall for lack of moisture quite a dry Bob. And Wordsworth in a rather long Excursion (I think the quarto holds five hundred pages) Has given a sample from the vasty version Of his new system to perplex the sages. ’Tis poetry, at least by his assertion, And may appear so when the Dog Star rages, And he who understands it would be able To add a story to the tower of Babel. You gentlemen, by dint of long seclusion From better company, have kept your own At Keswick, and through still continued fusion Of one another’s minds at last have grown To deem, as a most logical conclusion, That poesy has wreaths for you alone. There is

In [99]:
tokenizer.batch_decode(lv_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s></s><s>Ph presented to the nation. I wish he would explain his explanation. You, Bob, are rather insolent, you know, At being disappointed in your wish To supersede all warblers here below, And be the only blackbird in the dish. And then you overstrain yourself, or so And tumble downward like the flying fish Gasping on deck, because you soar too high, Bob, And fall for lack of moisture Quite a dry Bob. And Wordsworth in a rather long Excursion (I think the quarto holds five hundred pages) Has given a sample from the fleshy version Of his new system to perplex the sages. ’Tis poetry, at least by his estimation, And may appear so when the Dog Star rages, And he who understands it would be able To add a story to the tower of Babel. You gentlemen, by dint of long seclusion From better company, have kept your own At Keswick, and through still continued fusion Of one another’s minds at last have grown To deem, as a most logical conclusion, That poetry has wreaths for you alone. There is

In [100]:
tokenizer.batch_decode(hf_result_logits.argmax(dim=-1), skip_special_tokens=False)

['<s></s><s>Ph presented to the nation. I wish he would explain his explanation. You, Bob, are rather insolent, you know, At being disappointed in your wish To supersede all warblers here below, And be the only blackbird in the dish. And then you overstrain yourself, or so And tumble downward like the flying fish Gasping on deck, because you soar too high, Bob, And fall for lack of moisture Quite a dry Bob. And Wordsworth in a rather long Excursion (I think the quarto holds five hundred pages) Has given a sample from the fleshy version Of his new system to perplex the sages. ’Tis poetry, at least by his estimation, And may appear so when the Dog Star rages, And he who understands it would be able To add a story to the tower of Babel. You gentlemen, by dint of long seclusion From better company, have kept your own At Keswick, and through still continued fusion Of one another’s minds at last have grown To deem, as a most logical conclusion, That poetry has wreaths for you alone. There is

In [101]:
# from transformers import pipeline
# unmasker = pipeline('fill-mask', model='roberta-base')
# unmasker("The man worked as a <mask>.")

In [102]:
lv_result_logits

tensor([[[32.1017, -4.9437, 14.0598,  ..., -1.6564,  1.5099, 10.0954],
         [ 5.6817, -4.4572, 22.0748,  ..., -4.5277, -5.5077,  5.5239],
         [31.3908, -4.9682, 14.1496,  ..., -1.7012,  1.4606,  9.9348],
         ...,
         [ 3.6521, -3.7570, 17.7734,  ..., -0.4965,  0.1471,  5.7731],
         [ 2.3458, -4.1200, 17.9135,  ..., -3.8652, -2.7885,  4.4862],
         [10.6070, -4.2919, 23.0414,  ..., -2.8666, -5.3789,  6.4353]]])

In [103]:
hf_result_logits

tensor([[[32.1016, -4.9437, 14.0598,  ..., -1.6564,  1.5099, 10.0954],
         [ 5.6817, -4.4572, 22.0748,  ..., -4.5277, -5.5077,  5.5239],
         [31.3909, -4.9682, 14.1496,  ..., -1.7012,  1.4606,  9.9348],
         ...,
         [ 3.6521, -3.7570, 17.7734,  ..., -0.4965,  0.1471,  5.7731],
         [ 2.3459, -4.1200, 17.9135,  ..., -3.8652, -2.7885,  4.4862],
         [10.6070, -4.2919, 23.0414,  ..., -2.8666, -5.3789,  6.4353]]],
       grad_fn=<ViewBackward0>)

In [104]:
position_ids

NamedArray(array=Array([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
        143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
        169, 170, 171, 172, 173, 1

In [105]:
lv_result_logits.argmax(dim=-1)

tensor([[    0,     2,     0, 17297,  2633,     7,     5,  1226,     4,    38,
          2813,    37,    74,  3922,    39,  8257,     4,   370,     6,  3045,
             6,    32,  1195, 23799,  1342,     6,    47,   216,     6,   497,
           145,  5779,    11,   110,  2813,   598, 31716, 12820,    70,   997,
         25274,   259,   874,     6,   178,    28,     5,   129,   909, 15886,
            11,     5,  8847,     4,   178,   172,    47,    81,  6031,  1851,
          2512,     6,    50,    98,   178, 26566, 14659,   101,     5,  4731,
          3539,   272,  9331,   154,    15,  9124,     6,   142,    47, 27283,
           350,   239,     6,  3045,     6,   178,  1136,    13,  1762,     9,
         16227, 27800,    10,  3841,  3045,     4,   178, 15690, 15142,    11,
            10,  1195,   251, 22847, 29471,    36,   100,   206,     5, 18284,
           139,  3106,   292,  6317,  6052,    43,  6233,   576,    10,  7728,
            31,     5, 18940,   219,  1732,  1525,  

In [106]:
hf_result_logits.argmax(dim=-1)

tensor([[    0,     2,     0, 17297,  2633,     7,     5,  1226,     4,    38,
          2813,    37,    74,  3922,    39,  8257,     4,   370,     6,  3045,
             6,    32,  1195, 23799,  1342,     6,    47,   216,     6,   497,
           145,  5779,    11,   110,  2813,   598, 31716, 12820,    70,   997,
         25274,   259,   874,     6,   178,    28,     5,   129,   909, 15886,
            11,     5,  8847,     4,   178,   172,    47,    81,  6031,  1851,
          2512,     6,    50,    98,   178, 26566, 14659,   101,     5,  4731,
          3539,   272,  9331,   154,    15,  9124,     6,   142,    47, 27283,
           350,   239,     6,  3045,     6,   178,  1136,    13,  1762,     9,
         16227, 27800,    10,  3841,  3045,     4,   178, 15690, 15142,    11,
            10,  1195,   251, 22847, 29471,    36,   100,   206,     5, 18284,
           139,  3106,   292,  6317,  6052,    43,  6233,   576,    10,  7728,
            31,     5, 18940,   219,  1732,  1525,  

In [107]:
np.array(prompt["input_ids"])

array([[    0, 43043,  8173, 41724, 33823,     7,     5,  1226,     4,
           38,  2813,    37,    74,  3922,    39,  8257,     4,   370,
            6,  3045,     6,    32,  1195, 23799,  1342,     6,    47,
          216,     6,   497,   145,  5779,    11,   110,  2813,   598,
        31716, 12820,    70,   997, 25274,   259,   874,     6,   178,
           28,     5,   129,   909, 15886,    11,     5,  8847,     4,
          178,   172,    47,    81,  6031,  1851,  2512,     6,    50,
           98,   178, 26566, 14659,   101,     5,  4731,  3539,   272,
         9331,   154,    15,  9124,     6,   142,    47, 27283,   350,
          239,     6,  3045,     6,   178,  1136,    13,  1762,     9,
        16227,  1341,    10,  3841,  3045,     4,   178, 15690, 15142,
           11,    10,  1195,   251, 22847, 29471,    36,   100,   206,
            5, 18284,   139,  3106,   292,  6317,  6052,    43,  6233,
          576,    10,  7728,    31,     5,  4714,   219,  1732,  1525,
      