In [119]:
import pandas as pd
import numpy as np
import torch
import time
import random
import os
import jieba
import re

## 加载数据集

In [120]:
train_file = open("./LCSTS_ORIGIN/DATA/PART_I.txt", 'r')
train_file = train_file.read().replace('\n','').replace(' ','')
train_label = re.findall('<summary>(.*?)</summary', train_file)
train_text = re.findall('<short_text>(.*?)</short_text>', train_file)

valid_file = open("./LCSTS_ORIGIN/DATA/PART_II.txt", 'r')
valid_file = valid_file.read().replace('\n','').replace(' ','')
valid_label = re.findall('<summary>(.*?)</summary', valid_file)
valid_text = re.findall('<short_text>(.*?)</short_text>', valid_file)

test_file = open("./LCSTS_ORIGIN/DATA/PART_III.txt", 'r')
test_file = test_file.read().replace('\n','').replace(' ','')
test_label = re.findall('<summary>(.*?)</summary', test_file)
test_text = re.findall('<short_text>(.*?)</short_text>', test_file)

In [121]:
train_label = train_label[0:5000]
train_text = train_text[0:5000]

## 使用torchtext构建数据集

In [122]:
from torchtext import data
from torchtext.vocab import Vectors
from torch.nn import init
from tqdm import tqdm

In [123]:
tokenize = lambda x: jieba.lcut(x)
TEXT = data.Field(sequential=True, tokenize=tokenize, fix_length=100, use_vocab=True)
LABEL = data.Field(sequential=True, tokenize=tokenize, fix_length=20, use_vocab=True)

In [124]:
def get_dataset(txt_data, txt_label, text_field, label_field, test=False):
    fields = [("text", text_field), ("summary", label_field)]
    examples = []
    if test:
        for text in tqdm(txt_data):
            examples.append(data.Example.fromlist([text, None], fields))
    else:
        for text, label in tqdm(zip(txt_data, txt_label)):
            examples.append(data.Example.fromlist([text, label], fields))
    return examples, fields

# 得到构建Dataset所需的examples和fields
train_examples, train_fields = get_dataset(train_text, train_label, TEXT, LABEL)
valid_examples, valid_fields = get_dataset(valid_text, valid_label, TEXT, LABEL)
test_examples, test_fields = get_dataset(test_text, None, TEXT, None, test=True)
# 构建Dataset数据集
train = data.Dataset(train_examples, train_fields)
valid = data.Dataset(valid_examples, valid_fields)
test = data.Dataset(test_examples, test_fields)

5000it [00:03, 1605.34it/s]
10666it [00:07, 1466.32it/s]
100%|██████████| 1106/1106 [00:00<00:00, 1716.54it/s]


In [150]:
print(len(train[0].text))

59


## 构建词表

In [126]:
TEXT.build_vocab(train)
LABEL.build_vocab(train)

In [149]:
TEXT.vocab.freqs.most_common(10)
LABEL.vocab.freqs.most_common(10)

[('：', 887),
 ('“', 752),
 ('”', 743),
 ('的', 724),
 ('海南', 566),
 ('海口', 451),
 ('大', 398),
 ('数据', 397),
 ('！', 369),
 ('？', 351)]

## 构建数据集迭代器

In [148]:
from torchtext.data import Iterator, BucketIterator
# 同时对训练集和验证集进行迭代器的构建
train_iter, val_iter = BucketIterator.splits(
        (train, valid), # 构建数据集所需的数据集
        batch_sizes=(32, 32),
        device=6, # 如果使用gpu，此处将-1更换为GPU的编号
        sort_key=lambda x: len(x.text), # the BucketIterator needs to be told what function it should use to group the data.
        sort_within_batch=False,
        repeat=False # we pass repeat=False because we want to wrap this Iterator layer.
)

test_iter = Iterator(test, batch_size=32, device=6, sort=False, sort_within_batch=False, repeat=False)

The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


In [129]:
for idx, batch in enumerate(train_iter):
    print(batch)
    text, label = batch.text, batch.summary
    print(text.shape, label.shape, torch.t(batch.summary).shape)


[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of siz


[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 100x32]
	[.summary]:[torch.LongTensor of size 20x32]
torch.Size([100, 32]) torch.Size([20, 32]) torch.Size([32, 20])

[torchtext.data.batch.Batch of siz