# 从零手搓中文大模型｜🚀 Day07

## SFT 数据准备

`TinyStories`数据集其实也提供了[Instruct数据](https://huggingface.co/datasets/roneneldan/TinyStoriesInstruct)，我可以基于这个数据集在之前的预训练模型上进行指令微调。

先看看数据集的格式：

In [None]:
! head -10 ../../Data/TinyStoriesInstruct/TinyStories-Instruct-valid.txt

这些指令有四种类型：
1. 一个单词列表，包含在故事中。
2. 一个句子，应该出现在故事的某个地方。
3. 一个特征列表（可能的特征：对话、坏结局、道德价值、情节转折、伏笔、冲突）。
4. 一个简短的总结（1-2行）。

现在面临两个问题：
- 数据集是英文的，我需要想办法给整成中文的。
- 数据集的形式和主流的SFT数据集不太一样，需要做一些适配。

  > 个人理解这里是因为这里的指令相对单一（生成故事），只是约束有一些区别，所以作者采取了简单的拼接方式。
  >
  > 这里出于学习的目的还是往主流的SFT数据集上靠拢。

### 吴恩达老师的翻译Agent测试

这里直接试了下[吴恩达老师的translation-agent](https://github.com/andrewyng/translation-agent)项目（`translation-agent.py`文件），使用的是`gpt-4o-mini`的`api`（也尝试过`Ollama`本地部署的`qwen14b`、`qwen7b`，相对来说不太稳定）。

可以看到这里单次翻译的耗时在10秒左右（因为单词翻译的时候`agent`逻辑里有多次`api`调用），因此这里为了后面能够并发调用刷数据，我将代码全部改造成了`async`的异步调用。

大家如果有其他的翻译`api`或者模型也可以替换，这里纯属心血来潮玩一玩儿。

`translation-agent`项目其实只有一个`utils.py`文件，但因为太长了，这里就不把改造后的代码贴出来了，有兴趣的同学可以去仓库里查看。

In [None]:
from translation_agent import translate

text = """
Random sentence: They are very excited and want to fly too.
Features: Dialogue
Summary: Tom and Anna are excited to go on a holiday with their parents, and they fly on a big plane to a place with sun and sand.
Story: 
Tom and Anna are brother and sister. They like to play with their toys and read books. They are very happy because they are going on a holiday with their mum and dad. They will fly on a big plane to a place with a lot of sun and sand.
The day of the holiday comes and they pack their bags. They go to the airport and wait for their plane. They see many other planes flying in the sky. They are very excited and want to fly too.
"Look, Anna, that plane is so big and fast!" Tom says.
"Yes, Tom, and it has wings and a tail. I wonder where it is going," Anna says.
They hear their mum call them. "Come on, kids, it's time to board our plane. We have to show our tickets and go through the gate."
They follow their mum and dad and get on their plane. They find their seats and buckle their belts. They look out the window and see the ground and the cars and the people. They hear the pilot say something on the speaker.
"Hello, everyone, this is your pilot speaking. Welcome aboard flight 123 to Sunny Beach. We are ready to take off. Please sit back and enjoy the flight."
The plane starts to move and makes a loud noise. Tom and Anna feel the plane go faster and faster. They see the ground get smaller and smaller. They see the clouds get closer and closer. They are flying!
"Wow, Anna, we are flying! We are in the sky!" Tom says.
"I know, Tom, it's amazing! We are so high! Look, there is the sun!" Anna says.
They smile and laugh and clap their hands. They are not sad at all. They are very happy. They are flying to their holiday.
"""


result = await translate(
    source_lang="English",
    target_lang="Chinese",
    source_text=text,
    country="China",
)
print(result)

### 数据采样

我先看看训练集有多少条数据，可以发现文本都是以`<|endoftext|>`结尾的，所以通过统计`endoftext`的个数就可以知道数据集的条数。

In [None]:
! grep -o "endoftext" ../../Data/TinyStoriesInstruct/TinyStories-Instruct-train.txt  | wc -l 

接近250w的量级有点大（因为微软的论文里是直接在整个数据集上做的`pretrain`的）。

其实很多研究表明，`SFT`数据的量级不重要，质量够高的时候即使很少的数据也能训练出很好的效果。

所以这里我打算随机抽取11000条数据来试试。

我的策略如下：
1. 遍历`train`数据集，让四类指令的组合尽量均衡（需要先统计指令组合的的分布）
2. 用得到的11000条数据调用上面的`translation-agent`进行翻译
3. 将翻译后的数据整理成`SFT`数据集的`json`格式

先来做数据的采样：

In [None]:
from collections import Counter
import random


def count_field_combinations(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()

    blocks = content.split("<|endoftext|>")
    combinations = []

    for block in blocks:
        fields = set()
        if "Words:" in block:
            fields.add("Words")
        if "Random sentence:" in block:
            fields.add("Random sentence")
        if "Features:" in block:
            fields.add("Features")
        if "Summary:" in block:
            fields.add("Summary")

        if fields:  # 只有当字段不为空时才添加组合
            combinations.append(frozenset(fields))

    return Counter(combinations)


def sample_data(file_path, total_samples=11000):
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()

    blocks = content.split("<|endoftext|>")
    blocks = [block.strip() for block in blocks if block.strip()]  # 移除空块

    combinations = count_field_combinations(file_path)
    combination_more_than_1 = {k: v for k, v in combinations.items() if v > 1}
    samples_per_combination = total_samples // len(combination_more_than_1)

    sampled_data = []
    for combination in combinations:
        matching_blocks = [
            block for block in blocks if set(get_fields(block)) == set(combination)
        ]
        sampled_data.extend(
            random.sample(
                matching_blocks, min(samples_per_combination, len(matching_blocks))
            )
        )

    return sampled_data


def get_fields(block):
    fields = set()
    if "Words:" in block:
        fields.add("Words")
    if "Random sentence:" in block:
        fields.add("Random sentence")
    if "Features:" in block:
        fields.add("Features")
    if "Summary:" in block:
        fields.add("Summary")
    return fields

执行一下看看效果（为了有备无患，多采样了5000条数据），耗时1-2分钟，肯定还有优化空间，但是可以接受。

同时将采样后的数据保存为`pkl`文件，方便后续使用。

In [1]:
import pickle

# sft_raw = sample_data(
#     "../../Data/TinyStoriesInstruct/TinyStories-Instruct-train.txt", 15000
# )
sft_raw = pickle.load(open("sft_raw.pkl", "rb"))
print(f"采样数据总数: {len(sft_raw)}")

# pickle.dump(sft_raw, open("sft_raw.pkl", "wb"))

采样数据总数: 15001


### 批量翻译

接下来就可以调用`translation-agent`进行翻译了。

这里我除了用异步加速，还使用了`json`文件缓存来避免重复翻译（`gpt-4o-mini`的`api`也不算便宜，能省则省）。

In [None]:
import json
import aiofiles
import asyncio

cache_file = "translation_cache.json"


async def translate_and_cache(block, cache, semaphore):
    cache_key = hash(block)

    if str(cache_key) in cache:
        return cache[str(cache_key)]

    async with semaphore:
        try:
            result = await translate(
                source_lang="English",
                target_lang="Chinese",
                source_text=block,
                country="China",
            )
            cache[str(cache_key)] = result
            return result
        except Exception as e:
            print(f"翻译失败: {e}")
            return None


async def batch_translate(sampled_data, cache_file, max_workers=10):
    translated_data = []

    try:
        async with aiofiles.open(cache_file, "r") as f:
            cache = json.loads(await f.read())
    except (FileNotFoundError, json.JSONDecodeError):
        cache = {}

    semaphore = asyncio.Semaphore(max_workers)
    tasks = [translate_and_cache(block, cache, semaphore) for block in sampled_data]
    results = await asyncio.gather(*tasks)

    translated_data = [result for result in results if result]

    async with aiofiles.open(cache_file, "w") as f:
        await f.write(json.dumps(cache, ensure_ascii=False, indent=2))

    return translated_data


translated_data = await batch_translate(sft_raw, cache_file, max_workers=100)

使用了100路的并发，翻译了15000条数据，耗时48分钟，也就是大概每分钟翻译300条数据。

### 后续处理

翻译完成了，最后一步就是将数据整理成`SFT`数据集的格式。

（这里还发现了个小问题，翻译统一将**总结**字段放到了最后，导致顺序出现了问题，所以这里需要先处理一下。）

In [2]:
import itertools
import json
import random
from collections import Counter
from pprint import pprint

instruction_template = "按照下面输入的约束生成故事"


def split_data(data, keys):
    result = []
    current_key = None
    current_content = ""

    for line in data.split("\n"):
        line = line.strip()
        if any(key in line for key in keys):
            if current_key:
                result.append((current_key, current_content.strip()))
            for key in keys:
                if key in line:
                    current_key, current_content = line.split(key, 1)
                    current_key = key.strip()
                    current_content = current_content.strip().lstrip("：").strip()
                    break
        else:
            current_content += " " + line

    if current_key:
        result.append((current_key, current_content.strip()))

    return result


def process_translated_data(input_file, output_file, expand_data=True):
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    processed_data = []
    constraint_keys = Counter()

    for key, value in data.items():
        if "故事：" not in value:
            continue
        parts = value.split("故事：")

        if len(parts) == 2:
            input_text = parts[0].strip()
            output_text = parts[1].strip()
            if "总结：" in output_text or "摘要：" in output_text:
                # 将总结或摘要提取出来放到input_text中
                for keyword in ["总结", "摘要"]:
                    if keyword in output_text:
                        summary = output_text.split(f"{keyword}：")[1].strip()
                        input_text += f"\n{keyword}：{summary}"
                        output_text = output_text.split(f"{keyword}：")[0].strip()
                        break

            # 提取约束描述文本的关键字段
            lines = input_text.split("\n")
            for line in lines:
                if "：" in line:
                    key, _ = line.split("：", 1)
                    constraint_keys[key.strip()] += 1

            processed_item = {
                "instruction": instruction_template,
                "input": f"{input_text}",
                "output": output_text,
            }

            processed_data.append(processed_item)
    # 根据constraint_keys的频率排序，选取出现频率大于10的关键字
    constraint_keys = {k: v for k, v in constraint_keys.items() if v > 10}

    # 数据增强
    if expand_data:
        expanded_data = []
        for item in processed_data:
            input_tuple_list = split_data(item["input"], constraint_keys)
            if not input_tuple_list:
                continue

            for permutation in itertools.permutations(input_tuple_list):
                new_item = item.copy()
                new_item["input"] = "\n".join(
                    [f"{key}：{value}" for key, value in permutation]
                )
                expanded_data.append(new_item)
    else:
        expanded_data = processed_data

    # 对结果做一个打乱
    random.shuffle(expanded_data)
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(expanded_data, f, ensure_ascii=False, indent=2)

    return expanded_data, constraint_keys


expanded_data, constraint_keys = process_translated_data(
    "translation_cache.json",
    "../../Data/TinyStoriesInstruct/sft_data_no_expansion.json",
    expand_data=False,
)

看一看处理的结果，这样就和经典的`SFT`数据格式一致了。

In [14]:
pprint(expanded_data[0])

{'input': '特点：转折\n'
          '摘要：莉莉在她的大红盒子里发现了一只会魔法的青蛙，并希望有一个朋友可以一起玩，青蛙实现了她的愿望，变成了一个小女孩。\n'
          '随机句子：盒子没有弯曲，但感觉不同。\n'
          '词汇：敲打，盒子，灵活',
 'instruction': '按照下面输入的约束生成故事：',
 'output': '从前，有一个小女孩叫莉莉。她房间里有个大红盒子。莉莉喜欢玩这个盒子。一天，她想看看盒子是否灵活，于是她试着去弯它。盒子没有弯曲，但感觉不同。  \n'
           '莉莉决定用手敲打盒子，看看会发生什么。当她这样做时，盒子打开了！盒子里有一只小小的绿色青蛙。青蛙看着莉莉说：“你好！我是一只会魔法的青蛙。”  \n'
           '莉莉非常惊讶，她简直不敢相信自己的眼睛。青蛙告诉莉莉，他可以实现她一个愿望。莉莉想了一会儿，许下了一个希望有朋友一起玩的愿望。魔法青蛙微笑着变成了一个小女孩，就像莉莉一样。她们整天玩得不亦乐乎。'}


In [15]:
len(expanded_data)

71322

## 小结
1. 基于`TinyStories`的`Instruct`数据进行指令组合层面均衡的采样，获得了15000条原始数据
2. 构造了翻译函数，异步使用吴恩达老师的`translation-agent`对数据进行翻译
3. 基于翻译后的数据，构造了经典格式的`SFT`数据集