## SFT训练Qwen小模型

### wandb

In [None]:
%pip install wandb
import wandb
wandb.login(key="your_wandb_api_key_here")  # 替换为你的wandb API密钥
project_name = "SFT-Qwen2.5-3B" # todo: wandb项目命名
# project_name = "SFT-Qwen2.5-0.5B-final"
wandb.init(project = project_name)

### 库 和 提示词

In [None]:
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
import json
from modelscope import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import LoraConfig

print('cuda support:',torch.cuda.is_available())

SYSTEM_PROMPT='''
# Role
I am an elementary school math tutor. Currently, I will give a clear and easy-to-understand explanation for a specific elementary school math word problem.
Note that the response should be in the form of spoken textual expression, and avoid including the formats such as pictures, hyperlinks, etc.

# Precautions
1. Fully consider the cognitive level and knowledge reserve of primary school students.
2. If the problem is complex, break it down into simple steps. Avoid using professional terms and explain it in plain language.
3. Make more use of real-life examples and intuitive teaching aids to assist in teaching.

# Teaching Style
Adopt a step-by-step teaching method, combined with interactive Q&A sessions and classroom exercises; encourage students through positive feedback to strengthen their understanding of mathematical concepts.

# Answer Format
<reasoning>
Break down, analyze, and reflect on the problem step by step to sort out the thinking process of the solution.
</reasoning>
<answer>
Start explaining the problem from the first-person perspective of a math tutor.
</answer>
'''

### 加载模型、分词器、蒸馏的数据

In [None]:
# 加载蒸馏得到的数据集
def load_distill_dataset():
    ds={'messages':[]}
    with open('r1_distill.txt','r') as f:
        lines=f.readlines()
        for line in lines:
            line=json.loads(line)
            sample=[
                    {'role':'system','content':SYSTEM_PROMPT}, 
                    {'role':'user','content': line['question']}, 
                    {'role':'assistant','content': f"<reasoning>\n{line['reasoning']}\n</reasoning>\n<answer>\n{line['answer']}\n</answer>"},
            ]
            ds['messages'].append(sample)
    return Dataset.from_dict(ds)

In [None]:
model_name='Qwen2.5-3B-Instruct' # todo: 选取模型
model_dir='models/' + model_name
model=AutoModelForCausalLM.from_pretrained(model_dir,torch_dtype="auto",device_map="auto")
tokenizer=AutoTokenizer.from_pretrained(model_dir)
dataset=load_distill_dataset()

In [None]:
dataset

### 训练

#### SFT 和 LoRA 配置

In [None]:
# 定义SFT（Supervised Fine-Tuning）配置，设置训练的各种参数
sft_config=SFTConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    lr_scheduler_type='linear',
    warmup_ratio=0.1,
    learning_rate=5e-6,
    max_seq_length=500,
    logging_steps=1,
    save_steps=0.3,
    num_train_epochs=2,
    report_to='wandb',
    fp16=True,
    max_grad_norm=0.1,
    output_dir='models/' + model_name + '-SFT', 
)

# 定义LoRA（Low-Rank Adaptation）配置，用于对模型进行低秩适应
lora_config=LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],  # 要应用LoRA的目标模块
    lora_dropout=0.05,
    task_type='CAUSAL_LM',
)

#### 创建SFT训练器

In [None]:
# 创建SFT训练器对象，传入模型、分词器、训练数据集、SFT配置和LoRA配置
trainer=SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    args=sft_config,
    peft_config=lora_config,
)

#### 开始训练

In [None]:
trainer.train()

output_dir='models/' + model_name + '-SFT'
trainer.save_model(output_dir)