In [None]:
import string
import unicodedata
import torch

# Some name contains space, comma, or apostrophe.
characters = set(string.ascii_letters + " ,'")

token_to_index: dict[str, int] = {}
for c in characters:
    token_to_index[c] = len(token_to_index)
    
index_to_token: dict[int, str] = {i: c for c, i in token_to_index.items()}

def t2i(token: str) -> int:
    return token_to_index[token]

def i2t(index: int) -> str:
    return index_to_token[index]

def unicode_to_ascii(s) -> str:
    return "".join(
        c
        for c in unicodedata.normalize("NFD", s)
        if unicodedata.category(c) != "Mn" and c in characters
    )

# Returns a one-hot encoded tensor for a name.
# Shape: (len(name), 1, len(characters))
def name_to_tensor(name: str) -> torch.Tensor:
    tensor = torch.zeros(len(name), 1, len(characters))
    for i, c in enumerate(name):
        tensor[i][0][t2i(c)] = 1
    return tensor

def tensor_to_name(tensor: torch.Tensor) -> str:
    name: list[str] = []
    for i in range(tensor.size(0)):
        index = int(tensor[i][0].argmax().item())
        name.append(i2t(index))
    return "".join(name)

In [None]:
from data.names_dataset import NamesDataset


names_dataset = NamesDataset("../datasets/names", transform_input=unicode_to_ascii)
all_characters = set()
for name, country_index in names_dataset:
    all_characters.update(name)

print(len(all_characters))
print(sorted(all_characters))
print(tensor_to_name(name_to_tensor("O'Connor")))
print(tensor_to_name(name_to_tensor("Smith, John")))



55
[' ', "'", ',', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
