In [136]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [137]:
animals = open('animals.txt', 'r').read().splitlines()
animals = [a.lower() for a in animals]

In [138]:
animals[:10]

['canidae',
 'felidae',
 'cat',
 'cattle',
 'dog',
 'donkey',
 'goat',
 'guinea pig',
 'horse',
 'pig']

In [139]:
print('Number of animal names:', len(animals))
print('Min animal name size:', min(len(a) for a in animals))
print('Max animal name size:', max(len(a) for a in animals))

Number of animal names: 520
Min animal name size: 2
Max animal name size: 33


In [140]:
# Creating the dictionary of trigrams from animal names
b = {}
for a in animals:
    chs = list(a) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        trigram = (ch1, ch2, ch3)
        b[trigram] = b.get(trigram, 0) + 1

In [141]:
sorted(b.items(), key = lambda kv: -kv[1])

[(('e', 'r', '.'), 28),
 (('l', 'e', '.'), 26),
 (('l', 'i', 's'), 26),
 (('s', 't', '.'), 26),
 (('o', 'n', '.'), 26),
 (('i', 's', 't'), 24),
 (('i', 's', 'h'), 24),
 (('f', 'i', 's'), 23),
 (('s', 'h', '.'), 23),
 (('a', 't', '.'), 20),
 (('t', 'i', 'c'), 17),
 (('m', 'e', 's'), 17),
 (('s', 'e', '.'), 16),
 (('a', 'n', 't'), 16),
 (('i', 'c', ' '), 16),
 (('e', 's', 't'), 16),
 (('t', 'e', 'r'), 15),
 (('s', 't', 'i'), 15),
 (('o', 'm', 'e'), 15),
 (('i', 'n', 'g'), 14),
 (('d', 'o', 'm'), 14),
 (('c', 'a', 't'), 13),
 (('r', 'd', '.'), 12),
 (('c', 'a', 'n'), 11),
 (('k', 'e', 'y'), 10),
 (('a', 'r', 'd'), 10),
 (('r', 'k', '.'), 10),
 (('h', 'a', 'l'), 10),
 (('o', 'w', 'l'), 10),
 (('e', 'y', '.'), 9),
 (('i', 'n', 'e'), 9),
 (('p', 'i', 'g'), 9),
 (('f', 'a', 'l'), 9),
 (('c', 'h', 'i'), 9),
 (('a', 'r', 'k'), 9),
 (('r', 'i', 'c'), 9),
 (('p', 'a', 'r'), 9),
 (('a', 'm', 'e'), 9),
 (('a', 'n', '.'), 9),
 (('a', 'n', 'd'), 9),
 (('a', 'l', 'e'), 9),
 (('h', 'e', 'r'), 9),
 (('a

In [142]:
chars = sorted(list(set(''.join(animals)))) # list of characters used
stoi = {s: i+1 for i,s in enumerate(chars)} # mapping string to index
stoi['.'] = 0 # end character
itos = {i: s for s,i in stoi.items()} # mapping index to string

In [143]:
# Creating the traning set of all the trigrams
xs, ys = [], []
for a in animals:
    chs = ['.'] + list(a) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num_examples = xs.nelement() // 2
num_classes = len(stoi)

In [144]:
print(f'Number of traning examples: {num_examples}')
print(f'Size of traning data: {xs.shape}')
print(f'Size of target data: {ys.shape}')

Number of traning examples: 4056
Size of traning data: torch.Size([4056, 2])
Size of target data: torch.Size([4056])


In [145]:
# Combine the one=hot encodings into a single vector
x_one_hot = F.one_hot(xs, num_classes=num_classes).float()
x_one_hot = x_one_hot.view(num_examples, -1)  # shape: [num_examples, 2*num_classes]

In [146]:
# Each row of the matrix represents a bigram encoded
# Each digit of the biagram are 28 bits or num_classes bits
print(f'Shape of input data: {x_one_hot.shape}')

Shape of input data: torch.Size([4056, 56])


In [114]:
# Initialize a single weight matrix
g = torch.Generator().manual_seed(2147483647)
W =  torch.randn((2*num_classes, num_classes), generator=g, requires_grad=True)
print(f'Shape of matrix of weights: {W.shape}')

Shape of matrix of weights: torch.Size([56, 28])


In [147]:
# Gradient descent
for k in range(1):
    # Forward pass
    logits = x_one_hot @ W  # shape: [num_examples, num_classes]
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True) # normalize each row
    loss = -probs[torch.arange(num_examples), ys].log().mean()
    print(loss.item())
    
    # Backward pass
    W.grad = None
    loss.backward()
    
    # Update weights
    W.data += -50 * W.grad

torch.Size([4056, 28])
2.152465581893921


In [135]:
# Finally, we sample from the neural net model
g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    out = ['.']
    ix = 0
    while True:
        # Prepare the input from the last two characters
        if len(out) > 1:
            ix1, ix2 = stoi[out[-2]], stoi[out[-1]]  # Last two characters
        else:
            ix1, ix2 = stoi['.'], stoi[out[-1]]  # Use '.' if only one character has been generated
        
        # Create one-hot encoding for the concatenated last two characters
        xenc1 = F.one_hot(torch.tensor([ix1]), num_classes=num_classes).float()
        xenc2 = F.one_hot(torch.tensor([ix2]), num_classes=num_classes).float()
        xenc = torch.cat((xenc1, xenc2), dim=-1)  # Concatenate the encodings
        
        # Generate logits for the next character
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        
        # Sample the next character
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        
        # Break if the end token is generated or length exceeds a limit (e.g., to avoid infinite loops)
        if ix == 0 or len(out) > 20:  # Adjust the length limit as needed
            break
    
    # Print the generated word, excluding the start/end token
    print(''.join(out[1:-1]))

ctoick
hor bug
orkey falamperh
musnorot
estal
chees bueank
ug briac haly la
crine
mat
ansaneerttlighendst
