In [1]:
import torch
import torch.nn.functional as F

In [2]:
# Load the names
names = [name.strip() for name in open("names.txt")]

In [3]:
# Process characters
chars = ["."] + sorted(list(set("".join(names))))
stoi = {chars[i]: i for i in range(len(chars))}
itos = {v: k for k, v in stoi.items()}

In [4]:
# Create a dataset for trigram model
xs = []
ys = []
for name in names:
    chars = ["."] + list(name) + ["."]
    for ch1, ch2, ch3 in zip(chars, chars[1:], chars[2:]):
        xs.append([stoi[ch1], stoi[ch2]])
        ys.append(stoi[ch3])

xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [5]:
input_length = len(stoi)

In [6]:
# One-hot encode the inputs
x1 = F.one_hot(xs[:, 0], num_classes=input_length).float()
x2 = F.one_hot(xs[:, 1], num_classes=input_length).float()
xn = torch.cat([x1, x2], dim=1)

In [7]:
# Initialize weights
g = torch.Generator().manual_seed(2147483647)
w = torch.randn((2 * input_length, input_length), generator=g, requires_grad=True)

# Training parameters
num_samples = len(xs)
lr = 50

In [9]:
# Training loop
for i in range(200):
    # Forward pass
    logits = xn @ w
    exp = torch.exp(logits)
    softmax = exp / exp.sum(dim=1, keepdim=True)

    # Loss calculation
    loss = -torch.log(softmax[range(num_samples), ys]).mean()

    # Backward pass
    w.grad = None
    loss.backward()

    # Gradient descent update
    with torch.no_grad():
        w -= lr * w.grad
print(loss.item())

2.2465996742248535


In [10]:
# Generate new names
g = torch.Generator().manual_seed(2147483647)
for _ in range(10):
    ix1, ix2 = 0, 0
    output = []
    while True:
        xenc = torch.cat([F.one_hot(torch.tensor(ix1), num_classes=input_length).float(),
                          F.one_hot(torch.tensor(ix2), num_classes=input_length).float()])
        logits = xenc @ w
        exp = torch.exp(logits)
        softmax = exp / exp.sum(dim=0, keepdim=True)
        ix = torch.multinomial(softmax, num_samples=1, replacement=True, generator=g).item()
        if ix == 0:
            break
        output.append(itos[ix])
        ix1, ix2 = ix2, ix
    print("".join(output))

aunide
aliasad
ushfay
ainn
aui
ritoleras
get
adannaa
zabileniassibdainrwi
ol
