In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model and tokenizer
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=".cache")
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=".cache")
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [2]:
def top_k_token(prompt, k=5):
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits[:, -1, :]
    topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
    top_tokens = tokenizer.batch_decode(topk_indices[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    # print(topk_indices[0][:k])
    
    return top_tokens, topk_values[0]

In [3]:
prompt = "1. 子"
top_k = 10

top_tokens, top_values = top_k_token(prompt, k=top_k)
for i, (token, score) in enumerate(zip(top_tokens, top_values)):
    print(f"{i+1}. {token} (logit: {score.item():.4f})")

1. 宫 (logit: 16.9809)
2. 宮 (logit: 15.8867)
3. 女 (logit: 14.2693)
4. � (logit: 13.2507)
5. � (logit: 13.2188)
6. 午 (logit: 12.5235)
7. 母 (logit: 12.4848)
8. � (logit: 12.1210)
9. 孙 (logit: 12.1109)
10. 分 (logit: 11.9912)


In [4]:
# string_data = "1. 子曰：「學而時習之，不亦說乎？有朋自遠方來，不亦樂乎？人不知而不慍，不亦君子乎？」"
# string_data = "蓋聞天地之數，有十二萬九千六百歲為一元。將一元分為十二會，乃子、丑、寅、卯、辰、巳、午、未、申、酉、戌、亥之十二支也。"
string_data = "Alice was beginning to get very tired of sitting by her sister on the bank"
# string_data = "En 1815, M. Charles-François-Bienvenu Myriel était évêque de Digne. C'était un vieillard d'environ soixante-quinze ans; il occupait le siège de Digne depuis 1806."
top_k = 10

stored_data = []
prompt = ""
i = 0
while i < len(string_data):
    print(f"\n\nCurrent index: {i}")
    print(f"Prompt ({i}): {prompt}")
    print(f"stored_data ({i}): {stored_data}")
    if prompt == "":
        prompt = string_data[i]
        stored_data.append(string_data[i])
        i += 1
    else:
        top_tokens, top_values = top_k_token(prompt, k=top_k)
        for j, (token, score) in enumerate(zip(top_tokens, top_values)):
            print(f"{j+1}. '{token}' (logit: {score.item():.4f})")
            token_len = len(token)
            print(f"\tToken length: {token_len}")
            print(f"\tComparing with string_data[{i}:{i+token_len}] = '{string_data[i:i+token_len]}'")
            if token == string_data[i:i+token_len]:
                stored_data.append(j)
                prompt += token
                i += token_len - 1
                break
            elif j == top_k - 1:
                print(f"Character '{string_data[i]}' not found in top {top_k} tokens.")
                stored_data.append(string_data[i])
                prompt += string_data[i]
        i += 1
        # if i == 6:
        #     break

            




Current index: 0
Prompt (0): 
stored_data (0): []


Current index: 1
Prompt (1): A
stored_data (1): ['A']
1. ' ' (logit: 16.7928)
	Token length: 1
	Comparing with string_data[1:2] = 'l'
2. ' new' (logit: 15.5755)
	Token length: 4
	Comparing with string_data[1:5] = 'lice'
3. ' few' (logit: 14.8041)
	Token length: 4
	Comparing with string_data[1:5] = 'lice'
4. ' New' (logit: 14.2138)
	Token length: 4
	Comparing with string_data[1:5] = 'lice'
5. ' man' (logit: 14.1144)
	Token length: 4
	Comparing with string_data[1:5] = 'lice'
6. '.' (logit: 13.9978)
	Token length: 1
	Comparing with string_data[1:2] = 'l'
7. 'str' (logit: 13.8647)
	Token length: 3
	Comparing with string_data[1:4] = 'lic'
8. ' lot' (logit: 13.8384)
	Token length: 4
	Comparing with string_data[1:5] = 'lice'
9. 'frican' (logit: 13.8226)
	Token length: 6
	Comparing with string_data[1:7] = 'lice w'
10. ' couple' (logit: 13.7873)
	Token length: 7
	Comparing with string_data[1:8] = 'lice wa'
Character 'l' not found in top 10 t

In [7]:
print(f"\nFinal store_data: {stored_data}")
print(f"original string_data: {string_data}")


Final store_data: ['A', 'l', 'i', 'c', 'e', ' ', 'w', 'a', 's', 4, 'b', 'e', 'g', 9, 'n', 2, 'g', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
original string_data: Alice was beginning to get very tired of sitting by her sister on the bank
