In [None]:
import re
from tqdm import tqdm_notebook as tqdm
import codecs
from collections import defaultdict
import operator
import spacy
import _pickle as pickle
import torch
import numpy as np

## 训练集、测试集、验证集目录

In [None]:
path_src = ["test.txt.src", "train.txt.src", "val.txt.src"]
path_tgt = [
    "test.txt.tgt.tagged", "train.txt.tgt.tagged", "val.txt.tgt.tagged"
]

## 去标点噪音,转小写

In [None]:
def depunc(path):
    fr = codecs.open(path, encoding='utf-8')
    fw = codecs.open(path + '.clean', 'w', encoding='utf-8')
    for line in tqdm(fr):
        if line != "":
            line = line.lower()
            line = re.sub(r"--|-lrb-.*?-rrb- |'' |\"|`` |:|<t> |</t> |</t>",
                          "", line)
            line = re.sub(r"\s\.\s", r" . ", line)
            line = re.sub(r"\s\?\s", r" ? ", line)
            line = re.sub(r"\s\!\s", r" ! ", line)
            fw.write(line)
            if ord(line[-1]) != 10:
                fw.write("\n")
    fr.close()
    fw.close()

## 获得所有训练语料，用于训练词向量

In [None]:
def bow(path):
    fr = codecs.open(path, encoding='utf-8')
    fw = codecs.open('./data/corpus_total.txt', 'a', encoding='utf-8')
    for line in tqdm(fr):
        if line != "":
            fw.write(line)
    fr.close()
    fw.close()

## 调用depunc清洗数据

In [None]:
def prepare_for_model():
    for src in path_src:
        depunc("./data/" + src)
    for tgt in path_tgt:
        depunc("./data/" + tgt)

In [None]:
prepare_for_model()

## 为fasttext训练词向量准备语料

In [None]:
def prepare_for_fasttext():
    bow("./data/train.txt.src.clean")
    bow("./data/train.txt.tgt.tagged.clean")

In [None]:
prepare_for_fasttext()

## 建立词典

In [None]:
def build_dic():
    dic = defaultdict(int)
    with open("./data/corpus_total.txt", "r") as f:
        for line in tqdm(f):
            for word in line.split():
                dic[word] += 1
    return dic

In [None]:
dictionary = build_dic()

In [None]:
print(len(dictionary))

In [None]:
sorted_dic = sorted(
    dictionary.items(), key=operator.itemgetter(1), reverse=True)

## 建立映射表并保存
- 原文词典大小为40000
- 文摘词典大小为10000
- 0:PAD
- 1:EOS
- 2:UNK
- 3:STR
- source_input : W W W W UNK W W EOS PAD PAD PAD
- target_input : STR w w w UNK w EOS PAD

In [None]:
dict_size_src = 40000
dict_size_tgt = 10000

In [None]:
count = 4
word2id_src = dict()
id2word_src = dict()
word2id_src['PAD'] = 0
word2id_src['EOS'] = 1
word2id_src['UNK'] = 2
word2id_src['SOS'] = 3
id2word_src[0] = 'PAD'
id2word_src[1] = 'EOS'
id2word_src[2] = 'UNK'
id2word_src[3] = 'SOS'
for (k, v) in sorted_dic:
    word2id_src[k] = count
    id2word_src[count] = k
    count += 1
    if count == dict_size_src:
        break
pickle.dump(word2id_src, open("./data/word2id_src.dat", "wb"), True)
pickle.dump(id2word_src, open("./data/id2word_src.dat", "wb"), True)

In [None]:
count = 4
word2id_tgt = dict()
id2word_tgt = dict()
word2id_tgt['PAD'] = 0
word2id_tgt['EOS'] = 1
word2id_tgt['UNK'] = 2
word2id_tgt['SOS'] = 3
id2word_tgt[0] = 'PAD'
id2word_tgt[1] = 'EOS'
id2word_tgt[2] = 'UNK'
id2word_tgt[3] = 'SOS'
for (k, v) in sorted_dic:
    word2id_tgt[k] = count
    id2word_tgt[count] = k
    count += 1
    if count == dict_size_tgt:
        break
pickle.dump(word2id_tgt, open("./data/word2id_tgt.dat", "wb"), True)
pickle.dump(id2word_tgt, open("./data/id2word_tgt.dat", "wb"), True)

## 处理待训练语料
- 原文限制长度为400以内，标题限制在55以内
- 替换为one-hot下标
- 补上PAD、EOS、UNK、STR

In [None]:
SRC_LENGTH = 500
TGT_LENGTH = 70

In [None]:
def word_to_one_hot(path, output, word2id):
    if "tgt" in path:
        restrict_len = TGT_LENGTH
    else:
        restrict_len = SRC_LENGTH
    one_hot_matrix = []
    with open(path, "r") as f:
        for line in tqdm(f):
            one_hot_list = [
                word2id[word] if word in word2id else 2
                for word in line.split(' ')[:-1]
            ]
            if "tgt" in path and output:
                one_hot_list.insert(0, 3)
            one_hot_list = one_hot_list[:restrict_len - 1]
            one_hot_list.append(1)
            if len(one_hot_list) < restrict_len:
                for _ in range(restrict_len - len(one_hot_list)):
                    one_hot_list.append(0)
            one_hot_matrix.append(one_hot_list)
    return one_hot_matrix

## tgt处理两次，
- .onehot后缀是加了开始符号，整体右移一个单位。作为decoder输入
- .gold后缀是原始语料，作为gold output计算损失

In [None]:
def one_hot_for_model():
    for src in path_src:
        matrix = word_to_one_hot(
            "./data/" + src + ".clean", output=False, word2id=word2id_src)
        matrix = np.asarray(matrix)
        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + src + ".onehot", "wb"),
            True)
    for tgt in path_tgt:
        matrix = word_to_one_hot(
            "./data/" + tgt + ".clean", output=False, word2id=word2id_tgt)
        matrix = np.asarray(matrix)

        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + tgt + ".gold", "wb"),
            True)
    for tgt in path_tgt:
        matrix = word_to_one_hot(
            "./data/" + tgt + ".clean", output=True, word2id=word2id_tgt)
        matrix = np.asarray(matrix)

        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + tgt + ".onehot", "wb"),
            True)

In [None]:
one_hot_for_model()

In [None]:
matrix = pickle.load(open("./data/test.txt.tgt.tagged.onehot", "rb"))
print(len(matrix))

In [None]:
print(matrix[0])

In [None]:
s = [id2word[id] for id in matrix[0].numpy()]
print(' '.join(s))

## 统计文摘的长度作为mask传给模型
- 方便计算maskedNLLLoss
- 是截断后文摘的长度

In [None]:
def make_len_mask(path):
    len_mask = []
    onehot = pickle.load(open(path, "rb"))
    for sentence_onehot in tqdm(onehot):
        count = 0
        for i in sentence_onehot:
            count += 1
            if i == 1:
                break
        len_mask.append(count)
    len_mask = torch.from_numpy(np.asarray(len_mask))
    pickle.dump(len_mask, open("./data/" + tgt + ".mask", "wb"), True)

In [None]:
for tgt in path_tgt:
    make_len_mask("./data/" + tgt + ".onehot")

In [None]:
mask_test = pickle.load(open("./data/test.txt.tgt.tagged.mask", "rb"))
print(mask_test)

# 统计语料

In [None]:
train_tgt = pickle.load(open("./data/train.txt.tgt.tagged.onehot", "rb"))
vali_tgt = pickle.load(open("./data/val.txt.tgt.tagged.onehot", "rb"))
test_tgt = pickle.load(open("./data/test.txt.tgt.tagged.onehot", "rb"))

In [None]:
print(len(train_tgt))

In [None]:
print(len(vali_tgt))

In [None]:
print(len(test_tgt))

# 从onehot中恢复出文本
- 即获得截断长度的文本，另存为一份语料，用于fairseq训练

In [None]:
parts = ['train', 'val', 'test']

In [None]:
for p in parts:
    with open("./data/fairseq/" + p + ".src", "w") as f:
        onehot = pickle.load(open("./data/" + p + ".txt.src.onehot", "rb"))
        for sentence in tqdm(onehot):
            s = [id2word[id] for id in sentence.numpy()]
            f.write(' '.join(s) + "\n")

In [None]:
for p in parts:
    with open("./data/fairseq/" + p + ".tgt", "w") as f:
        onehot = pickle.load(
            open("./data/" + p + ".txt.tgt.tagged.onehot", "rb"))
        for sentence in tqdm(onehot):
            s = [id2word[id] for id in sentence.numpy()]
            f.write(' '.join(s) + "\n")

# 直接将clean语料截断

In [None]:
for p in parts:
    with open("./data/fairseq/" + p + ".src", "w") as fw:
        fr = open("./data/" + p + ".txt.src.clean", "r")
        for sentence in tqdm(fr):
            s = ' '.join(sentence.split(' ')[:SRC_LENGTH])
            fw.write(s)
            if ord(s[-1]) != 10:
                fw.write("\n")
        fr.close()

In [None]:
for p in parts:
    with open("./data/fairseq/" + p + ".tgt", "w") as fw:
        fr = open("./data/" + p + ".txt.tgt.tagged.clean", "r")
        for sentence in tqdm(fr):
            s = ' '.join(sentence.split(' ')[:TGT_LENGTH])
            fw.write(s)
            if ord(s[-1]) != 10:
                fw.write("\n")
        fr.close()