In [3]:
import re
from tqdm import tqdm_notebook as tqdm
import codecs
from collections import defaultdict
import operator
import spacy
import _pickle as pickle
import torch
import numpy as np

## 训练集、测试集、验证集目录

In [24]:
path_src = ["test.txt.src", "train.txt.src", "val.txt.src"]
path_tgt = [
    "test.txt.tgt.tagged", "train.txt.tgt.tagged", "val.txt.tgt.tagged"
]

## 去标点噪音,转小写

In [12]:
def depunc(path):
    fr = codecs.open(path, encoding='utf-8')
    fw = codecs.open(path + '.clean', 'w', encoding='utf-8')
    for line in tqdm(fr):
        if line != "":
            line = line.lower()
            line = re.sub(r"--|-lrb-.*?-rrb- |'' |\"|`` |:|<t> |</t> |</t>",
                          "", line)
            line = re.sub(r"\s\.\s", r" . ", line)
            line = re.sub(r"\s\?\s", r" ? ", line)
            line = re.sub(r"\s\!\s", r" ! ", line)
            fw.write(line)
            if ord(line[-1]) != 10:
                fw.write("\n")
    fr.close()
    fw.close()

## 获得所有训练语料，用于训练词向量

In [13]:
def bow(path):
    fr = codecs.open(path, encoding='utf-8')
    fw = codecs.open('./data/corpus_total.txt', 'a', encoding='utf-8')
    for line in tqdm(fr):
        if line != "":
            fw.write(line)
    fr.close()
    fw.close()

## 调用depunc清洗数据

In [14]:
def prepare_for_model():
    for src in path_src:
        depunc("./data/" + src)
    for tgt in path_tgt:
        depunc("./data/" + tgt)

In [15]:
prepare_for_model()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




## 为fasttext训练词向量准备语料

In [16]:
def prepare_for_fasttext():
    bow("./data/train.txt.src.clean")
    bow("./data/train.txt.tgt.tagged.clean")

In [17]:
prepare_for_fasttext()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




## 建立词典

In [1]:
def build_dic():
    dic = defaultdict(int)
    with open("./data/corpus_total.txt", "r") as f:
        for line in tqdm(f):
            for word in line.split():
                dic[word] += 1
    return dic

In [4]:
dictionary = build_dic()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [5]:
print(len(dictionary))

718591


In [6]:
sorted_dic = sorted(
    dictionary.items(), key=operator.itemgetter(1), reverse=True)

## 建立映射表并保存
- 原文词典大小为40000
- 文摘词典大小为10000
- 0:PAD
- 1:EOS
- 2:UNK
- 3:STR
- source_input : W W W W UNK W W EOS PAD PAD PAD
- target_input : STR w w w UNK w EOS PAD

In [17]:
dict_size_src = 40000
dict_size_tgt = 10000

In [18]:
count = 4
word2id_src = dict()
id2word_src = dict()
word2id_src['PAD'] = 0
word2id_src['EOS'] = 1
word2id_src['UNK'] = 2
word2id_src['SOS'] = 3
id2word_src[0] = 'PAD'
id2word_src[1] = 'EOS'
id2word_src[2] = 'UNK'
id2word_src[3] = 'SOS'
for (k, v) in sorted_dic:
    word2id_src[k] = count
    id2word_src[count] = k
    count += 1
    if count == dict_size_src:
        break
pickle.dump(word2id_src, open("./data/word2id_src.dat", "wb"), True)
pickle.dump(id2word_src, open("./data/id2word_src.dat", "wb"), True)

In [19]:
count = 4
word2id_tgt = dict()
id2word_tgt = dict()
word2id_tgt['PAD'] = 0
word2id_tgt['EOS'] = 1
word2id_tgt['UNK'] = 2
word2id_tgt['SOS'] = 3
id2word_tgt[0] = 'PAD'
id2word_tgt[1] = 'EOS'
id2word_tgt[2] = 'UNK'
id2word_tgt[3] = 'SOS'
for (k, v) in sorted_dic:
    word2id_tgt[k] = count
    id2word_tgt[count] = k
    count += 1
    if count == dict_size_tgt:
        break
pickle.dump(word2id_tgt, open("./data/word2id_tgt.dat", "wb"), True)
pickle.dump(id2word_tgt, open("./data/id2word_tgt.dat", "wb"), True)

## 处理待训练语料
- 原文限制长度为400以内，标题限制在55以内
- 替换为one-hot下标
- 补上PAD、EOS、UNK、STR

In [20]:
SRC_LENGTH = 500
TGT_LENGTH = 70

In [21]:
def word_to_one_hot(path, output, word2id):
    if "tgt" in path:
        restrict_len = TGT_LENGTH
    else:
        restrict_len = SRC_LENGTH
    one_hot_matrix = []
    with open(path, "r") as f:
        for line in tqdm(f):
            one_hot_list = [
                word2id[word] if word in word2id else 2
                for word in line.split(' ')[:-1]
            ]
            if "tgt" in path and output:
                one_hot_list.insert(0, 3)
            one_hot_list = one_hot_list[:restrict_len - 1]
            one_hot_list.append(1)
            if len(one_hot_list) < restrict_len:
                for _ in range(restrict_len - len(one_hot_list)):
                    one_hot_list.append(0)
            one_hot_matrix.append(one_hot_list)
    return one_hot_matrix

## tgt处理两次，
- .onehot后缀是加了开始符号，整体右移一个单位。作为decoder输入
- .gold后缀是原始语料，作为gold output计算损失

In [22]:
def one_hot_for_model():
    for src in path_src:
        matrix = word_to_one_hot(
            "./data/" + src + ".clean", output=False, word2id=word2id_src)
        matrix = np.asarray(matrix)
        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + src + ".onehot", "wb"),
            True)
    for tgt in path_tgt:
        matrix = word_to_one_hot(
            "./data/" + tgt + ".clean", output=False, word2id=word2id_tgt)
        matrix = np.asarray(matrix)

        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + tgt + ".gold", "wb"),
            True)
    for tgt in path_tgt:
        matrix = word_to_one_hot(
            "./data/" + tgt + ".clean", output=True, word2id=word2id_tgt)
        matrix = np.asarray(matrix)

        pickle.dump(
            torch.from_numpy(matrix), open("./data/" + tgt + ".onehot", "wb"),
            True)

In [25]:
one_hot_for_model()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [26]:
matrix = pickle.load(open("./data/test.txt.tgt.tagged.onehot", "rb"))
print(len(matrix))

11490


In [27]:
print(matrix[0])

tensor([   3,    2, 1781,   99,   72,  364,   84, 1824,   44,  167,   11,    5,
         886,  413,  326,  387,  571,    4, 2273,   25,    2,    9, 1380,  563,
          35,  122, 2240,    5,  153, 4100,   19,  374,    6,   38, 2212,   99,
           4,    2,    2,   40, 3167,   24,    2,  591,  168,   10,   38, 2843,
          10, 1638, 2918,    6, 2379,   99,    4,    1,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0])


In [28]:
s = [id2word[id] for id in matrix[0].numpy()]
print(' '.join(s))

SOS UNK prosecutor says so far no videos were used in the crash investigation despite media reports . journalists at UNK and paris match are very confident the video clip is real , an editor says . UNK UNK had informed his UNK training school of an episode of severe depression , airline says . EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD


## 统计文摘的长度作为mask传给模型
- 方便计算maskedNLLLoss
- 是截断后文摘的长度

In [29]:
def make_len_mask(path):
    len_mask = []
    onehot = pickle.load(open(path, "rb"))
    for sentence_onehot in tqdm(onehot):
        count = 0
        for i in sentence_onehot:
            count += 1
            if i == 1:
                break
        len_mask.append(count)
    len_mask = torch.from_numpy(np.asarray(len_mask))
    pickle.dump(len_mask, open("./data/" + tgt + ".mask", "wb"), True)

In [30]:
for tgt in path_tgt:
    make_len_mask("./data/" + tgt + ".onehot")

HBox(children=(IntProgress(value=0, max=11490), HTML(value='')))




HBox(children=(IntProgress(value=0, max=287227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13368), HTML(value='')))




In [31]:
mask_test = pickle.load(open("./data/test.txt.tgt.tagged.mask", "rb"))
print(mask_test)

tensor([56, 39, 63,  ..., 43, 70, 38])


## 测试batch_loader数据

In [32]:
id2word = pickle.load(open("./data/id2word.dat", "rb"))

In [33]:
sample_encoder_input = [
    5, 105, 15, 154, 463, 3618, 2433, 1873, 813, 87, 411, 2810, 9, 9172, 5383,
    67, 122, 46, 8, 3875, 116, 174, 5, 2, 473, 779, 142, 7, 665, 29, 5, 307,
    14, 205, 774, 4, 2, 2, 6, 5441, 6, 7, 3158, 403, 24, 8, 2955, 97, 38, 267,
    9, 5, 273, 3618, 9329, 15, 742, 1541, 774, 2, 2, 31, 115, 1345, 151, 31,
    311, 897, 31, 30, 347, 8, 5321, 2810, 15, 2932, 4, 46, 33, 347, 5, 3618,
    500, 1168, 6, 5, 627, 762, 8, 3443, 1873, 16, 813, 4928, 30, 1104, 2074, 9,
    1639, 37, 13, 259, 45, 12, 160, 24, 575, 4133, 12, 160, 11, 8, 3028, 4,
    1865, 132, 665, 2, 2, 2, 2, 6, 5441, 6, 7, 5, 67, 813, 12, 8, 1873, 11,
    2287, 3618, 3028, 2, 2, 5, 105, 15, 154, 200, 1873, 9, 702, 4, 33, 347, 8,
    3618, 98, 31, 440, 565, 113, 25, 3618, 9, 8, 575, 12, 160, 4, 2, 259, 2, 2,
    2, 13, 386, 30, 3618, 831, 18, 36, 11, 31, 8685, 928, 9, 565, 4, 33, 347,
    8, 3618, 500, 1168, 887, 9, 25, 9379, 4, 19, 5, 74, 11, 2, 15, 3028, 6,
    559, 2510, 12, 172, 106, 813, 210, 226, 153, 3028, 4, 2, 6, 37, 15, 729,
    3618, 831, 36, 11, 31, 142, 6, 347, 31, 3618, 38, 31, 440, 2, 2, 37, 7, 27,
    8, 403, 18, 31, 2, 2, 113, 25, 3618, 9, 652, 1177, 12, 160, 4, 12, 267, 6,
    2, 347, 8, 2932, 30, 733, 5605, 6, 2901, 6, 37, 132, 413, 11, 2955, 6, 112,
    2214, 265, 4, 5605, 15, 270, 8024, 15, 1132, 299, 3596, 427, 31, 9, 160, 8,
    60, 3618, 6, 48, 33, 347, 887, 9, 31, 440, 15, 2433, 4, 20, 78, 23, 53, 18,
    134, 9, 113, 85, 3618, 9, 2401, 12, 359, 9, 64, 8024, 8, 93, 1358, 3618, 6,
    28, 17, 7, 1141, 9, 39, 23, 6, 17, 5605, 63, 2214, 6, 48, 42, 417, 8, 493,
    54, 5, 307, 14, 284, 1873, 19, 2, 19, 378, 133, 4, 5, 67, 2486, 11, 34,
    2027, 1873, 264, 26, 5, 368, 11, 585, 602, 12, 4778, 46, 2, 2, 2192, 8,
    3618, 19, 2756, 11, 31, 553, 6, 37, 50, 395, 347, 8, 3028, 4, 31, 3618, 7,
    336, 3428, 606, 9, 2, 619, 226, 12, 6206, 6, 1
]

In [34]:
sample_decoder_input = [
    3, 5, 142, 14, 665, 2013, 665, 132, 29, 329, 2810, 9, 113, 3618, 9, 4928,
    12, 359, 18, 8, 259, 45, 9, 347, 8, 3618, 4, 4135, 14, 274, 226, 7, 416,
    12, 5, 307, 14, 205, 1873, 75, 5, 706, 11, 111, 122, 4, 3875, 5383, 2, 2,
    6, 5441, 6, 174, 5, 411, 813, 12, 5, 1873, 500, 1168, 4, 2, 1, 0, 0, 0, 0,
    0
]

In [35]:
sample_decoder_gold = [
    0, 0, 3, 1601, 2, 6, 2945, 6, 13, 7, 7829, 2569, 342, 151, 1496, 4, 89, 13,
    9, 230, 886, 9, 3014, 227, 322, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0
]

In [36]:
sample_decoder_output = [
    0, 5, 4, 3411, 2281, 4, 2, 7, 2, 3621, 10, 2, 3621, 484, 5, 2, 2, 4, 10, 0,
    5, 10, 4, 2, 7, 2, 7, 1259, 2, 7063, 1259, 2, 7, 1259, 12, 4, 2, 4, 7,
    1421, 3621, 3771, 3771, 9, 1259, 1259, 1259, 2, 4, 1285, 2, 10, 7874, 28,
    4, 688, 4, 12, 2, 2, 9089, 9089, 1259, 7063, 688, 2, 9089, 9089, 1259,
    7063, 0, 0, 0, 0
]

In [37]:
sample_cur_tgt_batch = [
    5, 7753, 2, 7527, 2369, 70, 2527, 2, 2088, 2, 1848, 2547, 2, 453, 901, 747,
    2, 164, 116, 4502, 2, 5, 7982, 3436, 2, 1925, 521, 2, 5379, 4963, 2545,
    559, 1335, 2, 5, 2, 9139, 5, 2, 1749, 2149, 4888, 6916, 797, 5150, 5587,
    7723, 2, 2334, 1130, 1215, 33, 5, 423, 2, 3403, 3162, 4685, 2689, 1480, 45,
    796, 1334, 717, 2086, 6788, 4369, 2, 2, 2444, 671, 2, 8719, 328, 5093,
    7673, 1556, 2008, 855, 4540, 1490, 2, 6391, 2, 1057, 747, 5, 5619, 620,
    3984, 1268, 655, 176, 661, 2, 2, 527, 7999, 8, 1749
]

In [38]:
s = [id2word[id] for id in sample_encoder_input]
print(' '.join(s))

the only that see enough cap ian emotional capital them others grounds and raul aerial can very their a button most under the UNK political similar u.s. to european have the daughter for away station . UNK UNK , peterson , to interesting britain his a moves did an six and the asked cap interrupted that interview midfielder station UNK UNK has then feeling house has again experts has from using a violated grounds that solar . their be using the cap november jackson , the center christmas a cats emotional ' capital radiation from seems 2001 and fast they was money been 's these his line shore 's these in a explains . shopping man european UNK UNK UNK UNK , peterson , to the can capital 's a emotional in weapon cap explains UNK UNK the only that see london emotional and doctors . be using a cap down has 2011 let made at cap and a line 's these . UNK money UNK UNK UNK was arrested from cap injuries ` she in has bonds view and let . be using a cap november jackson longer and at disrupt . is

In [39]:
s = [id2word[id] for id in sample_decoder_input]
print(' '.join(s))

SOS the u.s. for european tells european man have military grounds and made cap and radiation 's behind ` a money been and using a cap . decline for face car to together 's the daughter for away emotional time the brought in found very . button aerial UNK UNK , peterson , under the others capital 's the emotional november jackson . UNK EOS PAD PAD PAD PAD PAD


In [40]:
s = [id2word[id] for id in sample_decoder_gold]
print(' '.join(s))

PAD PAD SOS ceremony UNK , understanding , was to ivf stabbed minister house ways . now was and woman crash and conspiracy young several . UNK EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD


In [41]:
s = [id2word[id] for id in sample_cur_tgt_batch]
print(' '.join(s))

the agony UNK scattered tennis into native UNK prosecution UNK critical auction UNK point operation race UNK country most pledged UNK the genuinely capacity UNK confidence tried UNK creatures violation controversy leader raise UNK the UNK reversed the UNK taylor morgan operated pga board assaulting ownership parenting UNK rich weapons eye be the dead UNK horror maximum brussels rugby contract been gone mind huge birmingham fitting spencer UNK UNK woods private UNK squeeze local producers cyprus commission criticism returned mosque strike UNK 1985 UNK sir race the infamous fact keith higher miles death justice UNK UNK thing encourages a taylor


In [42]:
s = [id2word[id] for id in sample_decoder_output]
print(' '.join(s))

PAD the . celtic busy . UNK to UNK debris of UNK debris september the UNK UNK . of PAD the of . UNK to UNK to rise UNK wiltshire rise UNK to rise 's . UNK . to century debris 1970s 1970s and rise rise rise UNK . managed UNK of belongs by . 25 . 's UNK UNK answering answering rise wiltshire 25 UNK answering answering rise wiltshire PAD PAD PAD PAD


# 统计语料

In [43]:
train_tgt = pickle.load(open("./data/train.txt.tgt.tagged.onehot", "rb"))
vali_tgt = pickle.load(open("./data/val.txt.tgt.tagged.onehot", "rb"))
test_tgt = pickle.load(open("./data/test.txt.tgt.tagged.onehot", "rb"))

In [44]:
print(len(train_tgt))

287227


In [45]:
print(len(vali_tgt))

13368


In [46]:
print(len(test_tgt))

11490


# 从onehot中恢复出文本
- 即获得截断长度的文本，另存为一份语料，用于fairseq训练

In [2]:
parts = ['train', 'val', 'test']

In [52]:
for p in parts:
    with open("./data/fairseq/" + p + ".src", "w") as f:
        onehot = pickle.load(open("./data/" + p + ".txt.src.onehot", "rb"))
        for sentence in tqdm(onehot):
            s = [id2word[id] for id in sentence.numpy()]
            f.write(' '.join(s) + "\n")

HBox(children=(IntProgress(value=0, max=287227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13368), HTML(value='')))




HBox(children=(IntProgress(value=0, max=11490), HTML(value='')))




In [53]:
for p in parts:
    with open("./data/fairseq/" + p + ".tgt", "w") as f:
        onehot = pickle.load(
            open("./data/" + p + ".txt.tgt.tagged.onehot", "rb"))
        for sentence in tqdm(onehot):
            s = [id2word[id] for id in sentence.numpy()]
            f.write(' '.join(s) + "\n")

HBox(children=(IntProgress(value=0, max=287227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=13368), HTML(value='')))




HBox(children=(IntProgress(value=0, max=11490), HTML(value='')))




# 直接将clean语料截断

In [15]:
for p in parts:
    with open("./data/fairseq/" + p + ".src", "w") as fw:
        fr = open("./data/" + p + ".txt.src.clean", "r")
        for sentence in tqdm(fr):
            s = ' '.join(sentence.split(' ')[:SRC_LENGTH])
            fw.write(s)
            if ord(s[-1]) != 10:
                fw.write("\n")
        fr.close()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [16]:
for p in parts:
    with open("./data/fairseq/" + p + ".tgt", "w") as fw:
        fr = open("./data/" + p + ".txt.tgt.tagged.clean", "r")
        for sentence in tqdm(fr):
            s = ' '.join(sentence.split(' ')[:TGT_LENGTH])
            fw.write(s)
            if ord(s[-1]) != 10:
                fw.write("\n")
        fr.close()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


