In [9]:
import math
import random


text = "book"
chars = list(set(text)) 
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for ch, i in char_to_idx.items()}

input_size = len(chars)
hidden_size = 8         
output_size = input_size
learning_rate = 0.1


def random_matrix(rows, cols):
    return [[random.uniform(-0.1, 0.1) for _ in range(cols)] for _ in range(rows)]


Wxh = random_matrix(hidden_size, input_size) 
Whh = random_matrix(hidden_size, hidden_size)
Why = random_matrix(output_size, hidden_size) 
bh = [0.0 for _ in range(hidden_size)]        
by = [0.0 for _ in range(output_size)]       


In [10]:

def tanh(x):
    return math.tanh(x)

def dtanh(x):
    return 1.0 - math.tanh(x)**2

def softmax(x):
    max_x = max(x)
    exp_x = [math.exp(i - max_x) for i in x]
    sum_exp = sum(exp_x)
    return [i / sum_exp for i in exp_x]

def matmul(mat, vec):
    return [sum(m*v for m,v in zip(row, vec)) for row in mat]

def add(vec1, vec2):
    return [a + b for a, b in zip(vec1, vec2)]

def one_hot(index, size):
    vec = [0.0] * size
    vec[index] = 1.0
    return vec


In [11]:
def forward(inputs, target, h_prev):
    xs, hs, ys, ps = {}, {}, {}, {}
    hs[-1] = h_prev[:] 
    loss = 0.0

    for t in range(len(inputs)):
        xs[t] = one_hot(inputs[t], input_size)

        h_input = add(matmul(Wxh, xs[t]), matmul(Whh, hs[t-1]))
        h_input = add(h_input, bh)

        hs[t] = [tanh(x) for x in h_input]

        y = add(matmul(Why, hs[t]), by)
        ps[t] = softmax(y)

        if t == len(inputs) - 1:
            loss += -math.log(ps[t][target])

    cache = (xs, hs, ps)
    return loss, cache, hs[len(inputs) - 1]


In [12]:
def backward(cache, inputs, target):
    global Wxh, Whh, Why, bh, by

    xs, hs, ps = cache

    dWxh = [[0.0 for _ in range(input_size)] for _ in range(hidden_size)]
    dWhh = [[0.0 for _ in range(hidden_size)] for _ in range(hidden_size)]
    dWhy = [[0.0 for _ in range(hidden_size)] for _ in range(output_size)]
    dbh = [0.0 for _ in range(hidden_size)]
    dby = [0.0 for _ in range(output_size)]
    dh_next = [0.0 for _ in range(hidden_size)]

    for t in reversed(range(len(inputs))):
        dy = ps[t][:]
        if t == len(inputs) - 1:
            dy[target] -= 1 
        
        for i in range(output_size):
            for j in range(hidden_size):
                dWhy[i][j] += dy[i] * hs[t][j]
        
        for i in range(output_size):
            dby[i] += dy[i]

        dh = [sum(Why[k][i] * dy[k] for k in range(output_size)) + dh_next[i] for i in range(hidden_size)]
        dh_raw = [dtanh(hs[t][i]) * dh[i] for i in range(hidden_size)]

        for i in range(hidden_size):
            for j in range(input_size):
                dWxh[i][j] += dh_raw[i] * xs[t][j]

        for i in range(hidden_size):
            for j in range(hidden_size):
                dWhh[i][j] += dh_raw[i] * hs[t-1][j]

        for i in range(hidden_size):
            dbh[i] += dh_raw[i]

        dh_next = [sum(Whh[k][i] * dh_raw[k] for k in range(hidden_size)) for i in range(hidden_size)]

  
    for i in range(hidden_size):
        for j in range(input_size):
            Wxh[i][j] -= learning_rate * dWxh[i][j]
    for i in range(hidden_size):
        for j in range(hidden_size):
            Whh[i][j] -= learning_rate * dWhh[i][j]
    for i in range(output_size):
        for j in range(hidden_size):
            Why[i][j] -= learning_rate * dWhy[i][j]
    for i in range(hidden_size):
        bh[i] -= learning_rate * dbh[i]
    for i in range(output_size):
        by[i] -= learning_rate * dby[i]


In [13]:
inputs = [char_to_idx[ch] for ch in text[:-1]] 
target = char_to_idx[text[-1]]                 

h_prev = [0.0 for _ in range(hidden_size)]


n_epochs = 1000
for epoch in range(n_epochs):
    loss, cache, h_prev = forward(inputs, target, h_prev)
    backward(cache, inputs, target)

    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss:.4f}')


Epoch 0, Loss: 1.1075
Epoch 100, Loss: nan
Epoch 200, Loss: nan
Epoch 300, Loss: nan
Epoch 400, Loss: nan
Epoch 500, Loss: nan
Epoch 600, Loss: nan
Epoch 700, Loss: nan
Epoch 800, Loss: nan
Epoch 900, Loss: nan
