In [139]:
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [140]:
with open('./data/names.txt', "r") as f:
    names = [each.lower().strip() for each in f.readlines()]

all_chars="abcdefghijklmnopqrstuvwxyz*$"
n_chars = len(all_chars)
data = [[all_chars.index(char) for char in each] for each in names]
total_lens = [len(each) for each in data]

In [141]:
print(all_chars.index("*"))
print(all_chars.index("$"))

data_tensors = torch.nn.utils.rnn.pad_sequence([torch.Tensor(each).int() for each in data], batch_first=True, padding_value=26)


26
27


In [142]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, src, lens):
        super(CustomDataset, self).__init__()

        self.src = src
        self.tgt = torch.zeros_like(src)
        self.lens = lens

        self.tgt[:,:-1] = self.src[:,1:]
        self.tgt[:,-1] = 26

        for ix, each in enumerate(self.lens):
            self.tgt[ix, each-1] = 27


    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.lens[idx], self.tgt[idx]


In [195]:
class Gen(nn.Module):
    def __init__(self, in_size, hidden_size, out_size, lens, n_layers,dropout, device):
        super(Gen, self).__init__()

        self.num_layers = n_layers
        self.hidden_size = hidden_size
        self.device = device

        self.lens = lens
        self.embed = nn.Embedding(in_size, hidden_size, padding_idx=26)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=n_layers, batch_first=True, dropout=dropout)
        self.out = nn.Linear(hidden_size, out_size)

    def init_hidden(self, bs):
        h = torch.zeros(self.num_layers, bs, self.hidden_size).to(self.device)
        c = torch.zeros(self.num_layers, bs, self.hidden_size).to(self.device)
        return h,c
    def forward (self, X,lens, h, c):
        X = self.embed(X)
        # print("IN", X.shape)
        if lens is not None:
            X = torch.nn.utils.rnn.pack_padded_sequence(X, lens, batch_first=True, enforce_sorted=False)
            X, (h, c )= self.lstm(X,( h, c))
            X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True,total_length=15)
        else:
            X, (h,c) = self.lstm(X, (h,c))
        # X = F.relu(X)

        X = self.out(X)

        return X,h,c
    
    def generate(self,start="a", gen_len=7,temperature=0.85):

        hidden, cell = self.init_hidden(1)
        initial_input = torch.Tensor([all_chars.index(each) for each in start]).int().unsqueeze(0)
        # print(initial_input.shape)
        predicted = start

        for p in range(len(start) - 1):
            _, hidden, cell = self.forward(
                initial_input[p].view(1).to(self.device),None, hidden, cell
            )

        last_char = initial_input[-1].unsqueeze(0)
        # print("INITIAL LAST CHAR: ", last_char.shape)
        # print("SHAPED", last_char.view(1).shape)

        for p in range(gen_len):
            output, hidden, cell = self.forward(
                last_char.to(self.device),None, hidden, cell
            )
            # print("OUTPUTSHAPE",output.shape)
            output_dist = F.softmax(output, dim=2).squeeze(1)
            # print("OUDIST", output_dist.shape)
            top_char = torch.multinomial(output_dist, 1)[0]
            # print("TOP",top_char)
            predicted_char = all_chars[top_char]
            # print(predicted_char)
            predicted += predicted_char
            last_char = torch.Tensor([all_chars.index(predicted_char)]).int().unsqueeze(0)
            # print("LAST",last_char.shape, last_char)

        return predicted
     





In [196]:
bs = 128
dataset = CustomDataset(data_tensors, total_lens)
loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)


In [247]:
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(33)
num_epochs = 20
device = torch.device("mps")
# loss_fn = nn.CrossEntropyLoss()
model = Gen(n_chars, 256, n_chars, total_lens, 2, 0.5, device)
model.load_state_dict(torch.load("./weights/gen_weights_2_29.pt"))
model = model.to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.0003)
for ep in (pbar:=tqdm(range(num_epochs))):
    for idx, (src, lens, tgt) in enumerate(loader):
        # print(src[:3])
        # print(tgt[:3])
        # print(lens)
        optim.zero_grad()
        src = src.to(device)
        tgt = tgt.to(device)
        # print("SRC", src.shape)
        h, c = model.init_hidden(bs)
        preds, h, c = model(src, lens, h, c)
        # print("PREDS",preds.shape)
        loss = F.cross_entropy(preds.permute(0,2,1), tgt.long(), ignore_index=26)
        loss.backward()
        for name, param in model.named_parameters():
            param.grad.clamp(-1,1)
        optim.step()
        pbar.set_description(f"{ep+20}:{idx}/{len(loader)} Loss:{loss.item()}")
        if (ep+1) %5 == 0:
            torch.save(model.state_dict(), f"./weights/gen_weights_3_{ep+20}.pt")
        

        

  File "/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/doriclink/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/Users/doriclink/Library/Python/3.9/lib/python/site-packages/traitlets/config/application.py", line 978, in launch_instance
    app.start()
  File "/Users/doriclink/Library/Python/3.9/lib/python/site-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/Users/doriclink/Library/Python/3.9/lib/python/site-packages/tornado/platform/asyncio.py", line 215, in start
    self.asyncio_loop.run_forever()
  File "/Applications/Xcode.app/Co

RuntimeError: Function 'MpsLinearBackward0' returned nan values in its 0th output.

In [256]:
model.load_state_dict(torch.load("./weights/gen_weights_2_29.pt"))
model = model.to(device)
model.eval()
out = model.generate("e", 2)

print(out)

e*a
