In [1]:
# 用来做数据处理和生成训练数据
import codecs
import collections
import tensorflow as tf
import numpy as np

  from ._conv import register_converters as _register_converters


In [2]:
from operator import itemgetter

In [3]:
RAW_DATA = './simple-examples/data/ptb.train.txt' # 训练集数据
VOCAB_OUTPUT = 'ptb.vocab'
counter = collections.Counter()

In [4]:
with codecs.open(RAW_DATA, 'r', 'utf-8') as f:
    for line in f:
        for word in line.strip().split():
            counter[word] += 1

# 按照词频进行统计
sorted_word_to_cnt = sorted(counter.items(), key = itemgetter(1), reverse=True)
sorted_words = [x[0] for x in sorted_word_to_cnt]
sorted_words = ["<eos>"] + sorted_words

with codecs.open(VOCAB_OUTPUT, 'w', 'utf-8') as file_output:
    for word in sorted_words:
        file_output.write(word + '\n')

In [5]:
# 将原文件中的语料转化为单词编号表示的语料
TRAIN_DATA = './simple-examples/data/ptb.train.txt' # 训练集数据
TEST_DATA = './simple-examples/data/ptb.test.txt' # 测试集数据
VALID_DATA = './simple-examples/data/ptb.valid.txt' # 测试集数据
OUTPUT_TRAIN_DATA = 'ptb.train'
OUTPUT_TEST_DATA = 'ptb.test'
OUTPUT_VALID_DATA = 'ptb.valid'

In [6]:
# 建立word2idx
word2idx={}
id = 0
with codecs.open(VOCAB_OUTPUT) as vocab_file:
    for line in vocab_file:
        word2idx[line.strip()] = id
        id += 1

def getid(word):
    '''
    通过单词获取id
    '''
    return word2idx[word] if word in word2idx else word2idx['<unk>']

In [7]:
def convert(raw_data_path, output_data_path):
    '''
    做文本转换的函数，将原始文本转化为编号表示单词的文本
    '''
    fin = codecs.open(raw_data_path, 'r', 'utf-8')
    fout = codecs.open(output_data_path, 'w', 'utf-8')
    for line in fin:
        words = line.strip().split() + ["<eos>"]
        out_line = ' '.join([str(getid(word)) for word in words]) + '\n'
        fout.write(out_line)
    fin.close()
    fout.close()

In [8]:
convert(TRAIN_DATA, OUTPUT_TRAIN_DATA)
convert(TEST_DATA, OUTPUT_TEST_DATA)
convert(VALID_DATA, OUTPUT_VALID_DATA)

In [10]:
TRAIN_BATCH_SIZE = 20 # 训练batch size
TRAIN_NUM_STEP = 35

def read_data(file_path):
    '''
    读取数据，返回包含单词编号的数组，一整个文本的内容作为一个数组返回，每行句子拼接起来
    '''
    with open(file_path, 'r') as fin:
        id_string = ' '.join([line.strip() for line in fin.readlines()])
    id_list = [int(w) for w in id_string.split()]
    return id_list

In [34]:
def make_batches(id_list, batch_size, num_step):
    '''
    获取到batch
    Args:
        id_list: 一整个文本组成的数组，内容是word的id
        batch_size: batch的大小
        num_step: 表示训练时的上下文，输入的单词个数
    '''
    num_batches = (len(id_list) - 1) // (batch_size * num_step) # batch的数量
    print num_batches
    data = np.array(id_list[:num_batches * batch_size * num_step]) # 为了凑个整除的数
    print data.shape
    data = np.reshape(data, [batch_size, num_batches * num_step]) # 将数据切分成 batch_size, num_batches * num_steps的数组
    # 沿着第二个维度将数据切分为num_batches的batch，存入一个数组
    print data.shape
    data_batches = np.split(data, num_batches, axis = 1)
    print len(data_batches[0])
#     print data_batches[0]
    
    label = np.array(id_list[1:num_batches * batch_size * num_step + 1])
    label = np.reshape(label, [batch_size, num_batches * num_step])
    label_batches = np.split(label, num_batches, axis=1)
#     print label_batches[0]
    
    return list(zip(data_batches, label_batches))

In [41]:
train_batches = make_batches(read_data(OUTPUT_TRAIN_DATA), TRAIN_BATCH_SIZE, TRAIN_NUM_STEP)
print train_batches[0][0].shape
print len(train_batches[0])
print train_batches[0][1]

1327
(928900,)
(20, 46445)
20
(20, 35)
2
[[9999 9985 9976 9989 9973 9983 9975 9991 9970 9971 9987 9995 9977 9981
  9972 9982 9988 9974 9998 9992 9993 9996 9984    0 8998    2    3   72
   393   33 2148    1  146   19    6]
 [1515   18 1453    1  846  234    1 1374    5 1281    7 1638 1082 3841
    17  380 1355    4  207    0    1 2616    4    1  260   13    5  335
     1    2   16  766 1490   10   42]
 [1126  645   46   20    2 1062   82 1091  474    6 1912    7    2    2
     8 7459   80    6    2 2126    7 1932    0 5740   82 9057  558  549
     2   22 8823    8  537    2    0]
 [  14   93   25 1023    5  255  169   10  207    0   54 1441 1252   22
  1659   15    1  469   42   45   55 1869    1   37    9  207    4  511
    12    3   48    0   14   59   79]
 [   3  394   69  123    0  271  112  608    5 3434  206    7 3396    4
    45  310 1658    6 3329  353    0  367    1  332  119  742  174   90
   137 2402    1 1246    7  819  190]
 [  53 1525  159  725   10   23 1389    9  217 12