## 通过五折交叉验证来随机划分训练集和测试集 

In [1]:
import logging
import random
from transformers import BertTokenizer

def fold_crossover(lst, num_iteration):
    random.shuffle(lst)  # 随机打乱原始列表
    sublist_size = len(lst) // num_iteration  # 计算每个子列表的大小
    sublists = []
    for i in range(0, len(lst), sublist_size):
        sublist = lst[i:i + sublist_size]  # 提取一个子列表
        sublists.append(sublist)  # 将子列表添加到子列表列表中
    return sublists

def clean_list(lst, id_list):
    cleaned_list = []
    for i in range(len(lst)):
        elem = "##id:" + str(id_list[i]) + " O" + lst[i]
        if len(elem) >= 50:
            if elem.startswith('\n\n\n'):
                elem = elem[3:]
            cleaned_list.append(elem)
    return cleaned_list

def read_examples_from_file(file_path, num_iteration):
    logging.info('***** 开始读入样本 *****')
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    content = ""
    is_letter_between = True
    id_list = []   # 存储每一篇文章的id
    # 将属于同一篇文章的内容拼接为一个字符串
    for i in range(len(lines)):
        if lines[i].startswith("##id:"):
            # 获得这一段文章对应的id
            id_list.append(lines[i][5:].split(" ")[0])
            is_letter_between = False
        else:
            if is_letter_between:
                content += lines[i]
            else:
                is_letter_between = True
                content += "\n\n\n\n\n" + lines[i]
    # 对每一篇文章进行数据清洗
    content_list =  clean_list(content.split("\n\n\n\n\n")[1:], id_list)

    # 将语料数据随机划分为5份
    subcontent_list = fold_crossover(content_list, num_iteration)

    # 随机获得一份作为测试集，另外的则拼接成为训练集
    i = random.randint(0, 4)
    dev_content = subcontent_list[i]
    # 总语料集里拿掉测试集就是训练集
    train_content = [elem for elem in content_list if elem not in dev_content]
    return train_content, dev_content


def generate_5cv_file(model_path, input_path, num_iteration, train_output_path, eval_output_path):
    tokenzier = BertTokenizer.from_pretrained(model_path)
    # 编码器，from_pretrained 从包含词表文件（vocab.txt）的目录中初始化一个分词器

    train_content, dev_content = read_examples_from_file(input_path, num_iteration)


    with open(train_output_path, 'w', encoding='utf-8') as f:
        for doc in train_content:
            f.write(doc)
            f.write('\n')

    with open(eval_output_path, 'w', encoding='utf-8') as f:
        for doc in dev_content:
            f.write(doc)
            f.write('\n')
    return 0

In [3]:
input_path = '/ssd01/Codes/PersonalCodes/ZhangXianpeng/graduation_design/datas/fineturn_data/mechanical_data/mechine_zhang.txt'
# 此为手动标注的样本数据
train_output_path = '/ssd01/Codes/PersonalCodes/ZhangXianpeng/graduation_design/datas/fineturn_data/mechanical_data/mechine_train.txt'
eval_output_path = '/ssd01/Codes/PersonalCodes/ZhangXianpeng/graduation_design/datas/fineturn_data/mechanical_data/mechine_dev.txt'
num_iteration = 5
model_path = '/ssd01/Codes/PersonalCodes/ZhangXianpeng/graduation_design/mechanical_ner/model_bin'

generate_5cv_file(model_path, input_path, num_iteration, train_output_path, eval_output_path)

0