# SFT by PyTorch

本代码基于数据构造 `lecture/lc6_sft/Supervised_Finetuning_Dataset.ipynb`  进行训练。

并且本代码可以在 colab/CPU/GPU/Mac 上进行微调训练。本代码在 mac 运行，所微调模型能正常对话。

本代码自定义 dataset 和 训练函数，训练成功后能进行文本生成。

关于 `transfomrers` 库的使用，建议以[官网文档](https://huggingface.co/docs/transformers/index)为主，查阅接口

In [69]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-0.6B', 
                                          local_dir='~/.cache/huggingface/', # 如果可以直连 huggingface, 去除此行.
                                         )

## 调用现成 Pretrained 模型, 进行全参微调
model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen3-0.6B',
                                             local_files_only=True, )
# print(model)

In [70]:
# Forward
dummy_input_ids = torch.randint(100,(2,3))

output = model(input_ids=dummy_input_ids)
print(output.logits.shape)

torch.Size([2, 3, 151936])


In [71]:
# Genreation

input_ids = tokenizer(['如何学习python?'], return_tensors='pt')['input_ids']
print(intput_ids)
output_ids = model.generate(
    input_ids, 
    max_new_tokens=128,
    do_sample=False,
    temperature=1.0
)
result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(result)

tensor([[100007, 100134,  12669,     30]])
如何学习python? 有哪些资源推荐？ 有哪些书籍推荐？ 有哪些在线课程？ 有哪些工具推荐？ 有哪些社区？ 有哪些学习计划？ 有哪些学习方法？ 有哪些学习资源？ 有哪些学习工具？ 有哪些学习社区？ 有哪些学习计划？ 有哪些学习方法？ 有哪些学习资源？ 有哪些学习工具？ 有哪些学习社区？ 有哪些学习计划？ 有哪些学习方法？ 有哪些学习资源？ 有哪些学习工具？ 有哪些学习社区？ 有哪些学习计划？ 有哪些学习方法？ 有哪些学习资源？ 有哪些学习工具？ 有哪些学习社区？ 有哪些学习计划




1. 预训练模型一般出现重复输出或者是不停止的情况，常被称为“复读机”现象
3. 原因在于预训练数据不是每条都有 <EOS> 的
4. 思考：有什么方法可以使得预训练模型按照我们期望能够正常回答问题(能够有逻辑输出并能预测 eos)

以下代码将基于 “指令数据” 微调预训练模型，使微调后的模型具有 “指令跟随能力”，即能正常回答“通用”问题。

## dataset

In [72]:
from datasets import load_dataset
dataset = load_dataset('tatsu-lab/alpaca',
                       # data_dir='default',
                      cache_dir="~/.cache/huggingface",)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 52002
    })
})


In [73]:
print('<instruction>',dataset['train'][0]['instruction'])
print('<input>',dataset['train'][0]['input'])
print('<output>',dataset['train'][0]['output'])

<instruction> Give three tips for staying healthy.
<input> 
<output> 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. 
2. Exercise regularly to keep your body active and strong. 
3. Get enough sleep and maintain a consistent sleep schedule.


In [74]:
DEFINED_EOS_TOKEN = '<|im_end|>'
DEFINED_SOS_TOKEN = '<|im_start|>'
DEFINED_PAD_TOKEN = '<|endoftext|>'
DEFINIED_SYSTEM_PROMPT='你是小冬瓜智能体,请安全详细回答用户 USER 的问题'

In [75]:
def dataset_to_messageslist(dataset):
    messages = []
    for item in dataset['train']:
        messages_inst = [
            {'role':'SYSTEM', 'content':DEFINIED_SYSTEM_PROMPT},
            {'role':'USER', 'content': item['instruction'] + item['input'] },
            {'role':'ASSISTANT', 'content': item['output'] },
        ]
        messages.append(messages_inst)
    return messages
dataset_instruction = dataset_to_messageslist(dataset)

In [76]:
def ChatTemplateToken(example, tokenizer):
    sos_token_id = tokenizer(DEFINED_SOS_TOKEN).input_ids[0] 
    eos_token_id = tokenizer(DEFINED_EOS_TOKEN).input_ids[0] 
    
    input_ids = [ sos_token_id ]
    is_labels = [ 0 ]
    
    for i, item in enumerate(example):
        if item['role'] == 'ASSISTANT':
            prompt = '\n#' + item['role'] + ':'
            content_prompt = item['content'] 
            prompt_token_ids = tokenizer(prompt).input_ids
            content_prompt = tokenizer(content_prompt).input_ids

            is_labels += [0]*len(prompt_token_ids) + [1]*len(content_prompt) + [1] # last [1] is eos 
            input_ids += prompt_token_ids + content_prompt + [eos_token_id]
            
        else:
            prompt = '\n#' + item['role'] + ':' + item['content'] 
            prompt_token_ids = tokenizer(prompt).input_ids
            input_ids += prompt_token_ids
            
            is_labels += [0]*len(prompt_token_ids)
            
    return input_ids, is_labels
    
print(dataset_instruction[0])
prompt, is_label = ChatTemplateToken(dataset_instruction[0], tokenizer)
print(prompt)
print(is_label)
print(tokenizer.decode(prompt))

[{'role': 'SYSTEM', 'content': '你是小冬瓜智能体,请安全详细回答用户 USER 的问题'}, {'role': 'USER', 'content': 'Give three tips for staying healthy.'}, {'role': 'ASSISTANT', 'content': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.'}]
[151644, 198, 2, 46487, 25, 105043, 30709, 99949, 100857, 100168, 31914, 11, 14880, 99464, 100700, 102104, 20002, 13872, 43589, 86119, 198, 2, 6448, 25, 35127, 2326, 10414, 369, 19429, 9314, 13, 198, 2, 4939, 3846, 2821, 25, 16, 5142, 266, 264, 23831, 9968, 323, 1281, 2704, 311, 2924, 11260, 315, 25322, 323, 23880, 13, 715, 17, 13, 32818, 15502, 311, 2506, 697, 2487, 4541, 323, 3746, 13, 715, 18, 13, 2126, 3322, 6084, 323, 10306, 264, 12966, 6084, 9700, 13, 151645]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [77]:
import torch
from torch.utils.data import Dataset, DataLoader

class TokenSFTDataset(Dataset):
    def __init__(self, messages_list, tokenizer):
        data_list = [ ChatTemplateToken(messages, tokenizer) for messages in messages_list ]
        self.data = []
        for data in data_list:
            self.data.append( 
                [torch.tensor(data[0], dtype=torch.long), 
                 torch.tensor(data[1], dtype=torch.long)]
            )
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        token 返回数据一般是 tensor
        """
        return {
            'input_ids': self.data[idx][0],
            'is_label':self.data[idx][1]
        }



In [78]:
def paddding_collate_fn(batch_data, pad_token_id=None, ignore_index=-100):

    input_lens = []
    label_lens = []
    bs = len(batch_data)

    # padding longest maxlen
    for data in batch_data:
        input_lens.append(data['input_ids'].shape[0])
        max_input_len = torch.max(torch.tensor(input_lens, dtype=torch.long))
    
    # Right Padding
    input_ids = torch.ones(bs, max_input_len, dtype=torch.long) * pad_token_id
    attention_masks = torch.zeros(bs, max_input_len, dtype=torch.long)
    labels = torch.ones(bs, max_input_len, dtype=torch.long) * ignore_index

    for i in range(bs):
        input_ids[i, :input_lens[i]] = batch_data[i]['input_ids']
        attention_masks[i, :input_lens[i]] = 1
        
        idx = torch.where( batch_data[i]['is_label'] != 0)[0]
        labels[i, idx-1] = batch_data[i]['input_ids'][idx]

    return {
        'input_ids': input_ids,
        'attention_masks': attention_masks,
        'labels': labels,
    }

class PaddingCollateFunction:
    def __init__(self, pad_token_id: int, ignore_index: int):
        self.pad_token_id = pad_token_id
        self.ignore_index = ignore_index

    def __call__(self, batch) -> dict:
        batch = paddding_collate_fn(batch, self.pad_token_id, self.ignore_index )
        return batch

In [79]:
DEFINE_IGNORE_INDEX=-100
pad_token_id = tokenizer(DEFINED_PAD_TOKEN)[0].ids[0]
print(pad_token_id)
        
dataset_train = TokenSFTDataset(dataset_instruction, tokenizer)
collate_fn = PaddingCollateFunction(pad_token_id=pad_token_id, ignore_index=DEFINE_IGNORE_INDEX)

batch_size = 8
dataloader = DataLoader(dataset_train, 
                    batch_size=batch_size, 
                    shuffle=True,
                    collate_fn = collate_fn)

151643


In [80]:
# print(dataset_train[0]['is_label'])
# for k, batch in enumerate(dataloader):
#     print(batch['labels'])
#     break

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,  10048,  77990,   6832,    323,   6975,    454,
          77990,   6832,    525,   1378,   1887,  19827,    311,   5662,   6832,
             13,    220,  61824,   4056,   6832,   7460,  29829,    821,    323,
          17167,    315,  25185,    429,    646,   3960,    311,   1281,  19898,
            504,    279,   6350,    821,    738,     13,    758,  12872,    

In [82]:
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm 

learning_rate = 1e-5
epochs = 1
vocab_size = model.vocab_size
grad_accmulative = 10

optim = optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(ignore_index=DEFINE_IGNORE_INDEX)

debug_max_step = 1000
total_step = 0

for i in range(epochs):

    # 进度条更新
    train_dataloader_tqdm = tqdm(
            dataloader,  # 数据加载器
            desc=f'Epoch {i+1}/{epochs}',  # 进度条前缀
            ncols=100,    # 进度条宽度
            ascii=' =',   # ASCII 字符样式
        )
    
    for k, batch in enumerate(train_dataloader_tqdm):
        # train step
        optim.zero_grad()
        bsz, seq_len = batch['input_ids'].shape
        output = model(input_ids=batch['input_ids'],
                attention_masks=batch['attention_masks'])
        logits = output.logits
    
        loss = loss_fn(logits.view(bsz*seq_len, vocab_size), 
                       batch['labels'].view(bsz*seq_len)
                      )
        # print(loss.item())
        loss.backward()
        optim.step()

        # 进度条更新
        total_step = total_step + 1
        if total_step % 10 == 0:
                # print(
                # f"epochs:{i}, step:{total_step}, train_loss: {loss.item()}")
                # tqdm.write("\n" + "=" * 80)
            # tqdm.write(f"Epoch {i+1} | "
            #            f"Steps {total_step} | "
            #            f"Loss: {loss.item():.4f} | ")

            train_dataloader_tqdm.set_postfix(
                loss=f'{loss.item():.4f}',
            )
        
        if total_step == debug_max_step:
            break

        # test step

Epoch 1/1:   4%|=                                | 999/26001 [19:48<8:15:32,  1.19s/it, loss=1.0297]


In [83]:
# 不带 template
input_ids = tokenizer(['如何学习python?'], return_tensors='pt')['input_ids']
print(input_ids)
output_ids = model.generate(
    input_ids, 
    max_new_tokens=128,
    do_sample=False,
    temperature=1.0
)
result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(result)

tensor([[100007, 100134,  12669,     30]])
如何学习python? 1. 从基础开始，学习Python的语法和基本概念。 2. 学习一些Python库，如NumPy、Pandas、Matplotlib、Scikit-Learn等。 3. 学习一些Python的框架，如NumPy、Scikit-Learn、Pandas、Matplotlib、Scikit-Learn、Pandas、NumPy、Matplotlib、Scikit-Learn、Pandas、NumPy、Scikit-Learn、Pandas、Matplotlib、Scikit-Learn、Pandas、NumPy、Scikit-Learn、Pandas、Matplotlib、Scikit


In [86]:
def generate(model, tokenizer, prompt, max_new_tokens=128):
    messages_inst = [
            {'role':'SYSTEM', 'content':DEFINIED_SYSTEM_PROMPT},
            {'role':'USER', 'content': prompt},
            {'role':'ASSISTANT', 'content': ''},
        ]
    input_ids, is_label = ChatTemplateToken(messages_inst, tokenizer)

    input_ids = torch.tensor( [input_ids], dtype=torch.long)
    input_ids = input_ids[:, :-1] # 去掉 eos 
    
    output_ids = model.generate(
        input_ids, 
        max_new_tokens=128,
        do_sample=False,
        temperature=1.0
    )
    result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return output_ids, result

In [88]:
prompt='如何学习python?'
output_ids, result = generate(model, tokenizer, prompt, max_new_tokens=128)
print(result)
print(output_ids.shape)


#SYSTEM:你是小冬瓜智能体,请安全详细回答用户 USER 的问题
#USER:如何学习python?
#ASSISTANT:学习Python可以通过阅读书籍、在线课程、实践项目以及参加讨论组。阅读书籍可以帮助理解基本概念和语法，而在线课程可以提供更深入的指导。实践项目可以将所学知识应用到实际问题中，而讨论组则可以提供交流和讨论的平台。
torch.Size([1, 95])


## 思考

1. 写出批量数据生成代码
2. 为什么模型这么小,数据这么少,所训模型生成通用性好
3. pretrained模型不能对话，是否就代表它很差？
4. 不使用 transformers 库的 generate 函数，写出带 kvcache 的 generate 代码
5. 基于 4 进一步写出 right padding 输入 + kvcache + generate 代码
6. 如何测评 SFT 模型
7. 如何进行更高效的微调？