In [1]:
import tensorflow as tf
from tensorflow import keras
tf.random.set_seed(777)

In [2]:
num_classes = 5
input_dim = 5  # one-hot size
hidden_size = 5  # output from the LSTM
batch_size = 1   # one sentence
sequence_length = 6  # ihello
learning_rate = 0.1
epochs = 50

In [3]:
idx2char = ['h', 'i', 'e', 'l', 'o']

# teach hello: hihell -> ihello
x_data = [[0, 1, 0, 2, 3, 3]]   # hihell
x_one_hot = tf.one_hot(x_data, num_classes)
y_data = [[1, 0, 2, 3, 3, 4]]    # ihello

In [4]:
# create dataset
dataset = tf.data.Dataset.from_tensor_slices((x_one_hot, y_data)).batch(batch_size)

In [5]:
dataset

<_BatchDataset element_spec=(TensorSpec(shape=(None, 6, 5), dtype=tf.float32, name=None), TensorSpec(shape=(None, 6), dtype=tf.int32, name=None))>

In [6]:
def rnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(hidden_size, return_sequences=True, stateful=True),
        tf.keras.layers.Dense(num_classes)
    ])
    return model

In [7]:
# cost function
def loss_function(model, x, y):
    logits = model(x)
    return tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, logits, from_logits=True))

# optimize
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [8]:
model = rnn_model()
# training
for epoch in range(epochs):
    total_loss = 0.0
    for x, y in dataset:
        with tf.GradientTape() as tape:
            loss = loss_function(model, x, y)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        total_loss += loss
    print(f"Epoch {epoch + 1}, Loss: {total_loss}")

    # testing
    results = model(x_one_hot)
    predicted_indices = tf.argmax(results)
    predicted_chars = [idx2char[idx] for idx in predicted_indices[0].numpy()]
    print(f"Prediction str: {''.join(predicted_chars)}")

Epoch 1, Loss: 1.6062116622924805
Prediction str: hhhhh
Epoch 2, Loss: 1.5579748153686523
Prediction str: hhhhh
Epoch 3, Loss: 1.4895553588867188
Prediction str: hhhhh
Epoch 4, Loss: 1.4215855598449707
Prediction str: hhhhh
Epoch 5, Loss: 1.3260153532028198
Prediction str: hhhhh
Epoch 6, Loss: 1.210224986076355
Prediction str: hhhhh
Epoch 7, Loss: 1.0659083127975464
Prediction str: hhhhh
Epoch 8, Loss: 0.9099225997924805
Prediction str: hhhhh
Epoch 9, Loss: 0.7700164318084717
Prediction str: hhhhh
Epoch 10, Loss: 0.6491644978523254
Prediction str: hhhhh
Epoch 11, Loss: 0.5405043959617615
Prediction str: hhhhh
Epoch 12, Loss: 0.4360748827457428
Prediction str: hhhhh
Epoch 13, Loss: 0.34126606583595276
Prediction str: hhhhh
Epoch 14, Loss: 0.2604968845844269
Prediction str: hhhhh
Epoch 15, Loss: 0.1956607848405838
Prediction str: hhhhh
Epoch 16, Loss: 0.14693088829517365
Prediction str: hhhhh
Epoch 17, Loss: 0.11136766523122787
Prediction str: hhhhh
Epoch 18, Loss: 0.0854882225394249
Pre