
# Lecture 10 — LLM Training Pipeline Notebook

配合课件《大语言模型解析 VII》，演示从数据准备到训练循环的核心步骤。



## Prerequisites
- Python 3.10+，建议使用支持 GPU 的环境
- 安装 `datasets`, `transformers`, `torch`
- 若需要下载 Hugging Face 数据集或模型，请提前配置网络或本地缓存


In [None]:
# 可按需安装依赖（默认注释，按需取消）
# %pip install datasets transformers accelerate peft

In [None]:
from datasets import load_dataset, ClassLabel
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

print('Libraries imported.')


## Step 1 · 载入 Alpaca 数据集
Alpaca 数据集包含 instruction / input / output 字段，适合作为指令微调示例。若在离线环境，可将数据集下载至本地目录后使用 `load_dataset("json", data_files=...)` 载入。


In [None]:
dataset = load_dataset('tatsu-lab/alpaca')
train_raw = dataset['train']
print(train_raw)

In [None]:
# 查看一个原始样本
display(train_raw[0])


## Step 2 · 规则过滤
根据业务需求过滤异常样本，例如空答案、长度过短或包含敏感词的样本。


In [None]:
def keep_example(example):
    answer = example['output'].strip()
    return len(answer) > 5

filtered = train_raw.filter(keep_example)
print(filtered)


## Step 3 · 统一字段结构
课程中推荐将样本整理为 `messages` 列表，兼容 LLaMA Factory / ChatML 等模板。


In [None]:
def build_messages(example):
    prompt = example['instruction'].strip()
    if example['input']:
        prompt += "\n" + example['input'].strip()
    return {
        'messages': [
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': example['output'].strip()},
        ]
    }

structured = filtered.map(build_messages)
structured = structured.remove_columns([col for col in structured.column_names if col != 'messages'])
print(structured)
display(structured[0])


## Step 4 · 按长度分桶并划分训练 / 验证
通过增加 `len_bucket` 字段实现基于长度的分层抽样，保持长短上下文在不同划分中的比例。


In [None]:
from collections import Counter

def add_length_bucket(example):
    user_len = len(example['messages'][0]['content'].split())
    bucket = min(user_len // 200, 4)
    example['len_bucket'] = bucket
    return example

bucketed = structured.map(add_length_bucket)

def _bucket_counts(ds):
    return Counter(ds['len_bucket'])

bucket_counts = _bucket_counts(bucketed)
print('原始桶统计:', bucket_counts)

valid_buckets = sorted(bucket for bucket, count in bucket_counts.items() if count >= 2)
if len(valid_buckets) < len(bucket_counts):
    dropped = {bucket: bucket_counts[bucket] for bucket in bucket_counts if bucket not in valid_buckets}
    print('移除样本过少的桶:', dropped)
    bucketed = bucketed.filter(lambda example: example['len_bucket'] in valid_buckets)
    bucket_counts = _bucket_counts(bucketed)
    valid_buckets = sorted(bucket_counts.keys())
    print('过滤后桶统计:', bucket_counts)

if len(valid_buckets) >= 2:
    bucket_id_map = {bucket: idx for idx, bucket in enumerate(valid_buckets)}

    def remap_bucket(example):
        example['len_bucket'] = bucket_id_map[example['len_bucket']]
        return example

    bucketed = bucketed.map(remap_bucket, desc='Remap buckets')
    bucketed = bucketed.cast_column('len_bucket', ClassLabel(num_classes=len(valid_buckets)))

    split_data = bucketed.train_test_split(
        test_size=0.02,
        seed=42,
        stratify_by_column='len_bucket',
    )
    print('使用分层抽样完成数据划分。')
else:
    print('可用桶数量不足，退回到随机划分。')
    split_data = bucketed.train_test_split(
        test_size=0.02,
        seed=42,
    )

train_data = split_data['train']
val_data = split_data['test']
print('Train size:', len(train_data))
print('Validation size:', len(val_data))

if len(valid_buckets) >= 2:
    print('Train bucket分布:', Counter(train_data['len_bucket']))
    print('Val bucket分布:', Counter(val_data['len_bucket']))



## Step 5 · Tokenizer 对齐
加载 LLaMA 系列 tokenizer，并构造模板函数，将 `messages` 转换为单段文本。


In [None]:
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1b')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print('Tokenizer vocab size:', tokenizer.vocab_size)

In [None]:
PROMPT_PREFIX = "<|user|>\n"
PROMPT_SUFFIX = "\n<|assistant|>\n"
RESPONSE_SUFFIX = tokenizer.eos_token

def build_prompt_parts(example):
    user = example['messages'][0]['content']
    assistant = example['messages'][1]['content']
    prompt_text = f"{PROMPT_PREFIX}{user}{PROMPT_SUFFIX}"
    response_text = f"{assistant}{RESPONSE_SUFFIX}"
    return prompt_text, response_text

prompt_text, response_text = build_prompt_parts(train_data[0])
print('Prompt preview:')
print('\n'.join(prompt_text.splitlines()[:4]))
print('\nResponse preview:')
print('\n'.join(response_text.splitlines()[:4]))

In [None]:
def tokenize(example, max_length=2048):
    prompt_text, response_text = build_prompt_parts(example)
    prompt_ids = tokenizer(prompt_text, add_special_tokens=False)['input_ids']
    response_ids = tokenizer(response_text, add_special_tokens=False)['input_ids']

    max_response_len = max_length - len(prompt_ids)
    if max_response_len <= 0:
        input_ids = prompt_ids[:max_length]
        attention_mask = [1] * len(input_ids)
        labels = [-100] * len(input_ids)
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

    response_ids = response_ids[:max_response_len]
    input_ids = prompt_ids + response_ids
    attention_mask = [1] * len(input_ids)
    labels = [-100] * len(prompt_ids) + response_ids

    if len(labels) < len(input_ids):
        labels += [-100] * (len(input_ids) - len(labels))

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
    }

tokenized_train = train_data.map(
    tokenize,
    remove_columns=train_data.column_names,
)

tokenized_val = val_data.map(
    tokenize,
    remove_columns=val_data.column_names,
)

print(tokenized_train)
print(tokenized_val)


## Step 6 · Collate 函数
collate 函数负责批量对齐，并为填充位置生成 `attention_mask` 与忽略的 `labels`。


In [None]:
def causal_lm_collate(batch, pad_id, label_pad_id=-100):
    input_ids = [torch.tensor(example['input_ids']) for example in batch]
    labels = [torch.tensor(example['labels']) for example in batch]
    attention = [torch.tensor(example['attention_mask']) for example in batch]
    print(input_ids)
    padded_input = pad_sequence(input_ids, batch_first=True, padding_value=pad_id)
    padded_labels = pad_sequence(labels, batch_first=True, padding_value=label_pad_id)
    padded_attention = pad_sequence(attention, batch_first=True, padding_value=0)

    return {
        'input_ids': padded_input,
        'labels': padded_labels,
        'attention_mask': padded_attention,
    }

collate_preview = causal_lm_collate([
    tokenized_train[0],
    tokenized_train[1],
], pad_id=tokenizer.pad_token_id)

for key, value in collate_preview.items():
    print(key, value.shape)
print('Label sample:', collate_preview['labels'][0][:])
print('Label sample:', collate_preview['input_ids'][0][:])
print('Label sample:', collate_preview['attention_mask'][0][:])

In [None]:
train_loader = DataLoader(
    tokenized_train,
    batch_size=4,
    shuffle=True,
    collate_fn=lambda batch: causal_lm_collate(batch, pad_id=tokenizer.pad_token_id),
)

batch = next(iter(train_loader))
for key, value in batch.items():
    print(key, value.shape)
print('First labels row:', batch['labels'][0][:90])


## Step 7 · Trainer 集成
为便于演示，使用 Hugging Face 的 Tiny 模型，实践时可替换为真实的 LLaMA / Qwen / Baichuan 权重。


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    'hf-internal-testing/tiny-random-OPTForCausalLM'
)

print('Model loaded with parameters:', sum(p.numel() for p in model.parameters()))

In [None]:
training_args = TrainingArguments(
    output_dir='checkpoints/lecture10-demo',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    num_train_epochs=1,
    logging_steps=10,
    save_strategy='no',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=lambda batch: causal_lm_collate(batch, pad_id=tokenizer.pad_token_id),
)

print('Trainer ready.')


> **提示**：实际训练 LLaMA 等大模型需要显著的显存与计算资源，建议使用分布式/QLoRA/DeepSpeed 等手段；本示例仅演示数据与代码流程。



## 附录 · 与 LLaMA Factory 的衔接
- 将 `build_messages`/`tokenize` 逻辑注册为 `preprocess_func`
- 在 `config/*.yaml` 中设置 `dataset`, `template`, `data_collator` 等字段
- LoRA/QLoRA 配置示例见课件 `lora_qlora.yaml`
