# 诗歌生成

# 数据处理

In [20]:
import numpy as np
import tensorflow as tf
import collections
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import layers, optimizers, datasets

start_token = 'bos'
end_token = 'eos'

def process_dataset(fileName):
    examples = []
    with open(fileName, 'r',encoding='utf-8') as fd:
        for line in fd:
            outs = line.strip().split(':')
            content = ''.join(outs[1:])
            ins = [start_token] + list(content) + [end_token] 
            if len(ins) > 200:
                continue
            examples.append(ins)
            
    counter = collections.Counter()
    for e in examples:
        for w in e:
            counter[w]+=1
    
    sorted_counter = sorted(counter.items(), key=lambda x: -x[1])  # 排序
    words, _ = zip(*sorted_counter)
    words = ('PAD', 'UNK') + words[:len(words)]
    word2id = dict(zip(words, range(len(words))))
    id2word = {word2id[k]:k for k in word2id}
    
    indexed_examples = [[word2id[w] for w in poem]
                        for poem in examples]
    seqlen = [len(e) for e in indexed_examples]
    
    instances = list(zip(indexed_examples, seqlen))
    
    return instances, word2id, id2word

def poem_dataset():
    instances, word2id, id2word = process_dataset('./poems.txt')
    ds = tf.data.Dataset.from_generator(lambda: [ins for ins in instances], 
                                            (tf.int64, tf.int64), 
                                            (tf.TensorShape([None]),tf.TensorShape([])))
    ds = ds.shuffle(buffer_size=10240)
    ds = ds.padded_batch(100, padded_shapes=(tf.TensorShape([None]),tf.TensorShape([])))
    ds = ds.map(lambda x, seqlen: (x[:, :-1], x[:, 1:], seqlen-1))
    return ds, word2id, id2word

# 模型代码， 完成建模代码

In [21]:
class myRNNModel(keras.Model):
    def __init__(self, w2id):
        super(myRNNModel, self).__init__()
        self.v_sz = len(w2id)
        self.embed_layer = tf.keras.layers.Embedding(self.v_sz, 64)
        
        self.rnncell = tf.keras.layers.SimpleRNNCell(128)
        self.rnn_layer = tf.keras.layers.RNN(self.rnncell, return_sequences=True)
        self.dense = tf.keras.layers.Dense(self.v_sz)
        
    @tf.function
    def call(self, inp_ids):
        '''
        此处完成建模过程，可以参考Learn2Carry
        '''
        # 完整的前向传播逻辑
        embeddings = self.embed_layer(inp_ids)  # shape: (batch_size, seq_len, 64)
        rnn_output= self.rnn_layer(embeddings)
        logits = self.dense(rnn_output)  # shape: (batch_size, seq_len, v_sz)
        return logits
    
    @tf.function
    def get_next_token(self, x, state):
        '''
        shape(x) = [b_sz,] 
        '''
    
        inp_emb = self.embed_layer(x) #shape(b_sz, emb_sz)
        h, state = self.rnncell.call(inp_emb, state) # shape(b_sz, h_sz)
        logits = self.dense(h) # shape(b_sz, v_sz)
        out = tf.argmax(logits, axis=-1)
        return out, state
    def build(self, input_shape):
        super().build(input_shape)

## 一个计算sequence loss的辅助函数，只需了解用途。

In [22]:
def mkMask(input_tensor, maxLen):
    shape_of_input = tf.shape(input_tensor)
    shape_of_output = tf.concat(axis=0, values=[shape_of_input, [maxLen]])

    oneDtensor = tf.reshape(input_tensor, shape=(-1,))
    flat_mask = tf.sequence_mask(oneDtensor, maxlen=maxLen)
    return tf.reshape(flat_mask, shape_of_output)


def reduce_avg(reduce_target, lengths, dim):
    """
    Args:
        reduce_target : shape(d_0, d_1,..,d_dim, .., d_k)
        lengths : shape(d0, .., d_(dim-1))
        dim : which dimension to average, should be a python number
    """
    shape_of_lengths = lengths.get_shape()
    shape_of_target = reduce_target.get_shape()
    if len(shape_of_lengths) != dim:
        raise ValueError(('Second input tensor should be rank %d, ' +
                         'while it got rank %d') % (dim, len(shape_of_lengths)))
    if len(shape_of_target) < dim+1 :
        raise ValueError(('First input tensor should be at least rank %d, ' +
                         'while it got rank %d') % (dim+1, len(shape_of_target)))

    rank_diff = len(shape_of_target) - len(shape_of_lengths) - 1
    mxlen = tf.shape(reduce_target)[dim]
    mask = mkMask(lengths, mxlen)
    if rank_diff!=0:
        len_shape = tf.concat(axis=0, values=[tf.shape(lengths), [1]*rank_diff])
        mask_shape = tf.concat(axis=0, values=[tf.shape(mask), [1]*rank_diff])
    else:
        len_shape = tf.shape(lengths)
        mask_shape = tf.shape(mask)
    lengths_reshape = tf.reshape(lengths, shape=len_shape)
    mask = tf.reshape(mask, shape=mask_shape)

    mask_target = reduce_target * tf.cast(mask, dtype=reduce_target.dtype)

    red_sum = tf.reduce_sum(mask_target, axis=[dim], keepdims=False)
    red_avg = red_sum / (tf.cast(lengths_reshape, dtype=tf.float32) + 1e-30)
    return red_avg

# 定义loss函数，定义训练函数

In [23]:
@tf.function
def compute_loss(logits, labels, seqlen):
    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels)
    losses = reduce_avg(losses, seqlen, dim=1)
    return tf.reduce_mean(losses)

@tf.function(experimental_relax_shapes=True)
def train_one_step(model, optimizer, x, y, seqlen):
    '''
    完成一步优化过程，可以参考之前做过的模型
    '''
    with tf.GradientTape() as tape:
        # 前向传播：获取模型输出 logits
        logits = model(x)
        # 计算损失
        loss = compute_loss(logits, y, seqlen)
    # 计算梯度
    gradients = tape.gradient(loss, model.trainable_variables)
    # 应用梯度更新
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

def train(epoch, model, optimizer, ds):
    loss = 0.0
    accuracy = 0.0
    for step, (x, y, seqlen) in enumerate(ds):
        loss = train_one_step(model, optimizer, x, y, seqlen)

        if step % 500 == 0:
            print('epoch', epoch, ': loss', loss.numpy())

    return loss

# 训练优化过程

In [24]:
optimizer = optimizers.Adam(0.0005)
train_ds, word2id, id2word = poem_dataset()
model = myRNNModel(word2id)
model.build(input_shape=(None, None))
for epoch in range(30):
    loss = train(epoch, model, optimizer, train_ds)

epoch 0 : loss 8.821012
epoch 1 : loss 6.625804
epoch 2 : loss 6.0384154
epoch 3 : loss 5.841291
epoch 4 : loss 5.659331
epoch 5 : loss 5.584035
epoch 6 : loss 5.3673873
epoch 7 : loss 5.35361
epoch 8 : loss 5.304324
epoch 9 : loss 5.2393622
epoch 10 : loss 5.2105694
epoch 11 : loss 5.28596
epoch 12 : loss 5.1616797
epoch 13 : loss 5.0026193
epoch 14 : loss 4.9741454
epoch 15 : loss 4.9501534
epoch 16 : loss 5.00656
epoch 17 : loss 4.927034
epoch 18 : loss 4.888111
epoch 19 : loss 4.9201436
epoch 20 : loss 4.8554416
epoch 21 : loss 4.875818
epoch 22 : loss 4.830716
epoch 23 : loss 4.808032
epoch 24 : loss 4.725497
epoch 25 : loss 4.732128
epoch 26 : loss 4.7191505
epoch 27 : loss 4.7271347
epoch 28 : loss 4.74349
epoch 29 : loss 4.690366


# 生成过程

In [25]:
def gen_sentence():
    state = [tf.random.normal(shape=(1, 128), stddev=0.5), tf.random.normal(shape=(1, 128), stddev=0.5)]
    cur_token = tf.constant([word2id['bos']], dtype=tf.int32)
    collect = []
    for _ in range(50):
        cur_token, state = model.get_next_token(cur_token, state)
        collect.append(cur_token.numpy()[0])
    return [id2word[t] for t in collect]
for i in range(20):
    print(''.join(gen_sentence()))

eos人不可见，不是一枝枝。eos得无人事，何人不可寻。eos心无所思，何处是何人。eos子无人事，无人不可寻。eos心
一声清，不见君。eos子不知不可知，不知不是人间人。eos来不是无人事，不得无人不得归。eos得不知何处处，不知
eos君不见君王子，不得人间不得人。eos得不知何处处，不知何处是何人。eos来不是人间事，不是无人不得归。eos得
天上春风起，风吹万里秋。eos心无限意，何处是何人。eos去无人事，何人不可寻。eos心无限意，何处是何人。eos子
一夜月明月，山中不可寻。eos来无限意，不是故乡心。eos路无人事，无人不可寻。eos心无限意，何处是何人。eos子
一片云云起，风吹万里秋。不知何处处，不见白云生。eos得无人事，何妨得此心。eos心无所思，何处是何人。eos子
春风吹，一枝红。eos之一，不见人。eos之不，不见此。eos去不知何处。eos在不知何处，一片云中，不见此时。eos心
人间不见，风雨不成。eos人不见，不见此中。eos处不知，不知何处，何人不可。eos之不是，何人不可。eos之不是，
春风吹，风骚旨格》）eos，何处无人。eos心不可，三年不可。eos人不见，不见《风骚旨格》）eos女不知君不知，不
高阁有秋风，不见山中路。eos来不可见，不是心中事。eos心不可见，不是心中人。eos心不可见，不得心中心。eos之
九衢金玉凤，金缕玉金金。eos影随风雨，风吹入翠微。eos心无限意，何处是何人。eos子无人事，无人不可寻。eos心
一片青山一片云，一声不见人间里。eos来不见白云间，一片月中人不知。eos人不见君不得，不得不知何处知。eos得
一片春风起，春风一夜来。eos心无限意，何处是何人。eos去无人事，何人不可寻。eos心无限意，何处是何人。eos子
一片云云动，风吹万里秋。eos心无限意，何处是何人。eos去无人事，何人不可寻。eos心无限意，何处是何人。eos子
一朝天下不知名。eos来不得无人事，不是无人不得归。eos得不知何处处，不知何处是何人。eos来不是人间事，不是
南国不知何处，不知何事无人。eos来不是君王，不知何事不知。eos之不是，一片月明，不知何处，不见此时。eos之
风吹一片云中，__皎然eos兮不见，不见《风骚旨格》）eos女不知君，不如何处。eos心不得，不见此时。eos之不可
eos