In [1]:
import tensorflow as tf
import numpy as np
import glob
import json
from collections import Counter
from tqdm import tqdm
from snownlp import SnowNLP

a = tf.test.is_built_with_cuda()
b = tf.test.is_gpu_available(
    cuda_only=False,
    min_cuda_compute_capability=None
)     
print(a)
print(b)

True
True


In [2]:
paths = glob.glob('./dataset/json/poet.*.json')
poets = []
#print(paths)

In [3]:
for path in paths:
    data = open(path, 'r', encoding='utf-8').read()
    data = json.loads(data)
    #print(data)
    for item in data:
        content = ''.join(item['paragraphs'])
        if len(content)>= 24 and len(content) <= 32:
            content = SnowNLP(content)
            content=content.han
            if len(content)%4==0:
                poets.append('[' + content + ']')


In [4]:
print('We have %d Chinese ancient poets' % len(poets), poets[0], poets[-1])

We have 104147 Chinese ancient poets [欲出未出光辣达，千山万山如火发。须臾走向天上来，逐却残星赶却月。] [书劒催人不暂闲，洛阳羁旅复秦关。容颜岁岁愁边改，乡国时时梦里还。]


In [5]:
np.random.shuffle(poets)
num_validation_samples = 8000
validation_data = poets[:num_validation_samples]
training_data=poets[num_validation_samples:]
training_data.sort(key=lambda x: len(x))
validation_data.sort(key=lambda x: len(x))

In [6]:
print(len(training_data))

96147


In [7]:
print(validation_data[0])

[命啸无人啸，含娇何处娇？徘徊花上月，空度可怜宵。]


In [8]:
training_chars = []
for item in training_data:
    training_chars += [c for c in item]
print('We have %d words' % len(training_chars))

training_chars = sorted(Counter(training_chars).items(), key=lambda x:x[1], reverse=True)
print('We have %d different words in training set' % len(training_chars))
print(training_chars[:30])

training_chars = [c[0] for c in training_chars]
training_char2id = {c: i + 1 for i, c in enumerate(training_chars)}
training_id2char = {i + 1: c for i, c in enumerate(training_chars)}
print(len(training_char2id))


We have 3131174 words
We have 7927 different words in training set
[('，', 193150), ('。', 192667), ('[', 96150), (']', 96150), ('不', 33677), ('人', 26960), ('一', 26076), ('风', 20763), ('无', 20361), ('山', 19765), ('来', 18069), ('花', 16535), ('有', 14894), ('春', 14411), ('日', 13537), ('天', 12645), ('中', 12414), ('何', 12012), ('时', 11631), ('云', 11563), ('是', 11191), ('年', 11012), ('知', 10967), ('水', 10841), ('自', 10792), ('得', 10730), ('上', 10641), ('月', 10454), ('如', 9884), ('生', 9397)]
7927


In [9]:
batch_size = 256
X_training_data = []
Y_training_data = []

for b in range(len(training_data) // batch_size):
    start = b * batch_size
    end = b * batch_size + batch_size
    batch = [[training_char2id[c] for c in training_data[i]] for i in range(start, end)]
    #if count==1:
        #print(len(batch[0]))
    maxlen = max(map(len, batch))
    X_training_batch = np.full((batch_size, maxlen - 1), 0, np.int32)
    Y_training_batch = np.full((batch_size, maxlen - 1), 0, np.int32)

    for i in range(batch_size):
        X_training_batch[i, :len(batch[i]) - 1] = batch[i][:-1]
        Y_training_batch[i, :len(batch[i]) - 1] = batch[i][1:]
    #if maxlen!=34:
        #print(maxlen)
    X_training_data.append(X_training_batch)
    Y_training_data.append(Y_training_batch)
    
print(len(X_training_data), len(Y_training_data))

375 375


In [10]:
validation_chars = []
for item in validation_data:
    validation_chars += [c for c in item]
print('We have %d words in validation set' % len(validation_chars))

validation_chars = sorted(Counter(validation_chars).items(), key=lambda x:x[1], reverse=True)
print('We have %d different words in validation set' % len(validation_chars))
print(validation_chars[:30])

validation_chars = [c[0] for c in validation_chars]
validation_char2id = {c: i + 1 for i, c in enumerate(validation_chars)}
validation_id2char = {i + 1: c for i, c in enumerate(validation_chars)}
print(len(validation_char2id))

We have 260536 words in validation set
We have 5285 different words in validation set
[('，', 16059), ('。', 16042), ('[', 8000), (']', 8000), ('不', 2780), ('人', 2207), ('一', 2109), ('风', 1753), ('无', 1665), ('山', 1654), ('来', 1422), ('花', 1401), ('有', 1259), ('春', 1246), ('日', 1098), ('何', 1049), ('天', 999), ('中', 997), ('云', 991), ('时', 960), ('年', 960), ('知', 942), ('自', 933), ('得', 919), ('是', 888), ('如', 862), ('上', 849), ('水', 846), ('月', 816), ('生', 762)]
5285


In [11]:
batch_size = 256
X_validation_data = []
Y_validation_data = []

for b in range(len(validation_data) // batch_size):
    start = b * batch_size
    end = b * batch_size + batch_size
    batch = [[validation_char2id[c] for c in validation_data[i]] for i in range(start, end)]
    #if count==1:
        #print(len(batch[0]))
    maxlen = max(map(len, batch))
    X_validation_batch = np.full((batch_size, maxlen - 1), 0, np.int32)
    Y_validation_batch = np.full((batch_size, maxlen - 1), 0, np.int32)

    for i in range(batch_size):
        X_validation_batch[i, :len(batch[i]) - 1] = batch[i][:-1]
        Y_validation_batch[i, :len(batch[i]) - 1] = batch[i][1:]
    #if maxlen!=34:
        #print(maxlen)
    X_validation_data.append(X_validation_batch)
    Y_validation_data.append(Y_validation_batch)
    
print(len(X_validation_data), len(Y_validation_data))
print(b)
print(type(X_validation_data[0]))

31 31
30
<class 'numpy.ndarray'>


In [12]:
import pickle
with open('dictionary.pkl', 'wb') as fw:
    pickle.dump([training_char2id, training_id2char], fw)

In [12]:
tf.reset_default_graph()
hidden_size = 256
num_layer = 2
embedding_size = 256

X_training= tf.placeholder(tf.int32, [batch_size, None])
Y_training= tf.placeholder(tf.int32, [batch_size, None])
'''
X_validation= tf.placeholder(tf.int32, [batch_size, None])
Y_validation= tf.placeholder(tf.int32, [batch_size, None])
'''
#print(X)
learning_rate = tf.Variable(0.0, trainable=False)
#print(learning_rate)
ikeep_prob = tf.placeholder(tf.float32, name='ikeep_prob')
okeep_prob = tf.placeholder(tf.float32, name='okeep_prob')
cell = tf.nn.rnn_cell.MultiRNNCell(
    [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(hidden_size), input_keep_prob=ikeep_prob, output_keep_prob=ikeep_prob )for i in range(num_layer)], 
    state_is_tuple=True)

initial_state = cell.zero_state(batch_size, tf.float32)

tr_embeddings = tf.Variable(tf.random_uniform([len(training_char2id) + 1, embedding_size], -1.0, 1.0))
tr_embedded = tf.nn.embedding_lookup(tr_embeddings, X_training)
#print(tr_embeddings)
# outputs: batch_size, max_time, hidden_size
# last_states: 2 tuple(two LSTM), 2 tuple(c and h)
#              batch_size, hidden_size
tr_outputs, tr_last_states = tf.nn.dynamic_rnn(cell, tr_embedded, initial_state=initial_state)
tr_outputs = tf.reshape(tr_outputs, [-1, hidden_size])                # batch_size * max_time, hidden_size

training_logits = tf.layers.dense(tr_outputs, units=len(training_char2id) + 1)       # batch_size * max_time, len(char2id) + 1
training_logits = tf.reshape(training_logits, [batch_size, -1, len(training_char2id) + 1]) # batch_size, max_time, len(char2id) + 1
probs = tf.nn.softmax(training_logits)                                   # batch_size, max_time, len(char2id) + 1
loss = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(training_logits, Y_training, tf.ones_like(Y_training, dtype=tf.float32)))



'''
val_embeddings = tf.Variable(tf.random_uniform([len(training_char2id) + 1, embedding_size], -1.0, 1.0))
val_embedded = tf.nn.embedding_lookup(val_embeddings, X_validation)
print(val_embeddings)
# outputs: batch_size, max_time, hidden_size
# last_states: 2 tuple(two LSTM), 2 tuple(c and h)
#              batch_size, hidden_size
val_outputs, val_last_states = tf.nn.dynamic_rnn(cell, val_embedded, initial_state=initial_state)
val_outputs = tf.reshape(val_outputs, [-1, hidden_size])                # batch_size * max_time, hidden_size
validation_logits = tf.layers.dense(val_outputs, units=len(training_char2id) + 1)       # batch_size * max_time, len(char2id) + 1
validation_logits = tf.reshape(validation_logits, [batch_size, -1, len(training_char2id) + 1]) # batch_size, max_time, len(char2id) + 1
'''

#validation_accuracy=tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(training_logits, Y_training, tf.ones_like(Y_training, dtype=tf.float32)))



params = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss, params), 5)
optimizer = tf.train.AdamOptimizer(learning_rate).apply_gradients(zip(grads, params))


In [13]:
print(cell)

<tensorflow.python.ops.rnn_cell_impl.MultiRNNCell object at 0x00000259685E0BA8>


In [14]:
print(X_training)

print(tr_embeddings)
print(tr_embedded)
print(tr_outputs)
print(training_logits)
print(tr_last_states)
print(params)

Tensor("Placeholder:0", shape=(256, ?), dtype=int32)
<tf.Variable 'Variable_1:0' shape=(7928, 256) dtype=float32_ref>
Tensor("embedding_lookup/Identity:0", shape=(256, ?, 256), dtype=float32)
Tensor("Reshape:0", shape=(?, 256), dtype=float32)
Tensor("Reshape_1:0", shape=(256, ?, 7928), dtype=float32)
(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(256, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(256, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(256, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(256, 256) dtype=float32>))
[<tf.Variable 'Variable_1:0' shape=(7928, 256) dtype=float32_ref>, <tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(512, 1024) dtype=float32_ref>, <tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(1024,) dtype=float32_ref>, <tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(512, 1024) dtype=float32_ref>, <tf.Variable 'rnn/multi_rnn_cell/cell_1/ls

In [15]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Loss=[]
Validation=[]
for epoch in range(100):
    sess.run(tf.assign(learning_rate, 0.0004 * (0.97 ** epoch)))
    
    data_index = np.arange(len(X_training_data))
    np.random.shuffle(data_index)
    X_training_data = [X_training_data[i] for i in data_index]
    Y_training_data = [Y_training_data[i] for i in data_index]
    
    losses = []
    validates=[]

    for i in tqdm(range(len(X_training_data))):
        ls_,  _ = sess.run([loss, optimizer], feed_dict={X_training: X_training_data[i], Y_training: Y_training_data[i], ikeep_prob: 0.8, okeep_prob: 0.8})
        losses.append(ls_)
    #
    
    for i in tqdm(range(len(X_validation_data))):
        validate_acc = sess.run([loss], feed_dict={X_training: X_validation_data[i], Y_training: Y_validation_data[i], ikeep_prob: 1, okeep_prob: 1})
        validates.append(validate_acc)
    '''
    #
    
    if epoch+1==10:
        saver = tf.train.Saver()
        saver.save(sess, './10/10_epoch')
    
    if epoch+1==20:
        saver = tf.train.Saver()
        saver.save(sess, './20/20_epoch')
    if epoch+1==30:
        saver = tf.train.Saver()
        saver.save(sess, './30/30_epoch')
    if epoch+1==40:
        saver = tf.train.Saver()
        saver.save(sess, './40/40_epoch')
    if epoch+1==50:
        saver = tf.train.Saver()
        saver.save(sess, './50/50_epoch')
        
    if epoch+1==60:
        saver = tf.train.Saver()
        saver.save(sess, './60/60_epoch')
    
    if epoch+1==70:
        saver = tf.train.Saver()
        saver.save(sess, './70/70_epoch')
    if epoch+1==80:
        saver = tf.train.Saver()
        saver.save(sess, './80/80_epoch')
    if epoch+1==90:
        saver = tf.train.Saver()
        saver.save(sess, './90/90_epoch')
    if epoch+1==100:
        saver = tf.train.Saver()
        saver.save(sess, './100/100_epoch')
    '''
    
    print('Epoch %d Loss %.5f' % (epoch, np.mean(losses)))
    Loss.append(np.mean(losses))
    print('Epoch %d Validates %.5f' % (epoch, np.mean(validates)))
    Validation.append(np.mean(validates))

100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:07<00:00,  5.52it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.32it/s]


Epoch 0 Loss 6.69854
Epoch 0 Validates 6.52082


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.74it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.83it/s]


Epoch 1 Loss 6.32235
Epoch 1 Validates 6.13153


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.84it/s]


Epoch 2 Loss 6.01712
Epoch 2 Validates 6.06950


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.86it/s]


Epoch 3 Loss 5.89121
Epoch 3 Validates 6.10205


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.67it/s]


Epoch 4 Loss 5.78708
Epoch 4 Validates 6.14971


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.68it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.73it/s]


Epoch 5 Loss 5.65762
Epoch 5 Validates 6.23994


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.48it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.93it/s]


Epoch 6 Loss 5.50609
Epoch 6 Validates 6.35740


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.37it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.94it/s]


Epoch 7 Loss 5.41112
Epoch 7 Validates 6.43625


100%|████████████████████████████████████████████████████████████████████████████████| 375/375 [01:06<00:00,  5.84it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 11.83it/s]


Epoch 8 Loss 5.34969
Epoch 8 Validates 6.48057


 10%|███████▉                                                                         | 37/375 [00:06<00:53,  6.33it/s]


KeyboardInterrupt: 

In [17]:
print(Loss)
print(Validation)

[6.708487, 6.542999, 6.47859, 6.1634636, 5.9582043, 5.8722515, 5.8110027, 5.701758, 5.6210666, 5.580378, 5.5502963, 5.5204234, 5.4796486, 5.4413443, 5.4097214, 5.3864746, 5.366487, 5.3481503, 5.327319, 5.3054676, 5.281816, 5.2583537, 5.236086, 5.2168775, 5.198056, 5.1808333, 5.1650047, 5.1502104, 5.135682, 5.12194, 5.109127, 5.096859, 5.0848994, 5.0738664, 5.0636315, 5.0533843, 5.0444274, 5.0353065, 5.0265217, 5.0185204, 5.011227, 5.0037847, 4.9969974, 4.990392, 4.983714, 4.977443, 4.972288, 4.9653583, 4.960594, 4.9556413]
[6.5855975, 6.530785, 6.3774776, 6.056278, 6.0460277, 6.0778613, 6.1174946, 6.2397423, 6.2987475, 6.3240294, 6.313289, 6.3476143, 6.3538804, 6.377427, 6.4043336, 6.426238, 6.4298086, 6.44656, 6.44417, 6.459502, 6.461572, 6.480503, 6.502726, 6.5101876, 6.520776, 6.5377436, 6.546568, 6.556171, 6.5809355, 6.5737624, 6.5882535, 6.5878105, 6.6167054, 6.6074634, 6.621916, 6.6341963, 6.625648, 6.630437, 6.6416764, 6.65174, 6.653723, 6.6700177, 6.656472, 6.668778, 6.6718335,