# 文段句格式语料转格式jsonlines

原始语料的格式是 JSON Lines，每一行是一篇文章，其内容是二维数组，第一个维度是段落，第二个维度是句子

## 代码准备

### Imporings

In [20]:
import json
import math
import os
import sys
import random
from contextlib import closing, ExitStack
from datetime import timedelta
from fileinput import FileInput
from functools import partial
from glob import glob, iglob
from itertools import chain, cycle, islice, count, repeat
from multiprocessing import Pool
from time import time
from contextlib import ExitStack 
from glob import glob
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import sentencepiece as spm
from tqdm.auto import tqdm


### Constants

In [2]:
SEQ_LENGTH = 1024
MIN_CTX_LEN = 32

DATASET_NAMES = ['train', 'valid', 'test']

## SentencePiece

In [3]:
SPM_MODEL = '../data/spm/gpt2_huamei_corpus_bpe_32k_v2.model'


SP = spm.SentencePieceProcessor()
SP.load(SPM_MODEL)

True

### Functions

In [4]:
def text_files_line_iterator(paths):
    return chain.from_iterable(
        open(path)
        for path
        in tqdm(paths, '[iter files]', unit='file')
    )


def single_text_file_line_count(path, show_progress_bar=False):
    with open(path) as fd:
        iterable = tqdm(fd) if show_progress_bar else fd
        return sum(1 for _ in iterable)
        

def text_files_line_count(paths):
    try:
        total = len(paths)
    except (AttributeError, TypeError):
        total = None
    with Pool() as pool:
        it = pool.imap_unordered(
            single_text_file_line_count,
            tqdm(paths, '[map files ]', unit='file')
        )
        return sum(c for c in tqdm(it, '[sum files ]', unit='file', total=total))


def proc_line(line):
    result = []
    line = line.strip()
    if not line:
        return result
    paragraphs = json.loads(line)
    text = ''
    n_text = 0
    for sentence in chain.from_iterable(paragraphs):
        sentence = sentence.strip()
        if not sentence:
            continue
        n_sentence = len(SP.encode_as_ids(sentence))
        if n_text + n_sentence < SEQ_LENGTH + MIN_CTX_LEN // 2:
            text += sentence
            n_text += n_sentence
        else:
            result.append({'text': text, 'length': n_text})
            text = sentence
            n_text = n_sentence
    if n_text:
        result.append({'text': text, 'length': n_text})
    return result
    

## 语料文件

### 输入文件

### 列出输入文件

In [9]:
%%time

BASE_SRC_DIR = '/nfs/server01_public/data/gpt2/output/gpt2_huamei_corpus.json.train'

def list_data_files(ds):
    assert ds in DATASET_NAMES
    return [
        path
        for path in tqdm(iglob(
            os.path.join(BASE_SRC_DIR, '**', f'{ds}_*'),
            recursive=True
        ), f'search for {ds} data files', unit='file')
        if os.path.isfile(path) and os.path.splitext(path)[1].lower().startswith('.json')
    ]


dataset = {}

for ds in DATASET_NAMES:
    dds = dataset[ds] = {}
    source = dds['source'] = []
    paths = list_data_files(ds)
    for path in tqdm(paths, 'count lines'):
        source.append({
            'path': path,
            'lines': sum(1 for _ in open(path))
        })

print('输入语料理行数小计:')
for ds, dds in dataset.items():
    source_list = dds['source']
    lines = sum(dsrc['lines'] for dsrc in source_list)
    print(f'{ds}: {lines:,d}')

HBox(children=(IntProgress(value=1, bar_style='info', description='search for train data files', max=1, style=…




HBox(children=(IntProgress(value=0, description='count lines', max=19, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=1, bar_style='info', description='search for valid data files', max=1, style=…




HBox(children=(IntProgress(value=0, description='count lines', max=19, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=1, bar_style='info', description='search for test data files', max=1, style=P…




HBox(children=(IntProgress(value=0, description='count lines', max=19, style=ProgressStyle(description_width='…


输入语料理行数小计:
train: 10,839,332
valid: 52,908
test: 13,222
CPU times: user 52 s, sys: 5.92 s, total: 58 s
Wall time: 57.7 s


### 中间输出文件

In [12]:
BASE_MID_DIR = '../tmp'

os.makedirs(BASE_MID_DIR, exist_ok=True)

for ds, dds in dataset.items():
    mid_path = dds['mid_file'] = os.path.join(BASE_MID_DIR, f'{ds}.mid.txt')
    print(f'{ds}: {mid_path}')


train: ../tmp/train.mid.txt
valid: ../tmp/valid.mid.txt
test: ../tmp/test.mid.txt


## 执行

In [17]:
for ds, dds in dataset.items():
    print(f'处理 {ds} 数据 ...')
    total = sum(d['lines'] for d in dds['source'])
    n_samples = 0
    n_discard = 0
    with Pool() as pool, \
         FileInput(d['path'] for d in dds['source']) as iterable, \
         open(dds['mid_file'], 'w') as fp:
        it = pool.imap_unordered(
            proc_line,
            tqdm(iterable, 'map lines', total=total),
            chunksize=128
        )
        for result in tqdm(it, 'reduce all', total=total):
            for d in result:
                if d['length'] < MIN_CTX_LEN:
                    n_discard += 1
                    continue
                s = json.dumps(d, ensure_ascii=False)
                n_samples += 1
                print(s, file=fp)
    print(f'{ds} 得到语料样本数：{n_samples:,d}')
    print(f'{ds} 抛弃语料样本数：{n_discard:,d}')
    print()


处理 train 数据 ...


HBox(children=(IntProgress(value=0, description='map lines', max=10839332, style=ProgressStyle(description_wid…

HBox(children=(IntProgress(value=0, description='reduce all', max=10839332, style=ProgressStyle(description_wi…



train 得到语料样本数：13,569,050
train 抛弃语料样本数：445,104

处理 valid 数据 ...


HBox(children=(IntProgress(value=0, description='map lines', max=52908, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='reduce all', max=52908, style=ProgressStyle(description_width…



valid 得到语料样本数：69,394
valid 抛弃语料样本数：1,659

处理 test 数据 ...


HBox(children=(IntProgress(value=0, description='map lines', max=13222, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='reduce all', max=13222, style=ProgressStyle(description_width…



test 得到语料样本数：16,804
test 抛弃语料样本数：471



### 记录中间文件行数

In [42]:
for ds, dds in tqdm(dataset.items()):
    lines = sum(1 for _ in open(dds['mid']))
    print(f'{ds} 中间样本数: {lines:,d}')
    dds['mid_lines'] = lines

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

train 中间样本数: 13,569,050
valid 中间样本数: 69,394
test 中间样本数: 16,804



### 查看输出文件

In [29]:
midfiles_string = ' '.join(d['mid_file'] for d in dataset.values())

!du -hc {midfiles_string}
!echo ""
!wc -l {midfiles_string}

26G	../tmp/train.mid.txt
143M	../tmp/valid.mid.txt
33M	../tmp/test.mid.txt
26G	总用量

   13569050 ../tmp/train.mid.txt
      69394 ../tmp/valid.mid.txt
      16804 ../tmp/test.mid.txt
   13655248 总用量


## 长度统计

## 统计绘图软件包

采用 holoviz

见：  <http://holoviews.org/user_guide/Large_Data.html>

**如果没有安装，运行**：

In [None]:
%conda install -y -c defaults -c conda-forge -c pyviz holoviz

In [None]:
import json

import pandas as pd
import holoviews as hv
import hvplot.pandas  # noqa

import datashader as ds
import datashader.transfer_functions as tf

from holoviews.operation.datashader import datashade, shade, spread, dynspread, rasterize, spread
from tqdm.auto import tqdm

hv.extension('bokeh')

In [48]:
%%time

def extract_length_from_line(line):
    line = line.strip()
    if not line:
        return 0
    d = json.loads(line)
    return d.get('length', 0)


def iget_corpus_length_dict(paths):
    with Pool() as pool, FileInput(paths) as iterable:
        it = pool.imap_unordered(extract_length_from_line, iterable, chunksize=128)
        for length in it:
            yield {'length': length}

with closing(
    iget_corpus_length_dict(d['mid_file'] for d in dataset.values())
) as iterable:
    df = pd.DataFrame(
        tqdm(
            iterable,
            total=sum(d['mid_lines'] for d in dataset.values())
        )
    )

HBox(children=(IntProgress(value=0, max=13655248), HTML(value='')))


CPU times: user 2min 29s, sys: 36.7 s, total: 3min 5s
Wall time: 2min 54s


In [49]:
df.describe()

Unnamed: 0,length
count,13655250.0
mean,477.7334
std,380.4918
min,32.0
25%,135.0
50%,327.0
75%,952.0
max,27480.0


In [50]:
points = hv.Points(df.reset_index(), ['index', 'length'])
spread(datashade(points).opts(height=800, width=800))

## 拆分数据集

根据实际得到的样本数量，以及上面的分析，我们决定这一次数据集的数量:

In [60]:
x = np.zeros(8, dtype=np.int8)
x[:5] = 1
x
np.random.shuffle(x)
x

array([1, 1, 1, 0, 1, 0, 1, 0], dtype=int8)

In [61]:
%%time

parts = {
    'train': {'n_sample': 10000000},
    'valid': {'n_sample': 50000},
    'test': {'n_sample': 10000},
}


for ds, dds in dataset.items():
    n_sample = parts[ds]['n_sample']
    n_total = dds['mid_lines']
    assert n_sample <= n_total

    mask = parts[ds]['mask'] = np.zeros(n_total, dtype=np.int8)
    mask[:n_sample] = 1
    np.random.shuffle(mask)
    assert len(mask) == n_total


CPU times: user 398 ms, sys: 27.3 ms, total: 426 ms
Wall time: 423 ms


## 输出文件

In [55]:
BASE_DEST_DIR = '../data/hmwebmix'

os.makedirs(BASE_DEST_DIR, exist_ok=True)

for ds, dds in dataset.items():
    dest_path = dds['dest_file'] = os.path.join(BASE_DEST_DIR, f'hmwebmix.{ds}.json')
    print(f'{ds}: {dest_path}')


train: ../data/hmwebmix/hmwebmix.train.json
valid: ../data/hmwebmix/hmwebmix.valid.json
test: ../data/hmwebmix/hmwebmix.test.json


In [63]:
for ds, dds in dataset.items():
    mask = parts[ds]['mask']
    n_total = dds['mid_lines']
    with open(dds['mid_file']) as fp_src, open(dds['dest_file'], 'w') as fp_dst:
        for s, b in tqdm(zip(fp_src, mask), f'sampling for {ds}', total=n_total):
            if b:
                print(s.strip(), file=fp_dst)

HBox(children=(IntProgress(value=0, description='sampling for train', max=13569050, style=ProgressStyle(descri…




HBox(children=(IntProgress(value=0, description='sampling for valid', max=69394, style=ProgressStyle(descripti…




HBox(children=(IntProgress(value=0, description='sampling for test', max=16804, style=ProgressStyle(descriptio…




查看：

In [66]:
files = ' '.join(v['dest_file'] for v in dataset.values())

!wc -l {files}
!echo "\n"
!du -hc {files}

   10000000 ../data/hmwebmix/hmwebmix.train.json
      50000 ../data/hmwebmix/hmwebmix.valid.json
      10000 ../data/hmwebmix/hmwebmix.test.json
   10060000 总用量


19G	../data/hmwebmix/hmwebmix.train.json
103M	../data/hmwebmix/hmwebmix.valid.json
20M	../data/hmwebmix/hmwebmix.test.json
19G	总用量
