In [1]:
from tqdm import tqdm
import json
from glob import glob
from transformers import AutoTokenizer, PreTrainedTokenizer
import os 
import shutil
from pathlib import Path

In [2]:
TARGET_DIR = "data/data4pretrain"


In [3]:
internlm_tokenizer = AutoTokenizer.from_pretrained(
    "internlm_tokenizer", trust_remote_code=True
)

In [4]:
internlm_tokenizer.encode("你是谁啊啊啊啊")

[1, 69596, 61275, 81533]

In [5]:
with open(Path(TARGET_DIR).joinpath("aaa.json"), mode='w', encoding='utf-8') as fout:
    fout.write(json.dumps(
        {'a':[1,2,3,4]}, ensure_ascii=False
    ) + "\n",)

In [6]:
from typing import Any


class BuildPreTrainDataset:
    def __init__(
        self,
        target_dir: str,
        source_dir: str,
        file_size: int = 100000,
        tokenizer: PreTrainedTokenizer = None,
    ) -> None:
        self.target_dir = Path(target_dir)
        self.source_dir = Path(source_dir)
        shutil.rmtree(TARGET_DIR, ignore_errors=True)
        os.makedirs(TARGET_DIR)

        self.source_file_list = glob(self.source_dir.__str__() + "/*")[:3]

        self.index_file_name = self.target_dir.joinpath("index.json")
        self.file_size = file_size
        self.tokenizer = tokenizer

    def get_item(self):
        for temp_source_file in self.source_file_list:
            with open(temp_source_file, encoding="utf-8", mode="r") as fin:
                temp_data = fin.readlines()
                for i in temp_data:
                    yield json.loads(i)

    def __call__(self) -> None:
        start_file_name_index = 0
        total_token = 0
        for item_index_, item in tqdm(
            enumerate(self.get_item()), desc=f"gendata...2file"
        ):
            if item_index_ % self.file_size == 0 and item_index_ != 0:
                start_file_name_index += 1
            with open(
                self.target_dir.joinpath(f"file_{start_file_name_index}.json"),
                encoding="utf-8",
                mode="a",
            ) as fout:
                text = item["text"]
                tokenids = self.tokenizer.encode(text)
                total_token += len(tokenids)
                fout.write(
                    json.dumps({"input_ids": tokenids}, ensure_ascii=False) + "\n"
                )

        self.save_total_info(item_index_ + 1, total_token - item_index_ - 1)

    def save_total_info(self, total_item: int, total_token: int):
        with open(self.index_file_name, mode="w", encoding="utf-8") as fout:
            fout.write(
                json.dumps(
                    {"total_item": total_item, "total_token": total_token},
                    ensure_ascii=False,
                )
            )


buildpretrained = BuildPreTrainDataset(
    target_dir=TARGET_DIR,
    source_dir="data/pretrained_data",
    file_size=100000,
    tokenizer=internlm_tokenizer,
)

buildpretrained()

gendata...2file: 600882it [29:02, 344.79it/s]


In [7]:
class MyClass:
    def __init__(self) -> None:
        pass
    def get_item(self):
        for il in [[1,2,3], [4,5,6]]:
            for j in il:
                yield j 


myclass = MyClass()
for index, value in enumerate(myclass.get_item()):
    print(value)

1
2
3
4
5
6


In [8]:
100 % 10

0