In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

In [None]:
import string
import unidecode
all_chars = string.printable
n_chars = len(all_chars)

with open("names.txt", "r") as f:

    names = [each.lower().strip() for each in f.readlines()]

In [None]:
print(all_chars)
print(all_chars.index("*")) # use as pad character
print(all_chars.index("$")) # use as EOF character

In [None]:
out = [torch.Tensor([all_chars.index(char) for char in each] + [65]).long() for each in names]
out = torch.nn.utils.rnn.pad_sequence(out, batch_first=True, padding_value= 71)


In [None]:
class Bob(nn.Module):
    def __init__(self, in_size, hidden_size, out_size, n_layers,device):
        super(Bob, self).__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.out_size = out_size
        self.n_layers = n_layers
        self.device = device

        self.embed = nn.Embedding(in_size, hidden_size, padding_idx=71)
        torch.nn.init.xavier_uniform_(self.embed.weight)
        # print(self.embed.weight.isnan().any())
        # print(f"Max {self.embed.weight.max()} Min {self.embed.weight.min()}")
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, batch_first = True)
        self.relu = nn.ReLU()
        self.out = nn.Linear(hidden_size, out_size)
    def generate_sample(self,bs,len_gen, start="a" ):
        x = torch.zeros(bs, 1).long().to(self.device)
        for ix in range(x.shape[0]):
            x[ix] = all_chars.index(start)
        out = [start for _ in range(bs)]
        x, hidden, _ =  self.forward(x)
        x = F.softmax(x, dim=2)
        print("FIRST OUT", x.shape, x.isnan().any())
        preds = torch.argmax(x, dim=2)
        print("FRIST PRED:",preds.shape, preds.isnan().any())
        # print("FIRST PRED:", preds)
        for ix, each in enumerate(out):
            out[ix] += all_chars[preds[ix].item()]
        for _ in range(len_gen-1):
            x, hidden, _ = self.forward(preds, hidden)
            x = F.softmax(x, dim=2)
            hidden = hidden.detach()
            print("OUT",x.shape, preds.isnan().any())
            preds = torch.argmax(x, dim=2)
            print("PREDS:",preds.shape, preds.isnan().any())
            # print(preds)
            x = preds
            for ix, each in enumerate(out):
                print(preds[ix].item())
                out[ix] += all_chars[preds[ix].item()]
        
        return out

    def forward(self, x, hidden=None):
        out = None
        out = 0 if x.isnan().any() and out is None else out
        # x = F.one_hot(x)
        x = self.embed(x)
        out = 1 if x.isnan().any() and out is None else out
        # print(f"AFTER EMBED: {x_e.isnan().any()}")
        # print(x_e[:1])
        if hidden is None:
            x, hidden = self.gru(x)
        else:
            x, hidden = self.gru(x, hidden)
        
        out = 2 if x.isnan().any() and out is None else out

        # print(f"AFTER GRU: {x.isnan().any()}")

        # print(x[:1])
        x = self.relu(x)
        out = 3 if x.isnan().any() and out is None else out

        # print(f"AFTER Leaky: {x.isnan().any()}")

        # print(x[:1])
        x = self.out(x)
        out = 4 if x.isnan().any() and out is None else out


        # print(f"AFTER FC: {x.isnan().any()}")

        # print(x[:1])
        return x, hidden, out
        


In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, src):
        self.src= src
        self.tgt = torch.zeros_like(src)

        self.tgt[:, :-1] = self.src[:, 1:]
        self.tgt[:, -1] = 71.

    def __len__(self):
        return len(self.src)
    def __getitem__(self,idx):
        return self.src[idx], self.tgt[idx]

In [None]:
torch.autograd.anomaly_mode.set_detect_anomaly(True)
customset = CustomDataset(out)
loader = torch.utils.data.DataLoader(customset, batch_size=128, shuffle=True)
device = torch.device("mps")
num_epochs = 50
loss_fn = nn.CrossEntropyLoss()
model = Bob(n_chars, 256, n_chars, 2, device).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.0003)
first_nan = None
nan_loss = None
for ep in (pbar :=tqdm(range(num_epochs))):
    for idx, (src, tgt) in enumerate(loader):
        # print(idx)
        if src.isnan().any():
            ValueError("WTF BRO")

        optim.zero_grad()
        src = src.to(device)
        tgt = tgt.to(device).long()
        # print("GRRR",src.isnan().any())
        preds, _, nanthrown = model(src)
        has_nan = preds.isnan().any()
        if has_nan:
            print("FOUND NAN IN PREDS AT: {}",str(ep) + ","+ str(idx))
        first_nan = str(ep) + ","+ str(idx) if has_nan and first_nan is None else first_nan
        loss = loss_fn(preds.permute(0,2,1), tgt)
        # print(loss)

        nan_loss = str(ep) + ","+ str(idx) if loss.isnan() and nan_loss is None else nan_loss
        if nan_loss:
            print("FOUND NAN IN lOSS AT: {}",str(ep) + ","+ str(idx))

        loss.backward()
        for param in model.parameters():
            torch.clamp(param, -1, 1)
            pbar.set_description(f"Last Loss: {loss.item() if idx > 0 else 0} {ep},{idx} {each.isnan().any().item(), each.isinf().any().item(), each.isneginf().any().item()}")

        optim.step()
        # pbar.set_description("Ep: {} First Nan: {} Has Nan: {} Nan_loss: {} Where Thrown: {} Iter:{}/{} Loss: {}".format(ep,first_nan, has_nan,nan_loss,nanthrown,idx,len(loader), loss))
        # break
    # break
torch.save(model.state_dict(), "weights.pth")





In [None]:
# model.load_state_dict(torch.load("weights.pth"))
model.eval()
model.generate_sample(4, 5)