https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html

In [1]:
import pickle
with open("./names.pkl","rb") as file:
  names = pickle.load(file)

In [2]:
train_set = names[:int(len(names)*0.8)]
test_set = names[int(len(names)*0.8):]

In [3]:
languages_set = set()
for arr in [train_set,test_set]:
  for el in arr:
    languages_set.add(el[1])

In [4]:
languages_set=list(languages_set)
n_categories=len(languages_set)

In [5]:
len(train_set),len(test_set), len(names)

(16059, 4015, 20074)

In [13]:
import string
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)+1

# One-hot vector for category
def categoryTensor(category):
    li = languages_set.index(category)
    tensor = torch.zeros(1, n_categories)
    tensor[0][li] = 1
    return tensor

# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return tensor

# ``LongTensor`` of second letter to end (EOS) for target
def targetTensor(line):
    letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
    letter_indexes.append(n_letters - 1) # EOS
    return torch.LongTensor(letter_indexes)

In [24]:
import torch
import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, category, input, hidden):
        input_combined = torch.cat((category, input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_letters)

In [26]:
lr = 0.0005 # If you set this too high, it might explode. If too low, it might not learn
from torch.optim import AdamW
optimizer = AdamW(rnn.parameters(), lr=lr)
criterion = nn.NLLLoss()
def train(category_tensor, input_line_tensor, target_line_tensor):
    target_line_tensor.unsqueeze_(-1)
    hidden = rnn.initHidden()

    optimizer.zero_grad()

    loss = torch.Tensor([0]) # you can also just simply use ``loss = 0``

    for i in range(input_line_tensor.size(0)):
        output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)
        l = criterion(output, target_line_tensor[i])
        loss += l

    loss.backward()
    optimizer.step()

    return output, loss.item()

In [27]:
import time
import math
import random
n_iters = 200000
print_every = 5000
plot_every = 1000
from tqdm import tqdm
from statistics import mean

# Keep track of losses for plotting
current_loss = 0
all_losses = []

def randomTrainingExample():
    el = random.choice(train_set)
    line = el[0]
    category = categoryTensor(el[1])
    input_line_tensor = inputTensor(line)
    target_line_tensor = targetTensor(line)
    return category, input_line_tensor, target_line_tensor

# def test():
#     losses = []
#     with torch.no_grad():
#       for el in test_set:
#          line_tensor, category_tensor = lineToTensor(el[0]), torch.tensor([languages_set.index(el[1])], dtype=torch.long)
#          hidden = rnn.initHidden()
#          for i in range(line_tensor.size()[0]):
#              output, hidden = rnn(line_tensor[i], hidden)
#          loss = criterion(output, category_tensor)
#          losses.append(loss.item())
#     return mean(losses)


start = time.time()

for iter in tqdm(range(1, n_iters + 1)):
    category, input_line_tensor, target_line_tensor  = randomTrainingExample()
    output, loss = train(category, input_line_tensor, target_line_tensor)
    current_loss += loss

    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        print("train_loss", all_losses[-1])
        # print("test_loss",test())
        current_loss = 0

  1%|          | 1022/200000 [00:05<21:21, 155.27it/s]

train_loss 21.020474509716035


  1%|          | 2037/200000 [00:09<14:19, 230.25it/s]

train_loss 18.12153672504425


  2%|▏         | 3041/200000 [00:14<14:03, 233.63it/s]

train_loss 17.51363557815552


  2%|▏         | 4044/200000 [00:19<14:18, 228.33it/s]

train_loss 17.065049141168593


  3%|▎         | 5028/200000 [00:24<14:13, 228.46it/s]

train_loss 16.89833425283432


  3%|▎         | 6042/200000 [00:28<14:37, 221.03it/s]

train_loss 17.190678381443025


  4%|▎         | 7035/200000 [00:33<14:28, 222.23it/s]

train_loss 16.97126400065422


  4%|▍         | 8025/200000 [00:38<14:16, 224.17it/s]

train_loss 16.514384172081947


  5%|▍         | 9021/200000 [00:43<21:19, 149.28it/s]

train_loss 16.73128547024727


  5%|▌         | 10033/200000 [00:48<14:01, 225.74it/s]

train_loss 16.881865206956864


  6%|▌         | 11038/200000 [00:52<13:45, 229.04it/s]

train_loss 16.444278674602508


  6%|▌         | 12023/200000 [00:57<15:05, 207.59it/s]

train_loss 16.413511183023452


  7%|▋         | 13031/200000 [01:02<13:48, 225.56it/s]

train_loss 16.565109265327454


  7%|▋         | 14023/200000 [01:07<16:50, 184.12it/s]

train_loss 16.658438471078874


  8%|▊         | 15033/200000 [01:12<13:29, 228.49it/s]

train_loss 16.783077263355256


  8%|▊         | 16025/200000 [01:17<13:37, 225.09it/s]

train_loss 16.216862907409666


  9%|▊         | 17028/200000 [01:22<19:47, 154.02it/s]

train_loss 16.40999446606636


  9%|▉         | 18048/200000 [01:27<12:49, 236.36it/s]

train_loss 16.351042330265045


 10%|▉         | 19036/200000 [01:31<13:22, 225.54it/s]

train_loss 16.147071561336517


 10%|█         | 20035/200000 [01:36<13:47, 217.46it/s]

train_loss 16.480733034610747


 11%|█         | 21033/200000 [01:41<12:42, 234.77it/s]

train_loss 16.423290659427643


 11%|█         | 22027/200000 [01:45<12:37, 235.04it/s]

train_loss 16.009440853118896


 12%|█▏        | 23044/200000 [01:50<12:33, 234.99it/s]

train_loss 16.36322875714302


 12%|█▏        | 24032/200000 [01:55<12:42, 230.84it/s]

train_loss 16.421299243211745


 13%|█▎        | 25022/200000 [01:59<18:33, 157.11it/s]

train_loss 16.148341449975966


 13%|█▎        | 26026/200000 [02:04<13:04, 221.74it/s]

train_loss 16.083916101932527


 14%|█▎        | 27031/200000 [02:09<12:20, 233.56it/s]

train_loss 16.18406449460983


 14%|█▍        | 28028/200000 [02:14<12:54, 222.13it/s]

train_loss 16.123760563135146


 15%|█▍        | 29037/200000 [02:18<12:52, 221.29it/s]

train_loss 15.851438874721527


 15%|█▌        | 30024/200000 [02:24<21:18, 132.96it/s]

train_loss 16.31617490053177


 16%|█▌        | 31041/200000 [02:29<12:05, 232.90it/s]

train_loss 15.934873396873474


 16%|█▌        | 32044/200000 [02:33<12:34, 222.47it/s]

train_loss 15.952878562688827


 17%|█▋        | 33021/200000 [02:38<18:22, 151.40it/s]

train_loss 16.377495810031892


 17%|█▋        | 34027/200000 [02:43<11:52, 232.94it/s]

train_loss 16.343064047217368


 18%|█▊        | 35039/200000 [02:47<12:03, 227.87it/s]

train_loss 16.294335250377657


 18%|█▊        | 36044/200000 [02:52<12:08, 224.96it/s]

train_loss 15.850113852977753


 19%|█▊        | 37032/200000 [02:57<12:05, 224.56it/s]

train_loss 15.931258975028992


 19%|█▉        | 38046/200000 [03:01<11:32, 233.84it/s]

train_loss 16.254582916736602


 20%|█▉        | 39037/200000 [03:06<11:44, 228.62it/s]

train_loss 16.02313740873337


 20%|██        | 40034/200000 [03:11<11:37, 229.19it/s]

train_loss 16.034933081150054


 21%|██        | 41027/200000 [03:15<16:26, 161.22it/s]

train_loss 16.143545143842697


 21%|██        | 42044/200000 [03:21<11:55, 220.71it/s]

train_loss 16.283597616434097


 22%|██▏       | 43026/200000 [03:25<10:52, 240.59it/s]

train_loss 16.072546506404876


 22%|██▏       | 44031/200000 [03:30<16:10, 160.63it/s]

train_loss 16.25597638773918


 23%|██▎       | 45046/200000 [03:35<11:10, 231.17it/s]

train_loss 15.993738594293594


 23%|██▎       | 46025/200000 [03:39<11:37, 220.91it/s]

train_loss 15.88994291973114


 24%|██▎       | 47030/200000 [03:44<11:22, 224.15it/s]

train_loss 16.092481255054473


 24%|██▍       | 48039/200000 [03:48<10:48, 234.46it/s]

train_loss 16.146850522518157


 25%|██▍       | 49026/200000 [03:53<10:54, 230.77it/s]

train_loss 16.22644460773468


 25%|██▌       | 50025/200000 [03:58<11:11, 223.48it/s]

train_loss 15.818018553256989


 26%|██▌       | 51028/200000 [04:02<10:56, 226.96it/s]

train_loss 15.92458460855484


 26%|██▌       | 52023/200000 [04:07<15:59, 154.24it/s]

train_loss 16.089554732561112


 27%|██▋       | 53024/200000 [04:12<11:01, 222.34it/s]

train_loss 16.042582318782806


 27%|██▋       | 54035/200000 [04:16<10:30, 231.66it/s]

train_loss 15.877775828123093


 28%|██▊       | 55042/200000 [04:22<10:44, 224.84it/s]

train_loss 15.957701858997345


 28%|██▊       | 56026/200000 [04:26<10:12, 234.91it/s]

train_loss 16.10293527317047


 29%|██▊       | 57030/200000 [04:31<10:36, 224.55it/s]

train_loss 15.991113948822022


 29%|██▉       | 58031/200000 [04:36<10:32, 224.49it/s]

train_loss 16.056275524139405


 30%|██▉       | 59034/200000 [04:40<10:15, 229.12it/s]

train_loss 16.195020493507386


 30%|███       | 60024/200000 [04:45<14:58, 155.78it/s]

train_loss 16.06483197784424


 31%|███       | 61029/200000 [04:50<10:37, 217.98it/s]

train_loss 15.776455240488053


 31%|███       | 62027/200000 [04:54<10:03, 228.66it/s]

train_loss 15.840650798082352


 32%|███▏      | 63037/200000 [05:00<10:52, 209.97it/s]

train_loss 15.925639992952346


 32%|███▏      | 64027/200000 [05:04<09:33, 237.05it/s]

train_loss 15.839386975765228


 33%|███▎      | 65039/200000 [05:09<10:08, 221.70it/s]

train_loss 15.871844863176346


 33%|███▎      | 66039/200000 [05:14<09:55, 224.95it/s]

train_loss 15.794542990446091


 34%|███▎      | 67039/200000 [05:19<10:07, 218.79it/s]

train_loss 16.05996835923195


 34%|███▍      | 68018/200000 [05:23<14:58, 146.94it/s]

train_loss 16.3111146273613


 35%|███▍      | 69038/200000 [05:29<09:54, 220.16it/s]

train_loss 15.976353908538819


 35%|███▌      | 70033/200000 [05:33<09:28, 228.53it/s]

train_loss 16.00902590751648


 36%|███▌      | 71035/200000 [05:38<10:26, 205.89it/s]

train_loss 16.016016684651376


 36%|███▌      | 72026/200000 [05:43<09:24, 226.74it/s]

train_loss 16.1686739256382


 37%|███▋      | 73044/200000 [05:47<09:23, 225.26it/s]

train_loss 15.834947621107101


 37%|███▋      | 74044/200000 [05:53<09:11, 228.31it/s]

train_loss 16.134604957580567


 38%|███▊      | 75042/200000 [05:57<09:01, 230.97it/s]

train_loss 15.813564151048661


 38%|███▊      | 76017/200000 [06:02<14:17, 144.64it/s]

train_loss 16.072323153734207


 39%|███▊      | 77043/200000 [06:07<09:25, 217.49it/s]

train_loss 15.63762382054329


 39%|███▉      | 78033/200000 [06:11<08:54, 228.24it/s]

train_loss 15.938573860168457


 40%|███▉      | 79032/200000 [06:17<09:00, 223.63it/s]

train_loss 15.913919235944748


 40%|████      | 80037/200000 [06:21<08:31, 234.42it/s]

train_loss 15.85523940563202


 41%|████      | 81024/200000 [06:25<08:42, 227.88it/s]

train_loss 15.721876507878303


 41%|████      | 82042/200000 [06:31<08:21, 235.27it/s]

train_loss 15.805894966602326


 42%|████▏     | 83031/200000 [06:35<08:48, 221.29it/s]

train_loss 15.941659156918526


 42%|████▏     | 84016/200000 [06:40<12:14, 158.00it/s]

train_loss 15.937299144625664


 43%|████▎     | 85032/200000 [06:45<08:22, 229.02it/s]

train_loss 16.22859717273712


 43%|████▎     | 86027/200000 [06:49<08:29, 223.61it/s]

train_loss 15.901674610495567


 44%|████▎     | 87039/200000 [06:54<10:00, 188.21it/s]

train_loss 16.10298603963852


 44%|████▍     | 88038/200000 [06:59<08:23, 222.26it/s]

train_loss 16.05400157535076


 45%|████▍     | 89047/200000 [07:03<07:48, 236.99it/s]

train_loss 16.13138848233223


 45%|████▌     | 90043/200000 [07:09<08:01, 228.37it/s]

train_loss 16.16170492219925


 46%|████▌     | 91040/200000 [07:13<08:13, 221.00it/s]

train_loss 15.902270346403123


 46%|████▌     | 92029/200000 [07:18<10:12, 176.25it/s]

train_loss 15.906674769639968


 47%|████▋     | 93028/200000 [07:23<07:52, 226.57it/s]

train_loss 15.766722251892089


 47%|████▋     | 94036/200000 [07:27<07:40, 229.87it/s]

train_loss 15.922863802671433


 48%|████▊     | 95022/200000 [07:32<11:58, 146.07it/s]

train_loss 15.856671763181687


 48%|████▊     | 96029/200000 [07:37<07:33, 229.12it/s]

train_loss 15.992362245082855


 49%|████▊     | 97044/200000 [07:41<07:35, 226.18it/s]

train_loss 15.956912105560303


 49%|████▉     | 98036/200000 [07:47<07:44, 219.47it/s]

train_loss 15.977737224578858


 50%|████▉     | 99049/200000 [07:51<07:01, 239.75it/s]

train_loss 15.865176891803742


 50%|█████     | 100030/200000 [07:55<07:10, 232.10it/s]

train_loss 15.990447215795516


 51%|█████     | 101034/200000 [08:01<07:32, 218.88it/s]

train_loss 15.662890736579895


 51%|█████     | 102046/200000 [08:06<07:04, 230.62it/s]

train_loss 15.61479485654831


 52%|█████▏    | 103047/200000 [08:11<08:06, 199.44it/s]

train_loss 15.999016711950302


 52%|█████▏    | 104028/200000 [08:15<06:56, 230.63it/s]

train_loss 15.661895928025245


 53%|█████▎    | 105032/200000 [08:20<07:09, 221.26it/s]

train_loss 15.807457952976227


 53%|█████▎    | 106030/200000 [08:25<06:43, 232.83it/s]

train_loss 16.02814550638199


 54%|█████▎    | 107032/200000 [08:30<06:39, 232.76it/s]

train_loss 16.038911470890046


 54%|█████▍    | 108013/200000 [08:34<08:28, 181.01it/s]

train_loss 15.860991528749466


 55%|█████▍    | 109032/200000 [08:39<06:46, 223.94it/s]

train_loss 15.874370914697646


 55%|█████▌    | 110033/200000 [08:44<06:33, 228.38it/s]

train_loss 15.851159352064133


 56%|█████▌    | 111020/200000 [08:49<09:23, 157.80it/s]

train_loss 15.853188555240632


 56%|█████▌    | 112039/200000 [08:54<06:48, 215.52it/s]

train_loss 15.549977199673652


 57%|█████▋    | 113040/200000 [08:58<06:18, 229.85it/s]

train_loss 16.1906196846962


 57%|█████▋    | 114026/200000 [09:03<06:15, 229.03it/s]

train_loss 15.762328636407853


 58%|█████▊    | 115023/200000 [09:08<06:26, 219.69it/s]

train_loss 15.904240379571915


 58%|█████▊    | 116021/200000 [09:13<10:13, 136.82it/s]

train_loss 15.92756687426567


 59%|█████▊    | 117030/200000 [09:19<06:18, 219.48it/s]

train_loss 16.591401599168776


 59%|█████▉    | 118026/200000 [09:23<06:13, 219.69it/s]

train_loss 15.72513193655014


 60%|█████▉    | 119022/200000 [09:28<08:54, 151.52it/s]

train_loss 16.02914896631241


 60%|██████    | 120045/200000 [09:33<05:46, 231.04it/s]

train_loss 15.901782204151154


 61%|██████    | 121037/200000 [09:38<05:56, 221.55it/s]

train_loss 15.963225260972976


 61%|██████    | 122034/200000 [09:43<06:15, 207.51it/s]

train_loss 16.072419536590576


 62%|██████▏   | 123022/200000 [09:47<05:42, 224.63it/s]

train_loss 15.890267127990723


 62%|██████▏   | 124026/200000 [09:52<05:38, 224.75it/s]

train_loss 16.04667667412758


 63%|██████▎   | 125037/200000 [09:58<06:11, 201.79it/s]

train_loss 15.723004937410355


 63%|██████▎   | 126042/200000 [10:02<05:26, 226.36it/s]

train_loss 15.78408639240265


 64%|██████▎   | 127019/200000 [10:07<08:01, 151.45it/s]

train_loss 16.04697588777542


 64%|██████▍   | 128031/200000 [10:12<05:40, 211.64it/s]

train_loss 15.761700340509414


 65%|██████▍   | 129036/200000 [10:17<05:08, 229.81it/s]

train_loss 16.09232737660408


 65%|██████▌   | 130037/200000 [10:22<05:29, 212.14it/s]

train_loss 15.989393565654755


 66%|██████▌   | 131025/200000 [10:27<05:24, 212.29it/s]

train_loss 15.900026126623153


 66%|██████▌   | 132015/200000 [10:31<06:58, 162.31it/s]

train_loss 15.719915755271911


 67%|██████▋   | 133027/200000 [10:37<05:17, 211.10it/s]

train_loss 15.907588210582734


 67%|██████▋   | 134037/200000 [10:41<04:51, 226.60it/s]

train_loss 16.15287686038017


 68%|██████▊   | 135040/200000 [10:47<05:50, 185.41it/s]

train_loss 15.822587366342544


 68%|██████▊   | 136025/200000 [10:51<04:55, 216.37it/s]

train_loss 15.974640994787217


 69%|██████▊   | 137037/200000 [10:55<04:38, 225.76it/s]

train_loss 15.942206257581711


 69%|██████▉   | 138035/200000 [11:01<04:34, 225.36it/s]

train_loss 15.82511098909378


 70%|██████▉   | 139032/200000 [11:05<04:31, 224.37it/s]

train_loss 15.550004354953765


 70%|███████   | 140023/200000 [11:10<05:55, 168.69it/s]

train_loss 16.026342866659164


 71%|███████   | 141022/200000 [11:15<04:36, 213.33it/s]

train_loss 15.841058952569961


 71%|███████   | 142039/200000 [11:19<04:09, 232.16it/s]

train_loss 16.012150951981546


 72%|███████▏  | 143031/200000 [11:25<06:23, 148.46it/s]

train_loss 16.069985862493514


 72%|███████▏  | 144047/200000 [11:29<04:04, 228.76it/s]

train_loss 15.985878143072128


 73%|███████▎  | 145039/200000 [11:34<04:03, 225.96it/s]

train_loss 15.916363318443299


 73%|███████▎  | 146038/200000 [11:39<04:02, 222.12it/s]

train_loss 15.826781267404556


 74%|███████▎  | 147027/200000 [11:44<03:57, 222.83it/s]

train_loss 15.967424093008042


 74%|███████▍  | 148026/200000 [11:48<05:11, 166.99it/s]

train_loss 15.943461787462235


 75%|███████▍  | 149030/200000 [11:53<03:31, 241.34it/s]

train_loss 15.980974002599716


 75%|███████▌  | 150041/200000 [11:58<03:40, 226.35it/s]

train_loss 15.661893762946129


 76%|███████▌  | 151017/200000 [12:03<05:22, 152.01it/s]

train_loss 15.720630689382553


 76%|███████▌  | 152043/200000 [12:08<03:35, 222.91it/s]

train_loss 15.912895483970642


 77%|███████▋  | 153029/200000 [12:12<03:31, 222.09it/s]

train_loss 15.836708634614945


 77%|███████▋  | 154028/200000 [12:18<03:24, 224.34it/s]

train_loss 16.11356113123894


 78%|███████▊  | 155025/200000 [12:22<03:20, 224.03it/s]

train_loss 15.767767828702926


 78%|███████▊  | 156030/200000 [12:27<04:17, 170.94it/s]

train_loss 16.036050401210783


 79%|███████▊  | 157044/200000 [12:32<03:02, 234.99it/s]

train_loss 15.767051329612732


 79%|███████▉  | 158033/200000 [12:36<03:16, 213.81it/s]

train_loss 15.857380621433258


 80%|███████▉  | 159018/200000 [12:42<04:34, 149.44it/s]

train_loss 16.102107254266738


 80%|████████  | 160039/200000 [12:46<02:59, 222.42it/s]

train_loss 15.774814568281174


 81%|████████  | 161045/200000 [12:51<02:51, 226.55it/s]

train_loss 15.747527306556702


 81%|████████  | 162045/200000 [12:56<02:53, 219.12it/s]

train_loss 16.231321276903152


 82%|████████▏ | 163044/200000 [13:01<02:39, 231.17it/s]

train_loss 15.936729292154313


 82%|████████▏ | 164020/200000 [13:05<03:15, 183.76it/s]

train_loss 15.737097281694412


 83%|████████▎ | 165034/200000 [13:10<02:37, 221.60it/s]

train_loss 15.886976227045059


 83%|████████▎ | 166039/200000 [13:15<02:34, 219.24it/s]

train_loss 15.492004974365233


 84%|████████▎ | 167024/200000 [13:20<03:40, 149.32it/s]

train_loss 15.990594772338866


 84%|████████▍ | 168045/200000 [13:25<02:17, 232.88it/s]

train_loss 15.636646650791167


 85%|████████▍ | 169042/200000 [13:29<02:19, 221.87it/s]

train_loss 15.820303593158721


 85%|████████▌ | 170028/200000 [13:34<02:14, 223.31it/s]

train_loss 15.734605233669281


 86%|████████▌ | 171033/200000 [13:39<02:14, 215.98it/s]

train_loss 15.90361462199688


 86%|████████▌ | 172026/200000 [13:43<02:35, 179.51it/s]

train_loss 16.005247123718263


 87%|████████▋ | 173044/200000 [13:49<02:04, 217.19it/s]

train_loss 15.992699988603592


 87%|████████▋ | 174029/200000 [13:53<01:53, 229.18it/s]

train_loss 15.85154773426056


 88%|████████▊ | 175015/200000 [13:58<02:44, 151.97it/s]

train_loss 15.79873294711113


 88%|████████▊ | 176030/200000 [14:03<01:48, 220.60it/s]

train_loss 15.45306594133377


 89%|████████▊ | 177038/200000 [14:08<01:42, 223.05it/s]

train_loss 15.869832126140594


 89%|████████▉ | 178026/200000 [14:13<01:34, 232.31it/s]

train_loss 15.959919010162354


 90%|████████▉ | 179040/200000 [14:17<01:37, 214.10it/s]

train_loss 15.959768371343612


 90%|█████████ | 180029/200000 [14:22<01:48, 183.45it/s]

train_loss 15.648740842342377


 91%|█████████ | 181030/200000 [14:27<01:24, 224.56it/s]

train_loss 15.738000533342362


 91%|█████████ | 182043/200000 [14:32<01:20, 221.99it/s]

train_loss 15.91601589679718


 92%|█████████▏| 183025/200000 [14:37<01:49, 155.50it/s]

train_loss 16.14297343325615


 92%|█████████▏| 184031/200000 [14:41<01:09, 229.97it/s]

train_loss 16.00392128062248


 93%|█████████▎| 185041/200000 [14:46<01:07, 222.97it/s]

train_loss 15.647898645758628


 93%|█████████▎| 186038/200000 [14:53<01:01, 225.85it/s]

train_loss 15.93549103116989


 94%|█████████▎| 187038/200000 [14:57<00:58, 222.51it/s]

train_loss 16.047288871526717


 94%|█████████▍| 188025/200000 [15:01<00:51, 231.46it/s]

train_loss 15.528427793979645


 95%|█████████▍| 189024/200000 [15:07<00:49, 220.54it/s]

train_loss 15.904236383914947


 95%|█████████▌| 190026/200000 [15:12<00:45, 218.96it/s]

train_loss 15.941210552930832


 96%|█████████▌| 191016/200000 [15:17<01:01, 146.31it/s]

train_loss 15.445281387090683


 96%|█████████▌| 192029/200000 [15:21<00:34, 229.64it/s]

train_loss 15.928973904132842


 97%|█████████▋| 193044/200000 [15:26<00:30, 226.02it/s]

train_loss 16.095273601055144


 97%|█████████▋| 194035/200000 [15:31<00:26, 222.19it/s]

train_loss 15.711083534061908


 98%|█████████▊| 195027/200000 [15:36<00:21, 230.16it/s]

train_loss 15.86724893951416


 98%|█████████▊| 196039/200000 [15:40<00:17, 221.73it/s]

train_loss 16.131089668512345


 99%|█████████▊| 197039/200000 [15:46<00:13, 218.82it/s]

train_loss 16.297785317897798


 99%|█████████▉| 198029/200000 [15:50<00:08, 220.42it/s]

train_loss 16.104493117570875


100%|█████████▉| 199015/200000 [15:55<00:06, 152.73it/s]

train_loss 15.389915365219116


100%|██████████| 200000/200000 [16:00<00:00, 208.24it/s]

train_loss 15.848558926701546





In [32]:
max_length = 20

# Sample from a category and starting letter
def sample(category, start_letter='A'):
    with torch.no_grad():  # no need to track history in sampling
        category_tensor = categoryTensor(category)
        input = inputTensor(start_letter)
        hidden = rnn.initHidden()

        output_name = start_letter

        for i in range(max_length):
            output, hidden = rnn(category_tensor, input[0], hidden)
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_letters - 1:
                break
            else:
                letter = all_letters[topi]
                output_name += letter
            input = inputTensor(letter)

        return output_name

# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
    for start_letter in start_letters:
        print(sample(category, start_letter))

samples('Russian', 'RUS')

samples('German', 'GER')

samples('Spanish', 'SPA')

samples('Chinese', 'CHI')

Rovich
Ushanovich
Sharoff
Gran
Eler
Rosher
Salein
Para
Abara
Cha
Han
Ina
