# 第一部分：使用指令数据对基底模型进行有监督微调
在本作业的第一部分，我们将使用Qwen2.5-0.5B基底模型以及alpaca指令数据集，体验如何对LLM做指令微调的训练。

> 关于Transformer的基本使用教程，可以参考官方推出的[LLM Course](https://huggingface.co/learn/llm-course/chapter2/3)。本次作业要求同学们手写训练代码，不能使用里面提供的Trainer API，关于如何使用PyTorch训练模型，可以参照[这个教程](https://huggingface.co/docs/transformers/v4.49.0/en/training#train-in-native-pytorch)。

> 对于使用Kaggle进行作业的同学，这里有一份[Kaggle基础使用](https://www.kaggle.com/code/cnlnpjhsy/kaggle-transformers)的简单教学供参考。

In [None]:
# 如果缺失必要的库，可以使用下面的命令安装
!pip install -U torch transformers datasets accelerate

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets

## 加载模型、tokenizer与数据集
本次作业，我们使用通义千问的Qwen2.5-0.5B预训练模型进行微调。对于在本地部署的同学，请事先将模型文件下载到本地；对于在kaggle上进行作业的同学，可以依照kaggle上的教程，将`MODEL_PATH`与`DATASET_PATH`修改为Input中的路径。

In [5]:
MODEL_PATH = "./Qwen2.5-0.5B"  # 请确保服务器当前目录下有 Qwen2.5-0.5B 文件夹
DATASET_PATH = "./alpaca-cleaned/alpaca_data_cleaned.json"

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype="auto")
print(model)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

dataset = datasets.Dataset.from_json(DATASET_PATH)
for sample in dataset.select(range(10)):    # 查看前10个样本。思考应该怎么将样本组织成单条完整文本？
    print(sample)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2

Qwen为基底模型也提供了对话模板（chat template），对话模板中含有一些特殊的token，可以帮助我们区分说话人的轮次（思考一下为什么要区分？）。我们可以直接以下述“轮次对话”的方式，构造一个样例文本。

In [15]:
tokenizer.apply_chat_template([
    {"role": "user", "content": "This is a question."},
    {"role": "assistant", "content": "I'm the answer!"}
], tokenize=False
)

"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nThis is a question.<|im_end|>\n<|im_start|>assistant\nI'm the answer!<|im_end|>\n"

可以看到每一轮次的对话都以`<|im_end|>`这个token结束。但是基底模型是没有在对话上经过优化的，它并不认得这个终止符。因此我们需要修改tokenizer的终止符，使其知道什么token代表一个对话轮次的结束。

In [13]:
print(tokenizer.eos_token)  # 原来的终止符
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
tokenizer.pad_token_id = tokenizer.eos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id  # 也要修改模型的终止符

<|endoftext|>


为了与训练后的模型做对比，我们先使用模型自带的generate方法测试一下这个基底模型会生成什么样的文本：

In [12]:
messages = [
    {"role": "user", "content": "Give me a brief introduction to Shanghai Jiao Tong University."},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
with torch.no_grad():
    lm_inputs_src = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
    generate_ids = model.generate(**lm_inputs_src, max_new_tokens=150, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
print(pred_str)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Shanghai Jiao Tong University (JTU) is a public university in Shanghai, China. It was founded in 1905 and is one of the oldest universities in China. The university has a rich history and is known for its academic excellence, research, and teaching. It offers a wide range of undergraduate and graduate programs in various fields, including engineering, business, and social sciences. The university also has a strong international presence and is home to several prestigious research centers and institutes. Overall, Shanghai Jiao Tong University is a highly regarded institution that has made significant contributions to the development of China's education and research.
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional
-transitional



## 处理数据集
原始的alpaca数据集是纯文本形式，而非模型能够接受的token。我们需要先将这些文本tokenize，再传给模型。

在指令微调阶段，我们常常希望模型只在模型要生成回答的部分上做优化，而不在问题文本上做训练，这需要我们特别设计传入的标签。请完成下述的`tokenize_function`函数，将数据集的指令样本tokenize，并传回输入模型的`input_ids`以及用于<b>仅在output部分计算损失</b>的标签`labels`。

In [14]:
import copy

def tokenize_function(sample):
    input_ids = None
    labels = None
    instruction = sample["instruction"]
    input_text = sample.get("input", "")
    output_text = sample["output"]

    if input_text:
        user_content = f"{instruction}\n\n{input_text}"
    else:
        user_content = instruction

    messages_full = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": output_text},
    ]
    full_text = tokenizer.apply_chat_template(messages_full, tokenize=False)

    messages_prompt = [{"role": "user", "content": user_content}]
    prompt_text = tokenizer.apply_chat_template(
        messages_prompt, tokenize=False, add_generation_prompt=True
    )

    full_ids = tokenizer(full_text, add_special_tokens=False)["input_ids"]
    prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]

    labels = copy.copy(full_ids)
    ignore_len = len(prompt_ids)
    if ignore_len > 0:
        labels[:ignore_len] = [-100] * ignore_len

    return {"input_ids": full_ids, "labels": labels}


tokenized_dataset = dataset.map(
    tokenize_function, remove_columns=dataset.column_names
).filter(
    lambda x: len(x["input_ids"]) <= 512,
)

定义一个DataLoader，用于从中获取模型能够处理的tokenized输入。  
> <b>【附加1】（3分）</b>通过从dataloader中成批取出数据，可以提升计算效率。你能够设计`collate_fn`，使之能以`batch_size > 1`的方式获取数据吗？

In [15]:
from torch.utils.data import DataLoader

def collate_fn(batch):
    input_ids = None
    attention_mask = None
    labels = None
    input_ids = None
    attention_mask = None
    labels = None

    input_ids_list = [sample["input_ids"] for sample in batch]
    labels_list = [sample["labels"] for sample in batch]

    max_len = max(len(ids) for ids in input_ids_list)

    padded_input_ids = []
    padded_labels = []
    attention_masks = []

    pad_id = tokenizer.pad_token_id

    for ids, lbs in zip(input_ids_list, labels_list):
        pad_len = max_len - len(ids)
        
        padded_ids = ids + [pad_id] * pad_len
        padded_input_ids.append(padded_ids)

        padded_lbs = lbs + [-100] * pad_len
        padded_labels.append(padded_lbs)

        attn_mask = [1] * len(ids) + [0] * pad_len
        attention_masks.append(attn_mask)

    input_ids = torch.tensor(padded_input_ids, dtype=torch.long)
    attention_mask = torch.tensor(attention_masks, dtype=torch.long)
    labels = torch.tensor(padded_labels, dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

    

train_dataloader = DataLoader(tokenized_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)  # Optimized for H200 (Batch 128)


## 训练模型
准备好tokenized后的数据后，就可以对模型进行训练了。请手动编写用于训练的循环，计算损失并反传。

在向model传入labels时，Transformer模型内部会自动计算损失；但为了让同学们理解损失的内部计算机制，我们要求**不向模型forward中传入labels，而是手动将模型的最终输出logits与labels相比对，并计算损失。**  
> <b>【附加1】</b>从dataloader中成批获取数据后，要将整个batch一次性输入到模型中（并非是使用循环逐个处理批次输入），获取所有样例的loss，并正确计算损失。

In [16]:
from tqdm.notebook import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
from accelerate.test_utils.testing import get_backend

step = 0
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

num_training_steps = len(train_dataloader) * 3  # 假设训练3个epoch
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

device, _, _ = get_backend()
model.to(device)

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(3):
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        batch = {k: v.to(model.device) for k, v in batch.items()}

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits          # [B, T, V]

        shift_logits = logits[:, :-1, :].contiguous()   # [B, T-1, V]
        shift_labels = labels[:, 1:].contiguous()       # [B, T-1]

        loss = loss_fn(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        step += 1
        if step % 100 == 0:
            print(f"Step {step}\t| Loss: {loss.item()}")

    model.save_pretrained(f"output/checkpoint-epoch-{epoch + 1}")
    tokenizer.save_pretrained(f"output/checkpoint-epoch-{epoch + 1}")

  0%|          | 0/38298 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/12766 [00:00<?, ?it/s]

Step 100	| Loss: 1.390625
Step 200	| Loss: 1.5
Step 300	| Loss: 1.1640625
Step 400	| Loss: 1.6171875
Step 500	| Loss: 1.5078125
Step 600	| Loss: 1.59375
Step 700	| Loss: 1.625
Step 800	| Loss: 1.3046875
Step 900	| Loss: 1.7265625
Step 1000	| Loss: 1.703125
Step 1100	| Loss: 1.8203125
Step 1200	| Loss: 1.40625
Step 1300	| Loss: 1.703125
Step 1400	| Loss: 1.3984375
Step 1500	| Loss: 1.3359375
Step 1600	| Loss: 0.98046875
Step 1700	| Loss: 1.7578125
Step 1800	| Loss: 1.4765625
Step 1900	| Loss: 1.5859375
Step 2000	| Loss: 1.8359375
Step 2100	| Loss: 1.2734375
Step 2200	| Loss: 1.5703125
Step 2300	| Loss: 1.6640625
Step 2400	| Loss: 1.2109375
Step 2500	| Loss: 1.1484375
Step 2600	| Loss: 1.9765625
Step 2700	| Loss: 1.5
Step 2800	| Loss: 1.046875
Step 2900	| Loss: 1.625
Step 3000	| Loss: 2.046875
Step 3100	| Loss: 0.7890625
Step 3200	| Loss: 1.1953125
Step 3300	| Loss: 1.5703125
Step 3400	| Loss: 0.94140625
Step 3500	| Loss: 1.65625
Step 3600	| Loss: 1.375
Step 3700	| Loss: 1.578125
Step 38

Epoch 2:   0%|          | 0/12766 [00:00<?, ?it/s]

Step 12800	| Loss: 1.46875
Step 12900	| Loss: 1.2578125
Step 13000	| Loss: 0.86328125
Step 13100	| Loss: 0.96484375
Step 13200	| Loss: 1.5078125
Step 13300	| Loss: 1.4453125
Step 13400	| Loss: 1.2578125
Step 13500	| Loss: 0.96484375
Step 13600	| Loss: 0.90234375
Step 13700	| Loss: 1.0703125
Step 13800	| Loss: 1.2421875
Step 13900	| Loss: 1.09375
Step 14000	| Loss: 0.78515625
Step 14100	| Loss: 0.93359375
Step 14200	| Loss: 1.3203125
Step 14300	| Loss: 0.81640625
Step 14400	| Loss: 1.328125
Step 14500	| Loss: 1.015625
Step 14600	| Loss: 1.0703125
Step 14700	| Loss: 1.1875
Step 14800	| Loss: 1.0234375
Step 14900	| Loss: 0.8515625
Step 15000	| Loss: 1.1953125
Step 15100	| Loss: 1.0078125
Step 15200	| Loss: 0.87890625
Step 15300	| Loss: 1.1796875
Step 15400	| Loss: 1.1953125
Step 15500	| Loss: 0.90625
Step 15600	| Loss: 0.91796875
Step 15700	| Loss: 0.53515625
Step 15800	| Loss: 0.99609375
Step 15900	| Loss: 1.15625
Step 16000	| Loss: 0.6640625
Step 16100	| Loss: 0.9765625
Step 16200	| Los

Epoch 3:   0%|          | 0/12766 [00:00<?, ?it/s]

Step 25600	| Loss: 0.62890625
Step 25700	| Loss: 1.0078125
Step 25800	| Loss: 0.86328125
Step 25900	| Loss: 0.9765625
Step 26000	| Loss: 1.3671875
Step 26100	| Loss: 0.55859375
Step 26200	| Loss: 1.2109375
Step 26300	| Loss: 1.046875
Step 26400	| Loss: 1.125
Step 26500	| Loss: 0.42578125
Step 26600	| Loss: 1.140625
Step 26700	| Loss: 1.0
Step 26800	| Loss: 0.953125
Step 26900	| Loss: 0.96484375
Step 27000	| Loss: 0.92578125
Step 27100	| Loss: 0.75
Step 27200	| Loss: 0.68359375
Step 27300	| Loss: 1.1796875
Step 27400	| Loss: 1.0
Step 27500	| Loss: 0.98828125
Step 27600	| Loss: 1.1484375
Step 27700	| Loss: 0.50390625
Step 27800	| Loss: 0.98046875
Step 27900	| Loss: 1.09375
Step 28000	| Loss: 1.03125
Step 28100	| Loss: 0.369140625
Step 28200	| Loss: 1.109375
Step 28300	| Loss: 0.7734375
Step 28400	| Loss: 1.0703125
Step 28500	| Loss: 1.3671875
Step 28600	| Loss: 1.1328125
Step 28700	| Loss: 1.328125
Step 28800	| Loss: 0.462890625
Step 28900	| Loss: 0.7734375
Step 29000	| Loss: 1.0390625
S

测试训练后的模型效果。如果训练正常，模型应当能回答出通顺的语句，并在回答结束后自然地停止生成。

In [25]:
sft_model = AutoModelForCausalLM.from_pretrained("output/checkpoint-epoch-3", device_map="auto", torch_dtype="auto")
messages = [
    {"role": "user", "content": "Give me a brief introduction to Shanghai Jiao Tong University."},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
with torch.no_grad():
    lm_inputs_src = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(sft_model.device)
    generate_ids = sft_model.generate(**lm_inputs_src, max_new_tokens=150, do_sample=False)
pred_str = tokenizer.decode(generate_ids[0][lm_inputs_src.input_ids.size(1):], skip_special_tokens=True)
print(pred_str)

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


Shanghai Jiao Tong University (Jiao Tong University) is a prestigious research university founded in 1921 in Shanghai, China. It is one of the top-tier universities in China and is known for its innovative and research-based programs in fields such as engineering, business, and social sciences. The university has a strong commitment to internationalization and has a large international student body. It is also home to several prestigious research centers and institutes, including the Center for Advanced Materials Research, the Center for International Development, and the Center for Chinese Language and Culture.


如果模型行为正常，就可以继续前往大作业的第二部分了！

# 第二部分：使用LLM做推理生成，并解码为自然文本
在这一部分，我们将体验LLM是如何逐token进行生成、并解码出自然文本的。我们需要手动实现一个`generate`函数，它能够直接接受用户的自然文本作为输入，并同样以自然文本回复。

In [5]:
MODEL_PATH = "output/checkpoint-epoch-3"    # 你训练好的模型路径

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
tokenizer.pad_token_id = tokenizer.eos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id

`torch_dtype` is deprecated! Use `dtype` instead!


## 实现generate
请实现下述的generate函数，手动进行模型推理、生成与解码。

这个generate函数至少能够接受一个字符串`query`作为输入，限制最大生成token数`max_new_tokens`，并用`do_sample`选择是采用采样还是贪婪搜索进行生成。在使用采样策略生成时，允许设置基础的采样生成参数`temperature`、`top_p`和`top_k`。关于不同的生成策略是如何工作的，可以学习这篇[博客](https://huggingface.co/blog/how-to-generate)。  
**禁止使用模型自带的`model.generate`方法！**

> <b>附加2（3分）</b>你能够利用模型的批次输入特性（并非是使用循环逐个处理批次输入），成批次地输入文本、并同时生成新token吗？此时`query`应该可以接受一个字符串列表作为输入。

> <b>附加3（3分）</b>束搜索（Beam search）允许在解码过程中保留数个次优序列，通过生成过程中维护这些序列，模型能够生成整体更为合理的句子，改善了贪婪搜索中可能会陷入局部最优的问题。你可以在已有的贪婪搜索与采样两种生成策略的基础上实现束搜索吗？此时`num_beams`应允许大于1的值。  
关于束搜索，这里有一个[可视化Demo](https://huggingface.co/spaces/m-ric/beam_search_visualizer)演示其运作机理。

In [None]:
import torch
from torch.nn.functional import softmax
from typing import Union, List
import numpy as np
from torch.nn.functional import softmax, log_softmax

def top_p_top_k_filtering(
    logits: torch.Tensor,
    top_p: float = 0.9,
    top_k: int = 50,
) -> torch.Tensor:
    """
    对 logits 应用 top-p 和 top-k 过滤。
    Args:
        logits: 模型输出的 logits，形状为 [B, V]
        top_p: top-p 参数
        top_k: top-k 参数
    Returns:
        过滤后的 logits，形状为 [B, V]
    """
    # top-k
    if top_k > 0:
        top_k_vals, top_k_idx = torch.topk(logits, top_k, dim=-1)
        mask = torch.full_like(logits, float("-inf"))
        mask.scatter_(1, top_k_idx, top_k_vals)
        logits = mask

    probs = softmax(logits, dim=-1)

    # top-p (nucleus) 
    if top_p < 1.0:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        old = sorted_indices_to_remove.clone()
        sorted_indices_to_remove[..., 1:] = old[..., :-1]
        sorted_indices_to_remove[..., 0] = 0
        # set the p that exceed top_p to zero
        probs.scatter_(
            1,
            sorted_indices,
            sorted_probs * (~sorted_indices_to_remove)
        )
        probs = probs / probs.sum(dim=-1, keepdim=True)
        
    return probs
def generate(
    model,
    query: Union[str, List[str]],
    max_new_tokens: int = 1024,
    do_sample: bool = False,
    temperature: float = 1.0,
    top_p: float = 0.9,
    top_k: int = 50,
    num_beams: int = 1,
    length_penalty: float = 1.0,
) -> Union[str, List[str]]:
    
    """
    使用模型model进行文本生成。
    Args:
        model: 用于生成的语言模型
        query: 用户输入的查询。可以是单个字符串，或者是一个字符串列表
        max_new_tokens: 生成的最大新token数量
        do_sample: 是否使用采样生成文本。仅当为True时，后续的temperature、top_p、top_k参数才会生效
        temperature: 采样时的温度参数
        top_p: 采样时的top-p参数
        top_k: 采样时的top-k参数
        num_beams: 束搜索同时维护的束的数量。仅当`num_beams > 1`时，才会启用束搜索
        length_penalty: 启用束搜索时的长度惩罚系数
    Returns:
        生成的文本。如果输入是单个字符串，则返回单个字符串；如果输入是字符串列表，则返回字符串列表
    """
    
    global tokenizer

    if isinstance(query, str):
        queries = [query]
        single_input = True
    elif isinstance(query, list):
        queries = query
        single_input = False
    else:
        raise ValueError("query must be a string or a list of strings")

    device = model.device

    messages_list = [[{"role": "user", "content": q}] for q in queries]
    texts = [
        tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
        for m in messages_list
    ]

    enc = tokenizer(
        texts,
        add_special_tokens=False,
        return_tensors="pt",
        padding=True,
    )
    input_ids = enc["input_ids"].to(device)          # [B, T]
    attention_mask = enc["attention_mask"].to(device)

    batch_size = input_ids.size(0)

    
    generated = input_ids

    if num_beams > 1 and not do_sample:
        generated = input_ids.repeat_interleave(num_beams, dim=0)
        attention_mask = attention_mask.repeat_interleave(num_beams, dim=0)
        
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=device)
        beam_scores[:, 1:] = -1e9 
        
        beam_scores = beam_scores.view(-1) 

        vocab_size = -1 

        with torch.no_grad():
            for step in range(max_new_tokens):
                outputs = model(input_ids=generated, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :] 
                
                if vocab_size == -1:
                    vocab_size = next_token_logits.shape[-1]

                next_token_scores = log_softmax(next_token_logits, dim=-1) # [B * num_beams, V]

                next_scores = beam_scores.unsqueeze(-1) + next_token_scores

                next_scores = next_scores.view(batch_size, num_beams * vocab_size)

                topk_scores, topk_indices = torch.topk(next_scores, num_beams, dim=1, largest=True, sorted=True)

                beam_indices = topk_indices // vocab_size  # [B, num_beams]
                token_indices = topk_indices % vocab_size  # [B, num_beams]

                beam_scores = topk_scores.view(-1) # Flatten back to [B * num_beams]

                batch_offset = (torch.arange(batch_size, device=device) * num_beams).unsqueeze(1)
                
                global_beam_indices = (beam_indices + batch_offset).view(-1) # [B * num_beams]

                generated = generated[global_beam_indices]
                attention_mask = attention_mask[global_beam_indices]

                new_tokens = token_indices.view(-1, 1)
                generated = torch.cat([generated, new_tokens], dim=-1)
                
                attention_mask = torch.cat(
                    [attention_mask, torch.ones((batch_size * num_beams, 1), device=device)], 
                    dim=1
                )

                #继续生成直到 max_tokens
                if (new_tokens.squeeze() == tokenizer.eos_token_id).all():
                    break

        # beam_scores 形状 [B * num_beams]
        final_scores = beam_scores.view(batch_size, num_beams)
        
        current_len = generated.shape[1]
        final_scores = final_scores / (current_len ** length_penalty)

        best_beam_idx = torch.argmax(final_scores, dim=-1) # [B]

        final_generated = []
        for i in range(batch_size):
            global_idx = i * num_beams + best_beam_idx[i].item()
            final_generated.append(generated[global_idx])
        
        generated = final_generated
    else:
        with torch.no_grad():
            for _ in range(max_new_tokens):
                outputs = model(input_ids=generated, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]          # 取每个样本最后一个位置的 logits，[B, V]

                if do_sample:
                    logits = logits / temperature


                    probs = top_p_top_k_filtering(
                        logits,
                        top_p=top_p,
                        top_k=top_k,
                    )
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)  # [B]
                else:
                    next_tokens = torch.argmax(logits, dim=-1)  # [B]


                next_tokens_unsqueezed = next_tokens.unsqueeze(1)  # [B, 1]
                generated = torch.cat([generated, next_tokens_unsqueezed], dim=1)  # [B, T+1]
                attention_mask = torch.cat(
                    [attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=device)],
                    dim=1,
                )

                # break if all sequences have generated eos
                if (next_tokens == tokenizer.eos_token_id).all():
                    break

    gen_texts = []
    for i in range(batch_size):
        seq = generated[i]
        prompt_len = enc["input_ids"][i].size(0)
        new_tokens = seq[prompt_len:]
        if tokenizer.eos_token_id in new_tokens:
            eos_pos = (new_tokens == tokenizer.eos_token_id).nonzero(as_tuple=True)[0][0].item()
            new_tokens = new_tokens[:eos_pos]
        text = tokenizer.decode(new_tokens, skip_special_tokens=True)
        gen_texts.append(text)

    return gen_texts[0] if single_input else gen_texts



## 测试generate的效果
请同学们运行下述单元格，测试你的实现。除了下面提到的句子，同学们也可以自定义更多情况下的输入文本，探究模型在面对不同输入时采用不同解码策略的表现。

In [9]:
print("#1 贪心解码")
query1 = ["Give me a brief introduction to Shanghai Jiao Tong University.", "介绍一下上海交通大学。", "What is the capital of China?"]
# 如果没有实现附加2，请用循环的方式依次解码query1里的每个字符串并打印出来
for i, response in enumerate(generate(model, query1, max_new_tokens=256, do_sample=False)):
    print(f"[{i}] 问：{query1[i]}\n答：{response}")

print("\n#2 采样解码")
query2 = "Tell me a joke about computers."
for i in range(5):
    response = generate(model, query2, do_sample=True, temperature=0.7, top_p=0.9, top_k=50)    # 可以试试调整这些采样超参数
    print(f"[{i}] 问：{query2}\n答：{response}")

print("\n#3 【附加3】束搜索解码")
query3 = "What is the sum of the first 100 natural numbers? Please think step by step."
response = generate(model, query3, num_beams=1, length_penalty=1.0)
print(f"问：{query3}\n答：{response}")

print("\n========== 补充测试 ==========")

print("\n### 1. Decoding Strategy Comparison ###")
test_query = "Write a very short story about a robot who wants to be a chef."

print("[Greedy]")
print(generate(model, test_query, max_new_tokens=256, do_sample=False))

print("\n[Sampling Temp=0.7]")
print(generate(model, test_query, max_new_tokens=256, do_sample=True, temperature=0.7))

print("\n[Sampling Temp=1.5]")
print(generate(model, test_query, max_new_tokens=256, do_sample=True, temperature=1.5))

print("\n[Beam Search (Beams=4)]")
try:
    print(generate(model, test_query, max_new_tokens=256, num_beams=4))
except Exception as e:
    print(f"Beam search failed or not fully implemented: {e}")

print("\n### 2. Knowledge & Logic Checks ###")
queries = [
    "What is 12 + 8 * 2?", 
    "Who wrote the play Romeo and Juliet?",
    "Tell me about the history of the United States of Antarctica.", # Hallucination check
    "Translate 'Hello' into French."
]

for q in queries:
    print(f"\nQ: {q}")
    print(f"A: {generate(model, q, max_new_tokens=256, do_sample=False)}")


#1 贪心解码
[0] 问：Give me a brief introduction to Shanghai Jiao Tong University.
答：Shanghai Jiao Tong University (Jiao Tong University) is a prestigious research university founded in 1921 in Shanghai, China. It is one of the top-tier universities in China and is known for its innovative and research-based programs in fields such as engineering, business, and social sciences. The university has a strong commitment to internationalization and has a large international student body. It is also home to several prestigious research centers and institutes, including the Center for Advanced Materials Research, the Center for International Development, and the Center for Chinese Language and Culture.
[1] 问：介绍一下上海交通大学。
答：
上海交通大学，原名复旦交通大学，是一所以工、理、管、法、工、农、管、经、管、工、经、法、商、管理、艺术等门类为特色的综合型研究型大学。它于1922年在上海成立，1927年更名复旦，1936年复旦附设，1940年复旦附设，1946年复旦附设，1950年复旦附设，1953年复旦附设，1957年复旦附设，1960年复旦附设，1963年复旦附设，1967年复旦附设，1970年复旦附设，1973年复旦附设，1978年复旦附设，1983年复旦附设，1988年复旦附设，1993年复旦附设，1998年复旦附设，2004年复旦附设，2009年复旦附设，2014
[2] 问