In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

In [None]:
!wget http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/female.txt

--2022-11-24 07:42:29--  http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/female.txt
Resolving www.cs.cmu.edu (www.cs.cmu.edu)... 128.2.42.95
Connecting to www.cs.cmu.edu (www.cs.cmu.edu)|128.2.42.95|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35751 (35K) [text/plain]
Saving to: ‘female.txt’


2022-11-24 07:42:29 (445 KB/s) - ‘female.txt’ saved [35751/35751]



In [None]:
device = 'cuda'

In [None]:
with open('female.txt', 'r') as f:
    lines = f.readlines()


['# List of common female names.\n', '# Copyright (c) January 1991 by Mark Kantrowitz.\n', '# 4987 names\n', '# Thanks to Bill.Ross for about 1000 additional names.\n', '# Version 1.3 (29-MAR-94)\n', '\n', 'Abagael\n', 'Abagail\n', 'Abbe\n', 'Abbey\n', 'Abbi\n', 'Abbie\n', 'Abby\n', 'Abigael\n', 'Abigail\n', 'Abigale\n', 'Abra\n', 'Acacia\n', 'Ada\n', 'Adah\n', 'Adaline\n', 'Adara\n', 'Addie\n', 'Addis\n', 'Adel\n', 'Adela\n', 'Adelaide\n', 'Adele\n', 'Adelice\n', 'Adelina\n', 'Adelind\n', 'Adeline\n', 'Adella\n', 'Adelle\n', 'Adena\n', 'Adey\n', 'Adi\n', 'Adiana\n', 'Adina\n', 'Adora\n', 'Adore\n', 'Adoree\n', 'Adorne\n', 'Adrea\n', 'Adria\n', 'Adriaens\n', 'Adrian\n', 'Adriana\n', 'Adriane\n', 'Adrianna\n', 'Adrianne\n', 'Adrien\n', 'Adriena\n', 'Adrienne\n', 'Aeriel\n', 'Aeriela\n', 'Aeriell\n', 'Ag\n', 'Agace\n', 'Agata\n', 'Agatha\n', 'Agathe\n', 'Aggi\n', 'Aggie\n', 'Aggy\n', 'Agna\n', 'Agnella\n', 'Agnes\n', 'Agnese\n', 'Agnesse\n', 'Agneta\n', 'Agnola\n', 'Agretha\n', 'Aida\n',

In [None]:
# read the lines
with open('female.txt', 'r') as f:
    lines = f.readlines()

names = []
max_len = 0
for l in lines[6:]:
    # delete last line spacing and make all lower case.
    curr_name = l[:-1].lower()
    if curr_name.isalpha():
        names.append(curr_name)
        max_len = max(len(names[-1]), max_len)
max_len += 1 # consider the 'EOS' (end of signal)
print('Maximum Length : ' + str(max_len))

Maximum Length : 14


In [None]:
class NameDataset(Dataset):
    def __init__(self, names, max_len):
        self.names = names
        self.max_len = max_len
        self.a_order = ord('a') 
        self.z_order = ord('z') 
        self.num_classes = 26 + 1 # a-z + include the end of signal

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        padding_name = [self.num_classes-1 for _ in range(self.max_len)]
        curr_name = [ord(n)-self.a_order for n in names[idx]]
        padding_name[:len(curr_name)] = curr_name
        
        # Slide the input to make a output
        sample = dict()
        sample['input'] = torch.LongTensor(padding_name[:-1]) # h y e m i n  -1 -1 -1
        sample['output'] = torch.LongTensor(padding_name[1:]) # y e m i n -1 -1 -1 -1
        sample['length'] = len(names[idx])
        sample['original'] = names[idx]

        return sample

In [None]:
batch_size = 64
dataset = NameDataset(names, max_len)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
sample = next(iter(dataloader))
print(sample['input'][0])
print(sample['output'][0])
print(sample['length'][0])
print(sample['original'][0])
print(sample['input'].shape, sample['output'].shape)

tensor([ 4, 12, 12,  0, 26, 26, 26, 26, 26, 26, 26, 26, 26])
tensor([12, 12,  0, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26])
tensor(4)
emma
torch.Size([64, 13]) torch.Size([64, 13])


In [None]:
# This could be useful with variable lengths
total_lengths = sample['length']
sort_length, sort_idx = torch.sort(total_lengths, descending=True)
sort_input = sample['input'][sort_idx]
sort_output = sample['output'][sort_idx]
print(sort_length)
print(sort_input.shape)

tensor([11, 10,  9,  9,  9,  9,  9,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,
         4,  4,  4,  4,  4,  4,  4,  4,  4,  3])
torch.Size([64, 13])


In [None]:
class RNNmodel(nn.Module):
    def __init__(self, lstm_dim=256, num_classes=dataset.num_classes, max_len=max_len):
        super(RNNmodel, self).__init__()
        self.lstm_dim = lstm_dim
        self.num_classes = num_classes
        self.max_len = max_len
        self.char_embedding = nn.Embedding(num_embeddings=num_classes, 
                                           embedding_dim=lstm_dim)
        self.lstm = nn.LSTM(input_size=lstm_dim, 
                            hidden_size=lstm_dim,
                            num_layers=1, 
                            batch_first=True,
                            )
        
        self.out_linear = nn.Linear(lstm_dim, num_classes)

    def forward(self, sort_input, sort_output, sort_length):
        ## originally, recommended to use torch.nn.utils.rnn.pack_padded_sequence,when we have variable lengths
        ## but in this case, I just neglected it because beginners can be more confused with this
        lstm_input = self.char_embedding(sort_input)
        lstm_out, (h, c) = self.lstm(lstm_input)
        out = self.out_linear(lstm_out)
        
        return out

    def test(self, start_char):
        generated_name = list()
        generated_name.append(start_char)

        start_order = torch.LongTensor([ord(start_char)]).to(device) - ord('a')
        start_order = start_order.reshape(1, 1)
        cnt = 0

        while cnt <= self.max_len:
            curr_embed = self.char_embedding(start_order)
            if cnt == 0:
                lstm_out, (h, c) = self.lstm(curr_embed)
            else:
                lstm_out, (h, c) = self.lstm(curr_embed, (h, c))
            out = self.out_linear(lstm_out)

            sample_next = torch.distributions.Categorical(logits = out[0, 0, :]).sample().item()
            if sample_next == 26:
                break
            else:
                generated_name.append(chr(ord('a')+sample_next))
                sample_next = torch.LongTensor([sample_next]).to(device)
                start_order = sample_next.reshape(1, 1)

                cnt += 1

        return ''.join(generated_name)


In [None]:
model = RNNmodel()
model(sample['input'], sample['output'], sample['length'])
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def train(model, optimizer, sample):
    optimizer.zero_grad()
    criteria = nn.CrossEntropyLoss()

    total_lengths = sample['length']
    sort_length, sort_idx = torch.sort(total_lengths, descending=True)

    sort_input = sample['input'][sort_idx].to(device)
    sort_output = sample['output'][sort_idx].to(device)
    sort_length = sort_length.to(device)

    pred = model(sort_input, sort_output, sort_length) # B T C
    B, T, C = pred.shape
    
    curr_loss = criteria(pred.reshape(B*T, C), sort_output.reshape(B*T))

    curr_loss.backward()
    optimizer.step()

    return curr_loss.item()

In [None]:
max_epoch = 200
for epoch in tqdm(range(max_epoch)):
    total_loss = 0.0
    for sample in dataloader:
        curr_loss = train(model, optimizer, sample)
        total_loss += curr_loss / len(dataloader)

    start_char = chr(np.random.randint(ord('a'), ord('z')))
    print('[EPOCH {}] TRAIN LOSS: {}, SAMPLED NAME: {}'.format(epoch,
                                                               total_loss, 
                                                               model.test(start_char)))



  0%|          | 0/200 [00:00<?, ?it/s]

[EPOCH 0] TRAIN LOSS: 1.9628541209758863, SAMPLED NAME: gkkurulmytplzi
[EPOCH 1] TRAIN LOSS: 1.2664598715610995, SAMPLED NAME: vzjbuqzgweelfs
[EPOCH 2] TRAIN LOSS: 1.1441614887653249, SAMPLED NAME: clrakn
[EPOCH 3] TRAIN LOSS: 1.0811922068779278, SAMPLED NAME: quwpbmeru
[EPOCH 4] TRAIN LOSS: 1.0372098676669295, SAMPLED NAME: ltghita
[EPOCH 5] TRAIN LOSS: 1.003789742023517, SAMPLED NAME: folorie
[EPOCH 6] TRAIN LOSS: 0.9764890861816892, SAMPLED NAME: f
[EPOCH 7] TRAIN LOSS: 0.9535979674412656, SAMPLED NAME: celle
[EPOCH 8] TRAIN LOSS: 0.9346155401987908, SAMPLED NAME: omili
[EPOCH 9] TRAIN LOSS: 0.918441164187896, SAMPLED NAME: uuty
[EPOCH 10] TRAIN LOSS: 0.9040813468969786, SAMPLED NAME: carline
[EPOCH 11] TRAIN LOSS: 0.8914285661318365, SAMPLED NAME: wv
[EPOCH 12] TRAIN LOSS: 0.8802611086613092, SAMPLED NAME: prellanine
[EPOCH 13] TRAIN LOSS: 0.8704198018098487, SAMPLED NAME: carmya
[EPOCH 14] TRAIN LOSS: 0.8605944965130242, SAMPLED NAME: eclane
[EPOCH 15] TRAIN LOSS: 0.85194596877464