In [52]:
"""
数据来源：https://github.com/L1aoXingyu/Char-RNN-Gluon/tree/master/data
"""
import codecs
import collections
import numpy as np
import logging
import time


In [57]:
class TextDataLoader(object):
    def __init__(self):
        self.word_to_index = None
        self.index_to_word = None
        self.data = None
        self.data_size = 0
        self.vocab = None
        self.vocab_num = 0

    def load_data(self, file_path, max_vocab_num = 5000):
        start_time = time.time()
        with codecs.open(file_path, mode='r', encoding='utf-8') as f:
            file_content = f.readlines()

        word_list = [w for line in file_content for w in line]
        vocab_count_dict = collections.Counter(word_list)
        vocab_count_arr = list(vocab_count_dict.items())
        vocab_count_arr.sort(key=lambda x: x[1], reverse=True)
        real_vocab_num = len(vocab_count_arr)
        if len(vocab_count_arr) > max_vocab_num:
            vocab_count_arr = vocab_count_arr[:max_vocab_num]
        self.vocab = [x[0] for x in vocab_count_arr]
        self.word_to_index = {w: i for i, w in enumerate(self.vocab)}
        self.index_to_word = dict(enumerate(self.vocab))
        self.vocab_num = len(self.vocab)
        data = [ self.word_to_index[w] for line in file_content for w in line  if self.word_to_index.get(w)!= None] #TODO unkown word
        self.data_size = len(data)
        self.data = np.array(data)
        logging.info("load data done. real_vocab_num:%d result_vocab_num:%d data_size:%d",
                     real_vocab_num, self.vocab_num, self.data_size)
        
    def random_batch_onehot(self, seq_len, batch_size):
        x = [ np.zeros((self.vocab_num, batch_size), dtype=float) for i in range(seq_len) ] 
        p = np.random.randint( self.data_size - 1 - seq_len,  size =(batch_size) )

        for t in range(seq_len):
            x_t = x[t]
            for i in range(batch_size):                
                x_t[self.data[p[i]+t], i] = 1
            
        return x
            
        
        

In [58]:
if __name__ == "__main__":
    loader = TextDataLoader()
    loader.load_data("poetry.txt", 10)
    print("word_to_index:%s" % loader.word_to_index)
    print("index_to_word:%s" % loader.index_to_word)
    print(loader.random_batch_onehot(3, 2) )


word_to_index:{'，': 0, '。': 1, '\n': 2, '不': 3, '人': 4, '山': 5, '风': 6, '日': 7, '云': 8, '无': 9}
index_to_word:{0: '，', 1: '。', 2: '\n', 3: '不', 4: '人', 5: '山', 6: '风', 7: '日', 8: '云', 9: '无'}
[array([[0., 0.],
       [0., 0.],
       [1., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.]]), array([[1., 0.],
       [0., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]]), array([[0., 0.],
       [1., 0.],
       [0., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]])]
