当前目录下的modeling_llama.py被我进行了一点修改，它产生的kv_cache不包含位置信息，所以理论上可以进行kv_cache拼接。
modeling_llama.py中被我修改的行都被注释`# (ddw)`标记。

下面用一些简单输入对修改的模型和原版模型进行对比。看起来效果还可以，需要进一步测试。

在总结三句话故事的任务中，修改模型用到了大部分信息，而原版模型只用了第一句话的信息。我把这视为修改过的模型拼接kv仍能提供位置信息的证据。

In [2]:
import torch
from torch import LongTensor, FloatTensor

# from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import AutoTokenizer
from transformers import LlamaForCausalLM as OriginalLlamaForCausalLM
from modeling_llama import LlamaForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
checkpoint = "HuggingFaceTB/SmolLM2-360M-Instruct"

model0 = OriginalLlamaForCausalLM.from_pretrained(checkpoint)
model = LlamaForCausalLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [4]:
def prefill(model: LlamaForCausalLM, input_ids: LongTensor) -> tuple[tuple[FloatTensor]]:
    return model(input_ids).past_key_values

def kv_cat(kv1: tuple[tuple[FloatTensor]], kv2: tuple[tuple[FloatTensor]]) -> tuple[tuple[FloatTensor]]:
    # print(kv1[0][0].shape, kv2[0][0].shape)
    n_layers = len(kv1)
    
    ret = tuple()
    for i in range(n_layers):
        k = torch.cat((kv1[i][0], kv2[i][0]), dim=2)
        v = torch.cat((kv1[i][1], kv2[i][1]), dim=2)
        ret = ret + ((k,v),)
        
    # print(ret[0][0].shape)
    return ret

def kv_slice(kv: tuple[tuple[FloatTensor]], l: int, r: int) -> tuple[tuple[FloatTensor]]:
    n_layers = len(kv)
    
    ret = tuple()
    for i in range(n_layers):
        k = kv[i][0][:,:,l:r,:]
        v = kv[i][1][:,:,l:r,:]
        ret = ret + ((k,v),)

    return ret

def kv_len(kv: tuple[tuple[FloatTensor]]) -> int:
    return kv[0][0].shape[-2]
    

# def greedy_search(model: LlamaForCausalLM, 
#                 input_ids: LongTensor,
#                 past_key_values: tuple[tuple[FloatTensor]],
#                 eos_token_id: int,
#                 max_length: int = 10) -> tuple[LongTensor, tuple[tuple[FloatTensor]]]:
#     """
#     past_key_values may be None. The length of kv cache must less than or equal to 
#     the length of input_ids.

#     Return (output_ids, new_key_values)
#     """
#     if past_key_values is None:
#         past_key_values = model(input_ids).past_key_values

#     last_id = input_ids[0][-1]
#     while len(input_ids[0]) < max_length and last_id != eos_token_id:
#         model_outputs = model(LongTensor([[last_id]]), past_key_values=past_key_values)
#         logits = model_outputs.logits
#         past_key_values = model_outputs.past_key_values
#         # print((input_ids.shape, torch.argmax(logits, dim=2)[:,-1].unsqueeze(0).shape))
#         input_ids = torch.cat((input_ids, torch.argmax(logits, dim=2)[:,-1].unsqueeze(1)), dim=1)
#         last_id = input_ids[0][-1]

#     return input_ids, past_key_values


# This function is largely due to ChatGPT. I only made a few changes.
def greedy_search(model: LlamaForCausalLM, 
                  input_ids: LongTensor,
                  past_key_values: tuple[tuple[FloatTensor]] = None,
                  eos_token_id: int = 2,
                  max_length: int = 10) -> tuple[LongTensor, tuple[tuple[FloatTensor]]]:
    """
    Generates text using greedy search. Selects the most probable token at each step.
    
    Parameters:
    - model: The causal language model to generate text with.
    - input_ids: The input prompt tokens.
    - past_key_values: Cache from previous forward pass (if available).
    - eos_token_id: ID of the EOS token.
    - max_length: Maximum number of tokens to generate.
    
    Returns:
    - Tuple of output_ids (generated tokens) and new_key_values (updated cache).
    """
    
    input_len = len(input_ids[0])
    if past_key_values:
        input_kv_len = kv_len(past_key_values)
    else:
        input_kv_len = 0

    assert input_kv_len <= input_len

    if input_kv_len == input_len:
        past_key_values = kv_slice(past_key_values, 0, input_kv_len - 1)
        input_kv_len -= 1

    # Initialize output with the initial input_ids
    output_ids = input_ids
    new_key_values = past_key_values  # To store updated past_key_values during generation

    # Loop until max_length is reached or EOS token is generated
    for _ in range(max_length):

        # Forward pass through the model
        # If past_key_values is None, we start without cache
        outputs = model(input_ids=output_ids[:, input_kv_len - input_len:],  # Only feed the last token to avoid recalculating
                        past_key_values=new_key_values,
                        use_cache=True)
        
        # Update cache with new key/values
        new_key_values = outputs.past_key_values

        input_kv_len = kv_len(new_key_values)
        input_len += 1

        # Get the logits for the last generated token
        logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
        
        # Select the token with the highest probability (greedy choice)
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)  # Shape: (batch_size, 1)
        
        # Append to the generated sequence
        output_ids = torch.cat([output_ids, next_token_id], dim=-1)
        
        # Stop if EOS token is generated
        if next_token_id.item() == eos_token_id:
            break
    
    return output_ids, new_key_values


In [5]:
messages = [{"role": "user", "content": "What is the capital of France."}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of France.<|im_end|>



ChatGPT提供的测试用例。

Sure! Here’s a story you can use for testing your chatbot’s summarization abilities. It has a mix of characters, events, and a bit of conflict, making it a good test for comprehension.

---

**Title:** *The Lost City of Aurelia*

In the year 1897, a young archaeologist named Elara Thorne set off on an expedition to find the fabled Lost City of Aurelia, an ancient civilization rumored to be hidden deep within the Amazon rainforest. Elara, inspired by her grandfather’s stories of mythical cities, had spent years studying maps, analyzing ancient texts, and gathering clues. Along with a small team of explorers and a local guide named Mateo, she embarked on a journey filled with mystery and danger.

The team ventured through treacherous terrain—raging rivers, dense forests, and poisonous creatures—to uncover any signs of the lost city. After weeks of struggle, they finally came across a series of stone markers etched with unfamiliar symbols. Elara recognized them from her grandfather’s notes and realized they were close.

But as they ventured deeper, the team encountered a band of mercenaries led by an infamous treasure hunter named Victor Blackwood. Blackwood had been searching for Aurelia for years, driven by greed and a desire for fame. He threatened Elara and her team, demanding they share any information about the city’s location. Forced to comply but unwilling to give up, Elara secretly left false clues, leading Blackwood’s team in the wrong direction.

Finally, after several tense days, Elara and her group arrived at a hidden waterfall. Behind it lay the entrance to a vast underground city, with towering stone pillars and golden statues—Aurelia at last! The city was everything Elara had dreamed of: a place of incredible beauty and history, with intricate carvings, forgotten technology, and treasures beyond imagination.

But their joy was short-lived. Blackwood and his mercenaries found their way back, and a confrontation ensued. In a desperate attempt to protect the city, Elara and her team triggered a hidden mechanism that caused parts of the city to collapse, sealing off the main chambers and preserving Aurelia’s secrets. Blackwood, narrowly escaping, fled empty-handed.

Elara left the jungle with only a few artifacts and her memories of the city. She knew that Aurelia’s true wealth was its history, and she wanted it to remain unspoiled. Returning to her university, she shared her findings with scholars but never revealed the city’s location, keeping its mysteries safe from the world.

---

This story has plenty of details for testing a summarizer: character motivations, a sequence of events, conflict, and a resolution. Good luck with your testing!

In [None]:
def wrap_doc(doc: str) -> str:
    return f"""<|im_start|>system
This is a piece of information that may be useful in question answering: {doc}<|im_end|>
"""


docs = [
    # "Former Arkansas governor and Baptist minister Mike Huckabee has been named US ambassador to Israel.",
    # "John Ratcliffe, a former Texas congressman and former director of national intelligence, has been nominated as the next CIA director.",
    # "Fox News commentator and army veteran Pete Hegseth is Trump's pick to be the next US secretary of defence.",
    
    # "David Du is a master student of computer technology at Nankai University.",
    # "David Du is a 22-year-old chinese national living in Tianjin, China.",
    # "David Du is currently working on LLM inference systems.",

    "In the year 1897, a young archaeologist named Elara Thorne set off on an expedition to find the fabled Lost City of Aurelia, an ancient civilization rumored to be hidden deep within the Amazon rainforest.",
    "Along with a small team of explorers and a local guide named Mateo, she embarked on a journey filled with mystery and danger.",
    "Elara left the jungle with only a few artifacts and her memories of the city. ",
]
docs = [wrap_doc(d) for d in docs]
# docs = ["".join(docs)]

prompt = """
<|im_start|>system
You are a helpful AI assistant named SmolLM. Please answer questions basing on the above infos.<|im_end|>
<|im_start|>user
Who has been named US ambassador to Israel?<|im_end|>
"""

prompt = """
<|im_start|>system
You are a helpful AI assistant named SmolLM. Please answer questions basing on the above infos.<|im_end|>
<|im_start|>user
Tell me everything you know about David Du.<|im_end|>
"""

prompt = """
<|im_start|>system
You are a helpful AI assistant named SmolLM. Please answer questions basing on the above infos.<|im_end|>
<|im_start|>user
Briefly Summarize the story of Elara.<|im_end|>
"""

kv = None
concated_input_ids = None
for d in docs:
    input = tokenizer(d, return_tensors="pt").input_ids
    kv = prefill(model, input) if kv is None else kv_cat(kv, prefill(model, input))
    concated_input_ids = input if concated_input_ids is None else torch.cat((concated_input_ids, input), dim=1)

input = tokenizer(prompt, return_tensors="pt").input_ids
# kv = kv_cat(kv, prefill(model, input))
concated_input_ids = torch.cat((concated_input_ids, input), dim=1)

print(kv_len(kv), concated_input_ids.shape[1])
# kv = None

print("Modified model:")
generate_ids, _ = greedy_search(model, concated_input_ids, kv, tokenizer.eos_token_id, 100)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(outputs[0])


148 191
Modified model:
system
This is a piece of information that may be useful in question answering: In the year 1897, a young archaeologist named Elara Thorne set off on an expedition to find the fabled Lost City of Aurelia, an ancient civilization rumored to be hidden deep within the Amazon rainforest.
system
This is a piece of information that may be useful in question answering: Along with a small team of explorers and a local guide named Mateo, she embarked on a journey filled with mystery and danger.
system
This is a piece of information that may be useful in question answering: Elara left the jungle with only a few artifacts and her memories of the city. 

system
You are a helpful AI assistant named SmolLM. Please answer questions basing on the above infos.
user
Briefly Summarize the story of Elara.
assistant
Elara is a young archaeologist who embarks on a perilous journey to find the fabled Lost City of Aurelia, a city rumored to be hidden deep within the Amazon rainforest. 

In [12]:
# test original model (model0) in this block

kv = None
concated_input_ids = None
for d in docs:
    input = tokenizer(d, return_tensors="pt").input_ids
    kv = prefill(model0, input) if kv is None else kv_cat(kv, prefill(model0, input))
    concated_input_ids = input if concated_input_ids is None else torch.cat((concated_input_ids, input), dim=1)

input = tokenizer(prompt, return_tensors="pt").input_ids
# kv = kv_cat(kv, prefill(model0, input))
concated_input_ids = torch.cat((concated_input_ids, input), dim=1)

print(kv_len(kv), concated_input_ids.shape[1])
# kv = None

print("Original model:")
generate_ids, _ = greedy_search(model0, concated_input_ids, kv, tokenizer.eos_token_id, 150)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(outputs[0])


148 191
Original model:
system
This is a piece of information that may be useful in question answering: In the year 1897, a young archaeologist named Elara Thorne set off on an expedition to find the fabled Lost City of Aurelia, an ancient civilization rumored to be hidden deep within the Amazon rainforest.
system
This is a piece of information that may be useful in question answering: Along with a small team of explorers and a local guide named Mateo, she embarked on a journey filled with mystery and danger.
system
This is a piece of information that may be useful in question answering: Elara left the jungle with only a few artifacts and her memories of the city. 

system
You are a helpful AI assistant named SmolLM. Please answer questions basing on the above infos.
user
Briefly Summarize the story of Elara.
assistant
Elara is a young archaeologist who embarks on an expedition to find the fabled Lost City of Aurelia, an ancient civilization rumored to be hidden deep within the Amazon 