In [None]:
import tensorflow as tf
import numpy as np
import random

In [None]:
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words + ['.']))))
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s,i in stoi.items()}
stoi

In [None]:
xs, ys = [], []
block_size = 7

random.shuffle(words)
for w in words:
    # print(w)
    context = [stoi['.']] * block_size
    chs = list(w) + ['.']
    for ch in chs:
        ix = stoi[ch]
        xs.append(context)
        ys.append(ix)
        # print(''.join(itos[i] for i in context), '---->', itos[ix])
        context = context[1:] + [ix]

n = int(0.9 * len(xs))

xtrs, ytrs = xs[:n], ys[:n]
xtes, ytes = xs[n:], ys[n:]

xtrs, ytrs = np.array(xtrs), np.array(ytrs)
xtes, ytes = np.array(xtes), np.array(ytes)

In [None]:
model = tf.keras.Sequential([
    tf.keras.Input(shape=(block_size,)),
    tf.keras.layers.Embedding(len(chars), 16),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=64, activation='tanh'),
    tf.keras.layers.Dense(units=len(chars), activation='softmax')
])

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
for i in range(100):
    ix = np.random.randint(0, xtrs.shape[0], (512,))
    bx = xtrs[ix]
    by = tf.one_hot(ytrs[ix], len(chars))
    model.fit(bx, by, epochs=10, verbose=0)
    if (i + 1) % 10 == 0:
        print(f'#{int(i / 10) + 1} done')

In [None]:
ytrse = tf.one_hot(ytrs, len(chars))
training_loss, training_accuracy = model.evaluate(xtrs, ytrse, verbose=0)
print(f'{training_loss=:.04f} {training_accuracy=:.04f}')

ytese = tf.one_hot(ytes, len(chars))
test_loss, test_accuracy = model.evaluate(xtes, ytese, verbose=0)
print(f'{test_loss=:.04f} {test_accuracy=:.04f}')

In [None]:
for i in range(25):
    context = ['.'] * block_size
    while True:
        xi = np.array([[stoi[x] for x in context]])
        p = model.predict(xi, verbose=0)[0]
        i = np.random.choice(len(chars), p=p)
        ch = itos[i]
        context = context[1:] + [ch]
        print(ch, end='')
        if ch == '.':
            break
    print()