In [55]:
import torch
import torch.nn.functional as F

In [56]:
words = open('names.txt', 'r').read().split()

In [63]:
import itertools

chars = sorted(list(set(''.join(words))))
chars.append('.')

stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

# Generate all possible pairs of characters
all_possible_pairs = list(itertools.product(chars, repeat=2))

stoi_pairs = {pair: i for i, pair in enumerate(all_possible_pairs)}

itos_pairs = {i: pair for i, pair in enumerate(all_possible_pairs)}

In [58]:
xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi_pairs[(ch1, ch2)]
        ix2 = stoi[ch3]
        # print(ch1, ch2, ch3)
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

In [88]:
# gradient descent
for k in range(100):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
  # print(loss.item())
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -50 * W.grad
print(loss.item())

2.1618266105651855


In [94]:
generator = torch.Generator().manual_seed(9)
generated_names = []
for i in range(10):
    out = []
    context = ['.', '.']  # Starting with a bigram representing the start token ('.')
    
    while True:
        # Map the bigram (context) to a unique index
        context_index = stoi_pairs[tuple(context)]  # Bigram index from the pair
        
        # One-hot encode the context index (bigram)
        x_encoded = F.one_hot(torch.tensor([context_index]), num_classes=729).float()

        # Forward pass through the model
        logits = x_encoded @ W # Matrix multiplication
        counts = logits.exp()  # Convert logits to counts
        probas = counts / counts.sum(1, keepdims=True)  # Convert to probabilities
        
        # Sample the next character index from the probabilities
        ix = torch.multinomial(probas, num_samples=1, replacement=True, generator=generator).item()
        
        # Break if end token is predicted
        if ix == 0:
            break
            
        # Append the predicted character and shift context
        out.append(itos[ix])
        context = [context[1], itos[ix]]  # Shift the bigram for next prediction
    
    generated_names.append(''.join(out))
generated_names

['ola',
 'ilfvjfjqcwhibmhyklynn',
 'taidfvfinea',
 'mia',
 'znh',
 'royah',
 'zggh',
 'gicyber',
 'iraeqljion',
 'dulinickeiji']