# 循环神经网络

In [3]:
import tensorflow as tf
import numpy as np

class DataLoader:
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]
        
    def get_batch(self, sequence_length, batch_size):
        seq, next_char = [], []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - sequence_length)
            seq.append(self.text[index: index + sequence_length])
            next_char.append(self.text[index + sequence_length])
        return np.array(seq), np.array(next_char) # [batch_size, sequence_length], [batch_size]
            

In [6]:
class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, sequence_length):
        super().__init__()
        self.num_chars = num_chars
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.cell = tf.keras.layers.LSTMCell(256)
        self.dense = tf.keras.layers.Dense(units=self.num_chars)
    
    def call(self, inputs, from_logits = False):
        # 1. 将将输入转化成one-hot数据，用来进行预测
        inputs = tf.one_hot(inputs, depth=self.num_chars)
        
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
        for t in range(self.sequence_length):
            output, state = self.cell(inputs[:, t, :], state)
        logits = self.dense(output)
        if from_logits:
            return logits
        return tf.nn.softmax(logits)
