In [1]:
import os
import struct
import collections
from tensorflow.core.example import example_pb2

TRAIN_FILE = "./data/train_art_summ_prep_os.txt"
VAL_FILE = "./data/val_art_summ_prep_os.txt"

SENTENCE_START = '<s>'
SENTENCE_END = '</s>'

VOCAB_SIZE = 50000  
CHUNK_SIZE = 1000  

FINISHED_FILE_DIR = './data/datav4/finished_files'
CHUNKS_DIR = os.path.join(FINISHED_FILE_DIR, 'chunked')


def chunk_file(finished_files_dir, chunks_dir, name, chunk_size):
    in_file = os.path.join(finished_files_dir, '%s.bin' % name)
    print(in_file)
    reader = open(in_file, "rb")
    chunk = 0
    finished = False
    while not finished:
        chunk_fname = os.path.join(chunks_dir, '%s_%03d.bin' % (name, chunk)) 
        with open(chunk_fname, 'wb') as writer:
            for _ in range(chunk_size):
                len_bytes = reader.read(8)
                if not len_bytes:
                    finished = True
                    break
                str_len = struct.unpack('q', len_bytes)[0]
                example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
                writer.write(struct.pack('q', str_len))
                writer.write(struct.pack('%ds' % str_len, example_str))
            chunk += 1


def chunk_all():
    if not os.path.isdir(CHUNKS_DIR):
        os.mkdir(CHUNKS_DIR)
    for name in ['train', 'val']:
        print("Splitting %s data into chunks..." % name)
        chunk_file(FINISHED_FILE_DIR, CHUNKS_DIR, name, CHUNK_SIZE)
    print("Saved chunked data in %s" % CHUNKS_DIR)


def read_text_file(text_file):
    lines = []
    with open(text_file, "r", encoding='utf-8') as f:
        for line in f:
            lines.append(line.strip())
    return lines


def write_to_bin(input_file, out_file, makevocab=False):
    if makevocab:
        vocab_counter = collections.Counter()

    with open(out_file, 'wb') as writer:
        lines = read_text_file(input_file)
        for i, new_line in enumerate(lines):
            if i % 4 == 0:
                results = lines[i]
            if i % 4 == 1:
                requests = lines[i]
            if i % 4 == 2:
                article = lines[i]
            if i % 4 == 3:
                abstract = "%s %s %s" % (SENTENCE_START, lines[i], SENTENCE_END)

                # 写入tf.Example
                tf_example = example_pb2.Example()
                tf_example.features.feature['results'].bytes_list.value.extend([bytes(results, encoding='utf-8')])
                tf_example.features.feature['requests'].bytes_list.value.extend([bytes(requests, encoding='utf-8')])
                tf_example.features.feature['article'].bytes_list.value.extend([bytes(article, encoding='utf-8')])
                tf_example.features.feature['abstract'].bytes_list.value.extend([bytes(abstract, encoding='utf-8')])
                tf_example_str = tf_example.SerializeToString()
                str_len = len(tf_example_str)
                writer.write(struct.pack('q', str_len))
                writer.write(struct.pack('%ds' % str_len, tf_example_str))

                if makevocab:
                    rst_tokens = results.split(' ')
                    rqs_tokens = requests.split(' ')
                    art_tokens = article.split(' ')
                    abs_tokens = abstract.split(' ')
                    abs_tokens = [t for t in abs_tokens if
                                  t not in [SENTENCE_START, SENTENCE_END]]  
                    tokens = rst_tokens + rqs_tokens + art_tokens + abs_tokens
                    tokens = [t.strip() for t in tokens]  
                    tokens = [t for t in tokens if t != ""]
                    vocab_counter.update(tokens)

    print("Finished writing file %s\n" % out_file)

    if makevocab:
        print("Writing vocab file...")
        with open(os.path.join(FINISHED_FILE_DIR, "vocab"), 'w', encoding='utf-8') as writer:
            for word, count in vocab_counter.most_common(VOCAB_SIZE):
                writer.write(word + ' ' + str(count) + '\n')
        print("Finished writing vocab file")


if __name__ == '__main__':
    if not os.path.exists(FINISHED_FILE_DIR):
        os.makedirs(FINISHED_FILE_DIR)
    write_to_bin(TRAIN_FILE, os.path.join(FINISHED_FILE_DIR, "train.bin"), makevocab=True)
    write_to_bin(VAL_FILE, os.path.join(FINISHED_FILE_DIR, "val.bin"))

    chunk_all()


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Finished writing file ./data/datav4/finished_files/train.bin

Writing vocab file...
Finished writing vocab file
Finished writing file ./data/datav4/finished_files/val.bin

Splitting train data into chunks...
./data/datav4/finished_files/train.bin
Splitting val data into chunks...
./data/datav4/finished_files/val.bin
Saved chunked data in ./data/datav4/finished_files/chunked
