# 按行采样和划分语料

## Importings

In [1]:
import os
import random
from contextlib import ExitStack 
from glob import glob
from itertools import chain, repeat
from multiprocessing import Pool
from pathlib import Path

import numpy as np
from tqdm.auto import tqdm

## Constants

In [2]:
CORPUS_FILES = ['../data/wikipedia/wikidump_lines.json']

## Functions

In [3]:
def get_lines(path, show_progress=False):
    with open(path) as fp:
        return sum(1 for line in tqdm(fp , disable=(not show_progress)))

def get_total_lines(paths, show_progress=True):
    if isinstance(paths, (str, Path)):
        paths = [paths]
    try:
        total = len(paths)
    except TypeError:
        total = None
    processes = max(os.cpu_count(), len(paths))
    with Pool(processes) as pool:
        it = pool.imap_unordered(
            get_lines,
            tqdm(paths, 'Get lines - mapping', unit='file', disable=(not show_progress))
        )
        return sum(
            tqdm(it, 'Get lines - reducing', total=total, unit='file', disable=(not show_progress))
        )
        
def iterate_lines(paths, show_progress=True):
    if isinstance(paths, (str, Path)):
        paths = [paths]
    return chain.from_iterable(
        open(path)
        for path
        in tqdm(paths, unit='file', disable=(not show_progress))
    )



In [4]:
total_lines = get_total_lines(CORPUS_FILES)
print(f'{total_lines}')

HBox(children=(IntProgress(value=0, description='Get lines - mapping', max=1, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Get lines - reducing', max=1, style=ProgressStyle(description…



977539


## Number of samples

In [5]:
parts = {
    'dev': 100,
    'val': 10,
}

mask = []
for name, number in parts.items():
    mask.extend(repeat(name, number))
mask.extend(repeat(None, total_lines-len(mask)))
random.shuffle(mask)

assert len(mask) == total_lines

## Sample to output files

In [7]:
output_files = (
    (name, '../data/baike.{}.json'.format(name))
    for name in parts.keys()
)
with ExitStack() as stack:
    fp_dict = {
        name: stack.enter_context(open(fname, 'w'))
        for name, fname in output_files
    }
    for m, line in tqdm(
        zip(mask, iterate_lines(CORPUS_FILES, False)),
        total=total_lines,
        unit='line',
    ):
        fp = fp_dict.get(m)
        if fp:
            fp.write(line)


HBox(children=(IntProgress(value=0, max=977539), HTML(value='')))


