In [6]:
! pip install trl
! pip install -U bitsandbytes



In [7]:
import torch
from datasets import load_dataset
from transformers import(
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)

from trl import SFTTrainer
import os

# 加载基础模型和分词器

In [9]:
# Hugging Face Hub 中的模型标识符
base_model_id = "gpt2" # 替换为您的模型，例如 "NousResearch/Llama-2-7b-hf"

# 可选：量化配置
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
# 如果未设置，则设置填充符
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 加载模型
model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    # quantization_config=bnb_config, # 如果使用量化，请取消注释
    device_map="cuda:0" if torch.cuda.is_available() else "auto", # Load model to GPU if available, otherwise auto
    trust_remote_code=True
)
model.config.use_cache = False # 建议用于微调

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

# 加载数据集

In [10]:
dataset_name = "databricks/databricks-dolly-15k"
dataset = load_dataset(dataset_name, split="train[:2000]")

In [11]:
len(dataset)

2000

In [12]:
dataset[0]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [13]:
if 'text' not in dataset.column_names:
    def create_text_column(example):
        # 简单拼接，如果需要，根据实际 Dolly 格式调整
        instr = example.get('instruction', '')
        resp = example.get('response', '')
        ctx = example.get('context', '')
        if ctx:
             return {"text": f"### Instruction:\n{instr}\n\n### Context:\n{ctx}\n\n### Response:\n{resp}"}
        else:
             return {"text": f"### Instruction:\n{instr}\n\n### Response:\n{resp}"}

    dataset = dataset.map(create_text_column, remove_columns=dataset.column_names)

print("样本数据集条目：")
print(dataset[0]['text'])

样本数据集条目：
### Instruction:
When did Virgin Australia start operating?

### Context:
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.


In [16]:
output_dir = "./sft_model_output"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2, # 根据 GPU 内存调整
    gradient_accumulation_steps=4, # 有效批次大小 = 批次大小 * 梯度累积
    learning_rate=2e-5,
    logging_steps=20,              # 每 20 步记录指标
    num_train_epochs=1,            # 遍历数据集的次数
    max_steps=-1,                  # 设置为 >0 以覆盖轮数
    save_strategy="epoch",         # 在每轮结束时保存检查点
    # save_steps=50,               # 或每 N 步保存
    report_to="wandb",       # 或 "wandb", "none"
    fp16=True,                     # 使用混合精度（需要兼容的 GPU）
    # bf16=True,                   # 使用 BF16（需要 Ampere+ GPU）- 二选一
    optim="paged_adamw_8bit",      # 内存高效优化器，尤其是在量化时
    lr_scheduler_type="cosine",    # 学习率调度器类型
    warmup_ratio=0.03,             # 调度器的预热步数
)

In [17]:
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=dataset,
                               # 对于包含许多短序列的数据集，设置 packing=True 可能会加快速度。
                               # 需要仔细准备数据集。
)

Tokenizing train dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [18]:
print("开始 SFT 训练...")
trainer.train()
print("训练完成。")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


开始 SFT 训练...


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mliqing20[0m ([33mliqing20-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
20,3.3299
40,3.0318
60,2.8814
80,2.8077
100,2.8781
120,2.8311
140,2.7736
160,2.7904
180,2.8247
200,2.8187


训练完成。


In [19]:
# 定义最终保存模型的路径
final_model_path = os.path.join(output_dir, "final_sft_model")

print(f"正在将最终 SFT 模型保存到 {final_model_path}")
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path) # 将分词器与模型一起保存
print("模型保存成功。")

正在将最终 SFT 模型保存到 ./sft_model_output/final_sft_model
模型保存成功。


In [20]:
from transformers import pipeline

# 加载微调模型
print("正在加载微调模型进行推理...")
sft_pipe = pipeline("text-generation", model=final_model_path, tokenizer=final_model_path, device_map="auto")

# 定义一个符合模型预期的示例提示
# （与 SFT 数据集中使用的格式匹配）
prompt_text = "### Instruction:\nExplain the main benefit of using version control systems like Git.\n\n### Response:\n"

print(f"\n正在为提示生成响应：\n{prompt_text}")

# 生成响应
# 根据需要调整生成参数
output = sft_pipe(prompt_text, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7)

print("\n生成的响应：")
print(output[0]['generated_text'])

正在加载微调模型进行推理...


Device set to use cuda:0



正在为提示生成响应：
### Instruction:
Explain the main benefit of using version control systems like Git.

### Response:


生成的响应：
### Instruction:
Explain the main benefit of using version control systems like Git.

### Response:
Version control systems are often used to automate the production of software. They allow you to control how a particular version of a software works. The main benefit of using version control systems is that they allow you to manage and maintain your software in the best way possible.
