In [1]:
import os 
from tqdm import tqdm, trange
import torch 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [2]:
data_path = "./data/names.txt"

with open(data_path, "r") as f:
    raw_names = f.readlines()

names_processed  = [name.lower().strip() for name in raw_names]

In [3]:
def get_vocab(list_of_names):
    vocab = {}
    idx = 0
    for each in list_of_names:
        for char in each:
            if char not in vocab:
                vocab[char] = idx
                idx+=1
    vocab["*"] = idx
    idx+=1
    vocab["<EOF>"] = idx
    return vocab

In [4]:
def pad(list_of_names):
    max_len = max([len(name) for name in list_of_names])
    out = []

    for name in list_of_names:
        req = max_len - len(name)
        chars = list(name)
        chars += ["<EOF>"] + ["*"] * req
        out.append(chars)
    return out

In [5]:
def list_to_torch(list_of_names, vocab):
    out = []
    for name in list_of_names:
        chars = [vocab[char] for char in name]
        out.append(chars)
    
    return torch.Tensor((out))
    


In [20]:
vocab = get_vocab(names_processed)
# print(vocab)
padded_names = pad(names_processed)
padded_tgt = []
for name in padded_names:
    padded_tgt.append(name[1:] + ["*"])
input_data = list_to_torch(padded_names, vocab)
input_targets = list_to_torch(padded_tgt, vocab)


class CustomSet(torch.utils.data.Dataset):
    def __init__(self, tgt, src):
        self.tgt = tgt
        self.src = src
        self.transform = None
        self.target_transform = None
    def __len__(self):
        return len(self.src)
    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

dset = CustomSet(input_data, input_targets)

In [21]:
class Gen(nn.Module):
    def __init__(self, num_layers, vocab, emb_size, hidden_size):
        super(Gen, self).__init__()
        self.num_layers = num_layers
        self.vocab = vocab
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.emb = nn.Embedding(len(self.vocab), self.emb_size, padding_idx=self.vocab["*"])
        self.gru = nn.GRU(self.emb_size, self.hidden_size, num_layers=self.num_layers, batch_first=True, dropout=0.3)
        self.fc = nn.Linear(self.hidden_size, len(self.vocab))
        self.relu = nn.ReLU()

    def forward(self, X, h):
        X = self.emb(X)
        X, h = self.gru(X, h)
        X = self.relu(X)
        X = self.fc(X)
        

        return X, h
    def init_hidden(self,b):
        return torch.zeros(self.num_layers,b, self.hidden_size)


    

In [22]:
num_epochs = 10
model = Gen(2, vocab, 200, 300)
optim = torch.optim.Adam(model.parameters(), lr = 0.003)
loss_fn = nn.CrossEntropyLoss()
batch_size = 128
losses = []

loader = torch.utils.data.DataLoader(dset, batch_size=128, shuffle=True)
for ep in trange(num_epochs):
    h  = model.init_hidden(128)
    for ix, (src, tgt) in tqdm(enumerate(loader)):
        optim.zero_grad()
        if src.shape[0] != batch_size:
            continue
        # print(src.shape, tgt.shape)
        preds, h = model(src.long(), h)
        h = h.detach()
        # print(preds.shape)
        preds=preds.permute(0,2,1)
        loss = loss_fn(preds,tgt.long())
        loss.backward()
        losses.append(loss.item())
        optim.step()


143it [00:40,  3.57it/s]0:00<?, ?it/s]
143it [00:49,  2.88it/s]0:40<06:00, 40.01s/it]
143it [00:58,  2.47it/s]1:29<06:05, 45.66s/it]
143it [00:55,  2.58it/s]2:27<05:59, 51.31s/it]
143it [01:09,  2.05it/s]3:23<05:17, 52.97s/it]
143it [01:04,  2.21it/s]4:32<04:55, 59.04s/it]
143it [00:59,  2.42it/s]5:37<04:03, 60.98s/it]
143it [01:10,  2.01it/s]6:36<03:01, 60.38s/it]
143it [01:00,  2.37it/s]7:47<02:07, 63.76s/it]
143it [01:02,  2.28it/s]8:48<01:02, 62.72s/it]
100%|██████████| 10/10 [09:50<00:00, 59.10s/it]


In [23]:
torch.save(model.state_dict(), "weights.pth")

with torch.no_grad():
    model.eval()


In [27]:
starting = "a"
length = 10
model.load_state_dict(torch.load("weights.pth"))
x = torch.tensor(vocab[starting])
x = x.unsqueeze(0).unsqueeze(0)
print(x.shape)
word = []
with torch.no_grad():
    model.eval()
    h = model.init_hidden(1)
    print(h.shape)
    for each in range(length):
        pred, h = model(x, h)
    
        # print(pred.shape)
        pred = pred.softmax(dim=2).squeeze(1)
        # print(pred.shape)
        pred = torch.argmax(pred, dim=1)
        print(pred.shape)
        x = pred.unsqueeze(1)
        word.append(pred.item())
for each in word:
    print(list(vocab.keys())[list(vocab.values()).index(each)])  # Prints george




torch.Size([1, 1])
torch.Size([2, 1, 300])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])
m
a
m
a
m
a
m
a
m
a
