In [1]:
import torch
import pickle

In [4]:
class Linear():
    def __init__(self, fan_in, fan_out, bias=True, device="cpu"):
        self.weight = torch.randn((fan_in, fan_out), device=device) / fan_in**0.5
        self.bias = torch.zeros(fan_out, device=device) if bias else None
    def __call__(self,X):
        self.out = X @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out
    def parameters(self):
        return [self.weight] + ([self.bias] if self.bias is not None else [])

class BatchNorm1d():
    def __init__(self, dim, eps=1e-5, momentum=0.001, device="cpu"):
        self.gamma = torch.ones((1,dim), device=device)
        self.beta = torch.zeros((1,dim), device=device)
        self.eps = eps
        self.momentum = momentum
        self.running_mean = torch.zeros((1,dim), device=device)
        self.running_var = torch.ones((1,dim), device=device)
        self.training = True
    def __call__(self, X):
        if self.training:
            with torch.inference_mode():
                mean = X.mean(0, keepdim=True)
                var = X.var(0, keepdim=True)
                self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1-self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        return self.gamma * (X - mean) / (torch.sqrt(var)+self.eps) + self.beta
    def parameters(self):
        return [self.gamma, self.beta]

class Embedding():
    def __init__(self, num_embeddings, dim, device="cpu"):
        self.weight = torch.randn((num_embeddings, dim), device=device)
    def __call__(self, ix):
        self.out = self.weight[ix]
        return self.out
    def parameters(self):
        return [self.weight]

class Flatten():
    def __call__(self, X):
        self.out = X.reshape(X.shape[0],-1)
        return self.out
    def parameters(self):
        return []

class Tanh():
    def __call__(self, X):
        self.out = torch.tanh(X)
        return self.out
    def parameters(self):
        return []

class Sequential():
    def __init__(self, layers=[]):
        self.layers=layers
    def __call__(self, X):
        for layer in self.layers:
            X = layer(X)
        return X
    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]

In [5]:
model = torch.load('params/model.pt')
chars = pickle.load(open('params/chars.pkl', 'rb'))
char_index = pickle.load(open('params/char_index.pkl', 'rb'))

In [16]:
device = "cuda"

In [17]:
def next_char(prev_8):
    prev_8 = torch.tensor(prev_8).reshape(1,8).to(device)
    probs = torch.softmax(model(prev_8), dim=1)
    return chars[torch.multinomial(probs, 1).item()]
def get_name():
    name = ''
    prev_8 = [0,0,0,0,0,0,0,0]
    next = next_char(prev_8)
    while next != '!':
        name += next
        prev_8 = prev_8[1:] + [char_index[next]]
        next = next_char(prev_8)
    return name

In [18]:
for i in range(50):
    print(get_name())

magirathi
indravasaha
dharshini
amitiika
anruthan
banaka
anavinana
eyaankar
kishwata
nirjety
varishan
prasinipan
kalavan
sreethavendux
yukisney
bhashvini
jalais
anil
yesuthayan
nirmipa
ashuka
irangaj
mahasidha
avallika
bhasukhan
niresh
roshii
nalini
vanesha
thuvena
aamil
mayamesh
krishna
manavayuha
hrishyan
kavinan
binani
jeyasha
yaveenan
chayuth
lovan
jeeyani
legasha
kaleshwen
kalishree
eyallikan
anuriva
pamanya
ibhkuyan
nerudheep


In [27]:
def get_normalised_negative_log_likelihood(word):
    prev_8 = torch.tensor([0,0,0,0,0,0,0,0]).reshape(1,8).to(device)
    nll = 0
    for c in word:
        probs = torch.softmax(model(prev_8), dim=1)
        nll += -torch.log(probs[:,char_index[c]])
        prev_8 = torch.cat((prev_8[:,1:], torch.tensor([char_index[c]]).reshape(1,1).to(device)), dim=1)
    return nll / len(word)

In [29]:
get_normalised_negative_log_likelihood('arjun').item(), get_normalised_negative_log_likelihood('steve').item()

(2.286522626876831, 3.1762070655822754)