In [None]:
import string
import unicodedata
import torch
import torch.nn.functional as F

# Some name contains space, comma, or apostrophe.
characters = set(string.ascii_letters + " ,'")

token_to_index: dict[str, int] = {}
for c in characters:
    token_to_index[c] = len(token_to_index)
    
index_to_token: dict[int, str] = {i: c for c, i in token_to_index.items()}

def t2i(token: str) -> int:
    return token_to_index[token]

def i2t(index: int) -> str:
    return index_to_token[index]

def unicode_to_ascii(s) -> str:
    return "".join(
        c
        for c in unicodedata.normalize("NFD", s)
        if unicodedata.category(c) != "Mn" and c in characters
    )

# Returns a one-hot encoded tensor for a name.
def str_to_one_hot(name: str) -> torch.Tensor:
    return F.one_hot(
        torch.tensor([t2i(c) for c in name]),
        num_classes=len(characters)
    ).unsqueeze(0)


tensor = str_to_one_hot("O'Connor")
print(tensor)
print(tensor.size())
print(tensor.shape)

In [None]:
from data.names_dataset import NamesDataset

names_dataset = NamesDataset(
    data_folder="../datasets/names",
    transform_input=unicode_to_ascii,
    transform_output=str_to_one_hot,
    transform_label=str_to_one_hot,
)

In [None]:
def collate_fn(batch) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
    inputs, labels = zip(*batch)
    return list(inputs), list(labels)

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64

train_dataset, test_dataset = torch.utils.data.random_split(names_dataset, [0.85, 0.15])

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

train_features, train_labels = next(iter(train_dataloader))
print(train_features)
print(train_labels)
print(len(train_features))
print(len(train_labels))
print(train_features[0].size())
print(train_labels[0].size())
