In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import random

In [2]:
def make_training_samples(names, block_size, stoi):

    xs = []
    ys = []

    for name in names:
        name = f"{'.' * block_size}{name}"

        for i in range(len(name) - block_size):
            x = [stoi[char] for char in name[i:i + block_size]]
            y = stoi[name[i + block_size]]

            xs.append(x)
            ys.append(y)

    return torch.tensor(xs, dtype=torch.long), torch.tensor(ys, dtype=torch.long)

In [3]:
names = open('names.txt').read().split('\n')

vocab = sorted(list(set(list('.'.join(names)))))
vocab_size = len(vocab)
stoi = {ix:char for char, ix in enumerate(vocab)}
itos = {char:ix for char, ix in enumerate(vocab)}


random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

block_size = 8
batch_size = 256
d_model = 24
n_hidden = 128
n_flat_steps = 2

Xtr,  Ytr  = make_training_samples(names[:n1], block_size, stoi)     # 80%
Xdev, Ydev = make_training_samples(names[n1:n2], block_size, stoi)   # 10%
Xte,  Yte  = make_training_samples(names[n2:], block_size, stoi)     # 10%

In [4]:
def encode(data, is_label = None):

    x = [stoi[char] for name in data for char in name]
    
    return x

In [5]:
class FlattenConsecutive:
  
  def __init__(self, n):
    self.n = n
    
  def __call__(self, x):
    B, T, C = x.shape
    x = x.reshape(B, T//self.n, C*self.n)
    if x.shape[1] == 1:
      x = x.squeeze(1)
    self.out = x
    return self.out
  
  def parameters(self):
    return []

In [6]:
class MlpBatchNorm(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size):
        super(MlpBatchNorm, self).__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.linear = nn.Linear(d_model * 2, n_hidden, bias=False)
        self.linear2 = nn.Linear(n_hidden * 2, n_hidden, bias=False)
        self.linear3 = nn.Linear(n_hidden * 2, n_hidden, bias=False)

        self.bnorm = nn.BatchNorm1d(n_hidden)
        self.bnorm2 = nn.BatchNorm1d(n_hidden)
        self.bnorm3 = nn.BatchNorm1d(n_hidden)

        self.flat = FlattenConsecutive(2)
        self.out_layer = nn.Linear(n_hidden, vocab_size)

    def forward(self, x):

        x = self.embedding(x)
        
        x = self.flat(x)
        x = self.linear(x)
        x = x.permute(0, 2, 1)
        x = self.bnorm(x)
        x = x.permute(0, 2, 1)
        x = F.tanh(x)
        
        x = self.flat(x)
        x = self.linear2(x)
        x = x.permute(0, 2, 1)
        x = self.bnorm2(x)
        x = x.permute(0, 2, 1)
        x = F.tanh(x)

        x = self.flat(x)
        x = self.linear3(x)
        x = self.bnorm3(x)
        x = F.tanh(x)
        
        return self.out_layer(x)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MlpBatchNorm(vocab_size, d_model, n_hidden, vocab_size).to(device)
print(net)

# Enable cuDNN backend for performance optimization
if device.type == 'cuda':
    torch.backends.cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(), lr=0.1)

Xval = Xdev.to(device)
Yval = Ydev.to(device)
print(sum(p.nelement() for p in net.parameters()))

MlpBatchNorm(
  (embedding): Embedding(27, 24)
  (linear): Linear(in_features=48, out_features=128, bias=False)
  (linear2): Linear(in_features=256, out_features=128, bias=False)
  (linear3): Linear(in_features=256, out_features=128, bias=False)
  (bnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (out_layer): Linear(in_features=128, out_features=27, bias=True)
)
76579


In [9]:
optimizer = optim.SGD(net.parameters(), lr=0.01)
batch_size = 64

for i in range(100000):

    lr = [0.01 if i >= 20000 else 0.1]
    optimizer = optim.SGD(net.parameters(), lr=lr[0])

    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    Xb, Yb = Xtr[ix].to(device), Ytr[ix].to(device) 

    optimizer.zero_grad()   
    logits = net(Xb)
    loss = F.cross_entropy(logits, Yb)
    loss.backward()
    optimizer.step()    

    val_logits = net(Xval)
    val_loss = F.cross_entropy(val_logits, Yval)
    
    if i % 1000 == 0:
        print(f'Loss: {loss:.4f} --- Val Loss: {val_loss:.4f}')

Loss: 3.3304 --- Val Loss: 3.2947
Loss: 2.3732 --- Val Loss: 2.3175
Loss: 2.2516 --- Val Loss: 2.2792
Loss: 2.2597 --- Val Loss: 2.2488
Loss: 2.3274 --- Val Loss: 2.2393
Loss: 1.9547 --- Val Loss: 2.2126
Loss: 2.1109 --- Val Loss: 2.2051
Loss: 2.1979 --- Val Loss: 2.1938
Loss: 2.0756 --- Val Loss: 2.1836
Loss: 2.1412 --- Val Loss: 2.1884
Loss: 2.1585 --- Val Loss: 2.1707
Loss: 2.0407 --- Val Loss: 2.1766
Loss: 2.2933 --- Val Loss: 2.1684
Loss: 2.0360 --- Val Loss: 2.1642
Loss: 2.0160 --- Val Loss: 2.1569
Loss: 1.9014 --- Val Loss: 2.1547
Loss: 1.9958 --- Val Loss: 2.1489
Loss: 2.3931 --- Val Loss: 2.1592
Loss: 2.1906 --- Val Loss: 2.1517
Loss: 2.0623 --- Val Loss: 2.1473
Loss: 2.0454 --- Val Loss: 2.1547
Loss: 1.8986 --- Val Loss: 2.1013
Loss: 1.8259 --- Val Loss: 2.0975
Loss: 1.9925 --- Val Loss: 2.0970
Loss: 2.0301 --- Val Loss: 2.0922
Loss: 2.1724 --- Val Loss: 2.0907
Loss: 2.0855 --- Val Loss: 2.0918
Loss: 1.8842 --- Val Loss: 2.0917
Loss: 1.7608 --- Val Loss: 2.0927
Loss: 2.0884 -