https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

In [110]:
import pickle
with open("./names.pkl","rb") as file:
  names = pickle.load(file)

In [111]:
train_set = names[:int(len(names)*0.8)]
test_set = names[int(len(names)*0.8):]

In [112]:
languages_set = set()
for arr in [train_set,test_set]:
  for el in arr:
    languages_set.add(el[1])

In [113]:
languages_set=list(languages_set)

In [114]:
len(train_set),len(test_set), len(names)

(16059, 4015, 20074)

In [115]:
import string
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

def letterToIndex(letter):
    return all_letters.find(letter)

def lineToTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li, letter in enumerate(line):
        tensor[li][0][letterToIndex(letter)] = 1
    return tensor

In [116]:
import torch.nn as nn
import torch
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 256
rnn = RNN(n_letters, n_hidden, len(languages_set))

In [117]:
lr = 0.001 # If you set this too high, it might explode. If too low, it might not learn
from torch.optim import AdamW
optimizer = AdamW(rnn.parameters(), lr=lr)
criterion = nn.NLLLoss()
def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()
    optimizer.zero_grad()
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = criterion(output, category_tensor)
    loss.backward()
    optimizer.step()

    return output, loss.item()

In [118]:
import time
import math
import random
n_iters = 200000
print_every = 5000
plot_every = 1000
from tqdm import tqdm
from statistics import mean

# Keep track of losses for plotting
current_loss = 0
all_losses = []

def randomTrainingExample():
    el = random.choice(train_set)
    return lineToTensor(el[0]), torch.tensor([languages_set.index(el[1])], dtype=torch.long)

def test():
    losses = []
    with torch.no_grad():
      for el in test_set:
         line_tensor, category_tensor = lineToTensor(el[0]), torch.tensor([languages_set.index(el[1])], dtype=torch.long)
         hidden = rnn.initHidden()
         for i in range(line_tensor.size()[0]):
             output, hidden = rnn(line_tensor[i], hidden)
         loss = criterion(output, category_tensor)
         losses.append(loss.item())
    return mean(losses)


start = time.time()

for iter in tqdm(range(1, n_iters + 1)):
    line_tensor, category_tensor  = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        print("train_loss", all_losses[-1])
        print("test_loss",test())
        current_loss = 0

  0%|          | 993/200000 [00:03<11:25, 290.29it/s]

train_loss 1.831503066888501


  1%|          | 1053/200000 [00:06<1:05:51, 50.35it/s]

test_loss 1.6419061571904565


  1%|          | 1981/200000 [00:09<11:31, 286.49it/s]

train_loss 1.5851343831165576


  1%|          | 2036/200000 [00:12<1:09:21, 47.57it/s]

test_loss 1.4427444379183476


  1%|▏         | 2982/200000 [00:15<10:36, 309.49it/s]

train_loss 1.4012951051891


  2%|▏         | 3044/200000 [00:18<1:02:02, 52.92it/s]

test_loss 1.4213282134520084


  2%|▏         | 3997/200000 [00:22<14:52, 219.70it/s]

train_loss 1.3529136951655436


  2%|▏         | 4054/200000 [00:24<1:12:50, 44.83it/s]

test_loss 1.2328863437192261


  2%|▏         | 4982/200000 [00:27<10:11, 318.82it/s]

train_loss 1.3784811224467772


  3%|▎         | 5047/200000 [00:30<58:22, 55.67it/s]  

test_loss 1.2147311265999916


  3%|▎         | 5985/200000 [00:33<13:54, 232.58it/s]

train_loss 1.302954516189271


  3%|▎         | 6039/200000 [00:36<1:23:30, 38.71it/s]

test_loss 1.257550836467169


  3%|▎         | 6982/200000 [00:39<10:39, 301.68it/s]

train_loss 1.2460625995220154


  4%|▎         | 7044/200000 [00:42<1:01:48, 52.03it/s]

test_loss 1.23583720028498


  4%|▍         | 7982/200000 [00:45<14:28, 221.16it/s]

train_loss 1.2702650765521721


  4%|▍         | 8037/200000 [00:48<1:26:32, 36.97it/s]

test_loss 1.1828472879520706


  4%|▍         | 8989/200000 [00:51<10:05, 315.30it/s]

train_loss 1.256687814490173


  5%|▍         | 9058/200000 [00:54<56:03, 56.77it/s]  

test_loss 1.183033177926313


  5%|▍         | 9995/200000 [00:57<09:11, 344.68it/s]

train_loss 1.2282110551416416


  5%|▌         | 10030/200000 [01:00<1:40:37, 31.47it/s]

test_loss 1.1087060217044646


  5%|▌         | 10994/200000 [01:03<10:05, 311.91it/s]

train_loss 1.1812298812229682


  6%|▌         | 11056/200000 [01:06<57:58, 54.32it/s]  

test_loss 1.164736817552231


  6%|▌         | 11977/200000 [01:09<10:00, 313.19it/s]

train_loss 1.159434256276681


  6%|▌         | 12042/200000 [01:12<1:16:54, 40.74it/s]

test_loss 1.1615243096228554


  6%|▋         | 12983/200000 [01:15<09:46, 319.08it/s]

train_loss 1.2310353723006828


  7%|▋         | 13049/200000 [01:18<56:02, 55.59it/s]  

test_loss 1.2370941878467219


  7%|▋         | 13995/200000 [01:21<09:10, 337.65it/s]

train_loss 1.1957114651885559


  7%|▋         | 14029/200000 [01:24<1:32:07, 33.65it/s]

test_loss 1.05687203135781


  7%|▋         | 14993/200000 [01:27<09:28, 325.20it/s]

train_loss 1.0886052595532205


  8%|▊         | 15059/200000 [01:30<54:27, 56.61it/s]  

test_loss 1.105855080296232


  8%|▊         | 15974/200000 [01:33<10:23, 295.37it/s]

train_loss 1.1389289070349655


  8%|▊         | 16026/200000 [01:36<1:19:50, 38.40it/s]

test_loss 1.0704663097644675


  8%|▊         | 16975/200000 [01:40<10:23, 293.41it/s]

train_loss 1.1198783698050574


  9%|▊         | 17036/200000 [01:42<59:19, 51.40it/s]  

test_loss 1.061426205021981


  9%|▉         | 17996/200000 [01:45<09:58, 303.88it/s]

train_loss 1.0800953908157171


  9%|▉         | 18027/200000 [01:48<1:28:39, 34.21it/s]

test_loss 1.09087380739752


  9%|▉         | 18973/200000 [01:52<09:32, 316.26it/s]

train_loss 1.139960703359815


 10%|▉         | 19041/200000 [01:54<53:35, 56.27it/s]  

test_loss 1.024723769862538


 10%|▉         | 19984/200000 [01:57<10:16, 292.07it/s]

train_loss 31.88862112392143


 10%|█         | 20035/200000 [02:00<1:09:12, 43.34it/s]

test_loss 3.5778266105830383


 10%|█         | 20974/200000 [02:04<09:39, 309.10it/s]

train_loss 1.732123164212615


 11%|█         | 21037/200000 [02:06<56:03, 53.21it/s]  

test_loss 1.2405014011313953


 11%|█         | 21987/200000 [02:09<09:38, 307.52it/s]

train_loss 1.1766471387453012


 11%|█         | 22050/200000 [02:12<55:40, 53.27it/s]  

test_loss 1.1155477392285131


 11%|█▏        | 22988/200000 [02:16<10:24, 283.22it/s]

train_loss 1.1010578594460676


 12%|█▏        | 23051/200000 [02:18<56:02, 52.63it/s]  

test_loss 1.0434004467803533


 12%|█▏        | 23998/200000 [02:22<12:03, 243.34it/s]

train_loss 1.0068511002470586


 12%|█▏        | 24023/200000 [02:25<1:44:46, 27.99it/s]

test_loss 1.0178922357406843


 12%|█▏        | 24974/200000 [02:29<09:06, 320.39it/s]

train_loss 1.0505018650496787


 13%|█▎        | 25040/200000 [02:31<51:50, 56.25it/s]  

test_loss 0.992389449129762


 13%|█▎        | 25997/200000 [02:34<08:55, 324.83it/s]

train_loss 1.0469741362996428


 13%|█▎        | 26030/200000 [02:37<1:11:07, 40.77it/s]

test_loss 0.9633000706495773


 13%|█▎        | 26996/200000 [02:41<09:26, 305.44it/s]

train_loss 0.9730235835932254


 14%|█▎        | 27057/200000 [02:43<55:45, 51.69it/s]  

test_loss 1.024546036443414


 14%|█▍        | 27984/200000 [02:46<09:16, 308.86it/s]

train_loss 1.002725900143192


 14%|█▍        | 28044/200000 [02:49<54:42, 52.39it/s]  

test_loss 0.9670868625528659


 14%|█▍        | 28972/200000 [02:53<09:36, 296.77it/s]

train_loss 0.9741857495817892


 15%|█▍        | 29034/200000 [02:55<55:19, 51.50it/s]  

test_loss 1.0191480492155067


 15%|█▍        | 29996/200000 [02:59<09:11, 308.52it/s]

train_loss 0.9480374346727554


 15%|█▌        | 30027/200000 [03:01<1:11:39, 39.53it/s]

test_loss 0.9913026757114998


 15%|█▌        | 30990/200000 [03:05<12:39, 222.67it/s]

train_loss 1.0829435743036446


 16%|█▌        | 31043/200000 [03:08<1:14:00, 38.05it/s]

test_loss 1.0247416629026205


 16%|█▌        | 31986/200000 [03:11<08:54, 314.47it/s]

train_loss 1.1241614720943562


 16%|█▌        | 32042/200000 [03:13<56:31, 49.53it/s]  

test_loss 0.9997091713154617


 16%|█▋        | 32996/200000 [03:17<12:42, 218.96it/s]

train_loss 1.046105967368692


 17%|█▋        | 33051/200000 [03:20<1:03:46, 43.64it/s]

test_loss 1.3347679064806595


 17%|█▋        | 33987/200000 [03:23<08:47, 314.90it/s]

train_loss 1.1403873429527784


 17%|█▋        | 34047/200000 [03:25<53:13, 51.97it/s]  

test_loss 1.0811079616633081


 17%|█▋        | 34987/200000 [03:29<12:24, 221.64it/s]

train_loss 1.081541109143434


 18%|█▊        | 35044/200000 [03:32<1:08:20, 40.22it/s]

test_loss 1.2860518051072656


 18%|█▊        | 35970/200000 [03:35<08:26, 324.06it/s]

train_loss 1.1478974648653106


 18%|█▊        | 36033/200000 [03:37<52:34, 51.98it/s]  

test_loss 1.018298476736884


 18%|█▊        | 36999/200000 [03:41<13:22, 203.15it/s]

train_loss 1.0373121987790828


 19%|█▊        | 37048/200000 [03:44<1:19:41, 34.08it/s]

test_loss 1.0075415481302648


 19%|█▉        | 37986/200000 [03:47<10:56, 246.87it/s]

train_loss 0.9845312907680268


 19%|█▉        | 38040/200000 [03:50<1:03:31, 42.49it/s]

test_loss 1.0303741022590167


 19%|█▉        | 38987/200000 [03:53<11:03, 242.73it/s]

train_loss 1.0280053216353362


 20%|█▉        | 39030/200000 [03:57<1:45:27, 25.44it/s]

test_loss 0.9935961758653362


 20%|█▉        | 39968/200000 [04:01<08:23, 318.03it/s]

train_loss 35.98884987572131


 20%|██        | 40033/200000 [04:03<50:59, 52.29it/s]  

test_loss 3.429279754467489


 20%|██        | 40985/200000 [04:07<12:20, 214.76it/s]

train_loss 1.8842319929335252


 21%|██        | 41038/200000 [04:10<1:09:36, 38.06it/s]

test_loss 1.2039730026346875


 21%|██        | 41978/200000 [04:13<08:08, 323.73it/s]

train_loss 1.096604389428452


 21%|██        | 42048/200000 [04:15<44:38, 58.97it/s]  

test_loss 1.0840186588358325


 21%|██▏       | 42978/200000 [04:18<12:00, 217.91it/s]

train_loss 1.089364511200576


 22%|██▏       | 43030/200000 [04:22<1:13:26, 35.62it/s]

test_loss 1.0173235303171306


 22%|██▏       | 43984/200000 [04:25<08:40, 299.69it/s]

train_loss 1.025707513116482


 22%|██▏       | 44049/200000 [04:27<48:45, 53.31it/s]  

test_loss 1.0308182776708048


 22%|██▏       | 44976/200000 [04:30<07:41, 335.96it/s]

train_loss 1.0453815568248628


 23%|██▎       | 45044/200000 [04:34<1:00:07, 42.96it/s]

test_loss 0.9813561943942896


 23%|██▎       | 45996/200000 [04:37<09:17, 276.44it/s]

train_loss 0.9675900096130572


 23%|██▎       | 46060/200000 [04:40<52:10, 49.18it/s]  

test_loss 0.9997683828365049


 23%|██▎       | 46998/200000 [04:43<07:44, 329.12it/s]

train_loss 1.018370141497966


 24%|██▎       | 47064/200000 [04:46<1:00:24, 42.19it/s]

test_loss 0.9873061088082052


 24%|██▍       | 47969/200000 [04:49<08:17, 305.46it/s]

train_loss 1.038603015036006


 24%|██▍       | 48034/200000 [04:52<46:43, 54.20it/s]  

test_loss 1.0106545695853177


 24%|██▍       | 48980/200000 [04:55<08:45, 287.63it/s]

train_loss 0.9974744438012754


 25%|██▍       | 49044/200000 [04:59<1:04:08, 39.22it/s]

test_loss 1.013662109247866


 25%|██▍       | 49991/200000 [05:02<07:53, 316.77it/s]

train_loss 1.0599300984442235


 25%|██▌       | 50055/200000 [05:04<47:10, 52.97it/s]  

test_loss 1.0016514333116788


 25%|██▌       | 50972/200000 [05:07<07:38, 324.79it/s]

train_loss 0.9381973134853306


 26%|██▌       | 51029/200000 [05:10<59:34, 41.67it/s]  

test_loss 1.0455196384423886


 26%|██▌       | 51997/200000 [05:14<08:04, 305.17it/s]

train_loss 1.1014356777513077


 26%|██▌       | 52060/200000 [05:16<45:47, 53.85it/s]  

test_loss 1.0022099338777426


 26%|██▋       | 52995/200000 [05:19<07:28, 327.62it/s]

train_loss 1.0634990954949262


 27%|██▋       | 53028/200000 [05:22<1:06:58, 36.58it/s]

test_loss 1.0418602685809042


 27%|██▋       | 53984/200000 [05:25<07:24, 328.79it/s]

train_loss 1.0139491652684463


 27%|██▋       | 54045/200000 [05:28<46:51, 51.91it/s]  

test_loss 1.0314790085789523


 27%|██▋       | 54990/200000 [05:31<07:35, 318.34it/s]

train_loss 1.0997486060303274


 28%|██▊       | 55022/200000 [05:34<59:00, 40.94it/s] 

test_loss 1.0356030960235665


 28%|██▊       | 55975/200000 [05:37<08:16, 289.96it/s]

train_loss 1.0909005223974573


 28%|██▊       | 56041/200000 [05:40<44:18, 54.15it/s]  

test_loss 1.0595705869308687


 28%|██▊       | 56996/200000 [05:43<07:46, 306.52it/s]

train_loss 1.020277218057287


 29%|██▊       | 57063/200000 [05:45<42:21, 56.24it/s]

test_loss 0.9937222275813602


 29%|██▉       | 57979/200000 [05:49<07:22, 320.66it/s]

train_loss 1.0741393755120516


 29%|██▉       | 58046/200000 [05:52<42:10, 56.11it/s]

test_loss 0.9483842059617086


 29%|██▉       | 58972/200000 [05:55<07:48, 301.05it/s]

train_loss 1.0123209846790686


 30%|██▉       | 59036/200000 [05:57<43:27, 54.05it/s]

test_loss 0.9619468792239587


 30%|██▉       | 59976/200000 [06:01<12:13, 191.00it/s]

train_loss 1.0257498976976895


 30%|███       | 60026/200000 [06:04<59:48, 39.01it/s]  

test_loss 0.978221828729418


 30%|███       | 60998/200000 [06:07<07:46, 298.06it/s]

train_loss 0.9633251860577621


 31%|███       | 61061/200000 [06:09<43:33, 53.16it/s]

test_loss 0.987519931342081


 31%|███       | 61976/200000 [06:13<09:39, 238.01it/s]

train_loss 1.0211859386189017


 31%|███       | 62027/200000 [06:16<58:40, 39.20it/s]  

test_loss 1.0025807074777904


 31%|███▏      | 62994/200000 [06:19<07:32, 302.86it/s]

train_loss 0.9899389542978315


 32%|███▏      | 63059/200000 [06:21<43:17, 52.72it/s]  

test_loss 1.0247251172869258


 32%|███▏      | 63986/200000 [06:25<11:08, 203.41it/s]

train_loss 1.132104002528416


 32%|███▏      | 64038/200000 [06:28<1:03:44, 35.55it/s]

test_loss 0.9813815760049613


 32%|███▏      | 64995/200000 [06:31<07:10, 313.41it/s]

train_loss 1.1237767911583811


 33%|███▎      | 65054/200000 [06:34<43:11, 52.06it/s]

test_loss 1.004542503193071


 33%|███▎      | 65975/200000 [06:37<09:11, 243.19it/s]

train_loss 0.978063489241993


 33%|███▎      | 66021/200000 [06:40<1:11:27, 31.25it/s]

test_loss 1.0154899322354742


 33%|███▎      | 66988/200000 [06:44<07:14, 305.96it/s]

train_loss 0.9586676447608327


 34%|███▎      | 67050/200000 [06:47<41:55, 52.85it/s]

test_loss 0.9897978510217893


 34%|███▍      | 67995/200000 [06:50<09:21, 235.09it/s]

train_loss 1.0102639200031003


 34%|███▍      | 68050/200000 [06:53<56:11, 39.13it/s]  

test_loss 0.9849609940556291


 34%|███▍      | 68976/200000 [06:57<07:13, 302.28it/s]

train_loss 0.9499821385007308


 35%|███▍      | 69039/200000 [06:59<41:08, 53.04it/s]

test_loss 0.965314652441275


 35%|███▍      | 69983/200000 [07:03<09:25, 229.75it/s]

train_loss 1.0845412261797756


 35%|███▌      | 70038/200000 [07:06<58:48, 36.84it/s]  

test_loss 0.9658673732797827


 35%|███▌      | 70974/200000 [07:09<06:45, 318.22it/s]

train_loss 1.357359541124411


 36%|███▌      | 71040/200000 [07:11<39:28, 54.44it/s]

test_loss 1.033096066262447


 36%|███▌      | 71978/200000 [07:15<09:51, 216.40it/s]

train_loss 0.9695740641574143


 36%|███▌      | 72031/200000 [07:18<1:00:16, 35.38it/s]

test_loss 0.9724235341664802


 36%|███▋      | 72966/200000 [07:21<06:53, 307.02it/s]

train_loss 0.924766936543756


 37%|███▋      | 73033/200000 [07:24<38:21, 55.17it/s]

test_loss 0.9413797379372199


 37%|███▋      | 73998/200000 [07:27<08:19, 252.20it/s]

train_loss 0.8949971885726405


 37%|███▋      | 74055/200000 [07:30<57:23, 36.58it/s]  

test_loss 0.9812775022477819


 37%|███▋      | 74971/200000 [07:33<06:41, 311.39it/s]

train_loss 1.0323908731670364


 38%|███▊      | 75031/200000 [07:36<39:30, 52.72it/s]

test_loss 0.928825896230553


 38%|███▊      | 75994/200000 [07:39<06:50, 301.84it/s]

train_loss 1.018569310459727


 38%|███▊      | 76057/200000 [07:43<51:09, 40.37it/s]  

test_loss 3.9805868774895043


 38%|███▊      | 76983/200000 [07:46<06:25, 319.07it/s]

train_loss 20.630130338574922


 39%|███▊      | 77045/200000 [07:48<39:46, 51.53it/s]

test_loss 1.701756740742734


 39%|███▉      | 77981/200000 [07:51<06:23, 318.12it/s]

train_loss 1.4192332694422578


 39%|███▉      | 78043/200000 [07:55<51:41, 39.32it/s]  

test_loss 1.1184459357421572


 39%|███▉      | 78996/200000 [07:58<06:39, 302.87it/s]

train_loss 0.9877587318030449


 40%|███▉      | 79057/200000 [08:01<37:51, 53.25it/s]

test_loss 1.0118551309112076


 40%|███▉      | 79988/200000 [08:04<06:38, 301.44it/s]

train_loss 0.999557882900012


 40%|████      | 80050/200000 [08:07<50:36, 39.51it/s]  

test_loss 0.9746170550301357


 40%|████      | 80996/200000 [08:10<06:33, 302.80it/s]

train_loss 0.9394726584140698


 41%|████      | 81061/200000 [08:13<36:30, 54.29it/s]

test_loss 0.9482919085182134


 41%|████      | 81974/200000 [08:16<06:22, 308.51it/s]

train_loss 0.9094360530906777


 41%|████      | 82027/200000 [08:19<52:42, 37.30it/s]  

test_loss 0.9675988786380378


 41%|████▏     | 82993/200000 [08:23<06:44, 289.59it/s]

train_loss 0.9942486337547066


 42%|████▏     | 83056/200000 [08:26<39:49, 48.94it/s]

test_loss 0.9874826449804033


 42%|████▏     | 83996/200000 [08:29<06:09, 314.13it/s]

train_loss 0.8991582388963181


 42%|████▏     | 84028/200000 [08:32<1:04:07, 30.14it/s]

test_loss 1.1400296244721304


 42%|████▏     | 84966/200000 [08:35<05:53, 325.25it/s]

train_loss 1.0645088268537022


 43%|████▎     | 85033/200000 [08:38<33:56, 56.46it/s]

test_loss 0.9740454730417171


 43%|████▎     | 85988/200000 [08:41<06:44, 282.10it/s]

train_loss 0.9794398151116784


 43%|████▎     | 86038/200000 [08:45<55:25, 34.27it/s]  

test_loss 0.982428062164349


 43%|████▎     | 86976/200000 [08:48<06:44, 279.46it/s]

train_loss 0.9746404295573484


 44%|████▎     | 87039/200000 [08:50<36:16, 51.89it/s]

test_loss 0.9554693872119893


 44%|████▍     | 87986/200000 [08:53<05:58, 312.35it/s]

train_loss 1.0385859414949923


 44%|████▍     | 88041/200000 [08:57<45:49, 40.72it/s]

test_loss 0.9737862626967041


 44%|████▍     | 88976/200000 [09:00<05:59, 309.07it/s]

train_loss 0.9733197707701821


 45%|████▍     | 89039/200000 [09:02<33:57, 54.46it/s]

test_loss 0.9641060294569757


 45%|████▍     | 89969/200000 [09:05<06:17, 291.79it/s]

train_loss 1.0406999103659265


 45%|████▌     | 90022/200000 [09:08<43:59, 41.67it/s]

test_loss 1.0084633731614228


 45%|████▌     | 90976/200000 [09:12<06:03, 299.99it/s]

train_loss 1.0613790140076211


 46%|████▌     | 91037/200000 [09:14<36:00, 50.44it/s]

test_loss 3.7534170318346867


 46%|████▌     | 91973/200000 [09:18<05:52, 306.75it/s]

train_loss 835.3203594633287


 46%|████▌     | 92026/200000 [09:20<39:19, 45.76it/s]

test_loss 2.6074450938154494


 46%|████▋     | 92990/200000 [09:24<05:36, 317.82it/s]

train_loss 1.8235690323865912


 47%|████▋     | 93049/200000 [09:27<34:22, 51.87it/s]

test_loss 1.6507908779824485


 47%|████▋     | 93996/200000 [09:30<05:49, 303.32it/s]

train_loss 1.5276463484774423


 47%|████▋     | 94060/200000 [09:32<32:15, 54.74it/s]

test_loss 1.275702935119569


 47%|████▋     | 94987/200000 [09:36<05:57, 293.64it/s]

train_loss 1.2245246967986614


 48%|████▊     | 95046/200000 [09:39<36:15, 48.25it/s]

test_loss 1.1860961381004145


 48%|████▊     | 95994/200000 [09:42<05:46, 300.01it/s]

train_loss 1.017921372836798


 48%|████▊     | 96058/200000 [09:45<32:25, 53.42it/s]

test_loss 1.0885650567115432


 48%|████▊     | 96985/200000 [09:49<06:32, 262.70it/s]

train_loss 2.4798445580140456


 49%|████▊     | 97045/200000 [09:51<35:45, 47.99it/s]

test_loss 1.273384579464602


 49%|████▉     | 97987/200000 [09:54<06:02, 281.37it/s]

train_loss 1.1824519755211622


 49%|████▉     | 98047/200000 [09:57<34:57, 48.62it/s]

test_loss 1.0552310706032388


 49%|████▉     | 98991/200000 [10:01<06:33, 256.72it/s]

train_loss 1.0523810882601323


 50%|████▉     | 99050/200000 [10:04<36:22, 46.25it/s]

test_loss 0.9713159122976315


 50%|████▉     | 99982/200000 [10:07<05:37, 296.39it/s]

train_loss 0.9497970740419233


 50%|█████     | 100041/200000 [10:10<34:12, 48.71it/s]

test_loss 0.9597180488234875


 50%|█████     | 100982/200000 [10:14<05:48, 284.43it/s]

train_loss 0.9720891640842415


 51%|█████     | 101042/200000 [10:16<33:58, 48.54it/s]

test_loss 0.9710278411805678


 51%|█████     | 101988/200000 [10:20<05:21, 305.31it/s]

train_loss 0.9506685793462438


 51%|█████     | 102049/200000 [10:22<32:00, 51.01it/s]

test_loss 0.9586489747006925


 51%|█████▏    | 102999/200000 [10:26<05:19, 303.85it/s]

train_loss 0.9504490213750869


 52%|█████▏    | 103060/200000 [10:29<32:12, 50.16it/s]

test_loss 0.9257644926489272


 52%|█████▏    | 103993/200000 [10:32<05:09, 310.11it/s]

train_loss 0.91673437751145


 52%|█████▏    | 104052/200000 [10:35<32:39, 48.97it/s]

test_loss 0.9561000787898758


 52%|█████▏    | 104987/200000 [10:39<05:47, 273.55it/s]

train_loss 17.80231755738708


 53%|█████▎    | 105046/200000 [10:41<33:59, 46.57it/s]

test_loss 22.69224760641061


 53%|█████▎    | 105990/200000 [10:44<05:06, 306.75it/s]

train_loss 2.3572640180115063


 53%|█████▎    | 106051/200000 [10:47<30:58, 50.54it/s]

test_loss 1.0958152698251051


 53%|█████▎    | 106982/200000 [10:51<08:14, 188.21it/s]

train_loss 1.097875098917978


 54%|█████▎    | 107032/200000 [10:54<39:15, 39.47it/s]  

test_loss 0.9675717922824825


 54%|█████▍    | 107976/200000 [10:57<05:19, 288.14it/s]

train_loss 0.9301961646287237


 54%|█████▍    | 108038/200000 [10:59<30:01, 51.06it/s]

test_loss 0.9242739543378823


 54%|█████▍    | 108992/200000 [11:04<09:15, 163.69it/s]

train_loss 0.8633049323985033


 55%|█████▍    | 109039/200000 [11:07<48:29, 31.26it/s]  

test_loss 1.721139185460232


 55%|█████▍    | 109992/200000 [11:10<05:06, 294.01it/s]

train_loss 0.9625645369200813


 55%|█████▌    | 110056/200000 [11:13<29:43, 50.42it/s]

test_loss 0.910291841620043


 55%|█████▌    | 110982/200000 [11:17<06:36, 224.45it/s]

train_loss 0.8753913201238752


 56%|█████▌    | 111038/200000 [11:20<39:02, 37.98it/s]

test_loss 0.9127076045056792


 56%|█████▌    | 111995/200000 [11:23<04:37, 317.53it/s]

train_loss 1.059355580774506


 56%|█████▌    | 112059/200000 [11:25<27:07, 54.04it/s]

test_loss 0.9715422684799389


 56%|█████▋    | 112982/200000 [11:29<07:17, 198.73it/s]

train_loss 0.9650631003233389


 57%|█████▋    | 113032/200000 [11:32<42:01, 34.49it/s]  

test_loss 0.9099373984184089


 57%|█████▋    | 113983/200000 [11:35<04:41, 305.36it/s]

train_loss 0.9405501095155596


 57%|█████▋    | 114042/200000 [11:38<28:41, 49.92it/s]

test_loss 0.9455279619670384


 57%|█████▋    | 114981/200000 [11:41<06:49, 207.76it/s]

train_loss 7.767810471705144


 58%|█████▊    | 115027/200000 [11:45<45:32, 31.09it/s]  

test_loss 1.102313927960513


 58%|█████▊    | 115978/200000 [11:48<04:56, 283.12it/s]

train_loss 1.0219139692188148


 58%|█████▊    | 116033/200000 [11:51<29:54, 46.79it/s]

test_loss 0.9669766743623076


 58%|█████▊    | 116986/200000 [11:55<06:31, 211.84it/s]

train_loss 0.9009602254665023


 59%|█████▊    | 117035/200000 [11:58<39:38, 34.89it/s]

test_loss 0.9510034004061136


 59%|█████▉    | 117996/200000 [12:01<04:54, 278.39it/s]

train_loss 0.8713563293677089


 59%|█████▉    | 118052/200000 [12:04<29:04, 46.97it/s]

test_loss 0.9544862726520239


 59%|█████▉    | 118985/200000 [12:08<08:01, 168.13it/s]

train_loss 0.979692918206049


 60%|█████▉    | 119029/200000 [12:11<40:15, 33.52it/s]  

test_loss 0.939957727623277


 60%|█████▉    | 119991/200000 [12:14<04:53, 272.33it/s]

train_loss 0.89080099259087


 60%|██████    | 120047/200000 [12:17<28:31, 46.72it/s]

test_loss 1.005950896437952


 60%|██████    | 120989/200000 [12:21<06:36, 199.21it/s]

train_loss 1.0125072544694718


 61%|██████    | 121033/200000 [12:23<36:07, 36.43it/s]

test_loss 0.9950026039898472


 61%|██████    | 121979/200000 [12:27<04:19, 300.79it/s]

train_loss 0.9948192668263509


 61%|██████    | 122042/200000 [12:29<25:28, 51.02it/s]

test_loss 0.9286062230984942


 61%|██████▏   | 122980/200000 [12:33<06:24, 200.51it/s]

train_loss 27.82073194682951


 62%|██████▏   | 123016/200000 [12:37<53:58, 23.77it/s]  

test_loss 3.221083685941143


 62%|██████▏   | 123982/200000 [12:40<04:07, 306.90it/s]

train_loss 1.5298529286676945


 62%|██████▏   | 124035/200000 [12:43<26:21, 48.04it/s]

test_loss 1.24186483415261


 62%|██████▏   | 124988/200000 [12:47<04:48, 260.22it/s]

train_loss 1.1138928733022297


 63%|██████▎   | 125044/200000 [12:50<27:36, 45.24it/s]

test_loss 1.0159771061231833


 63%|██████▎   | 125970/200000 [12:53<04:32, 271.48it/s]

train_loss 1.0042235126182932


 63%|██████▎   | 126032/200000 [12:56<24:28, 50.38it/s]

test_loss 0.9818714550877505


 63%|██████▎   | 126994/200000 [13:00<03:58, 306.25it/s]

train_loss 0.9805152025123021


 64%|██████▎   | 127056/200000 [13:02<23:15, 52.26it/s]

test_loss 0.9584238785059783


 64%|██████▍   | 127980/200000 [13:05<03:58, 302.56it/s]

train_loss 0.9853333661206782


 64%|██████▍   | 128042/200000 [13:08<23:16, 51.53it/s]

test_loss 0.9348611984068307


 64%|██████▍   | 128978/200000 [13:12<04:20, 272.15it/s]

train_loss 1.0095341858066849


 65%|██████▍   | 129037/200000 [13:15<25:26, 46.48it/s]

test_loss 0.9754039450008737


 65%|██████▍   | 129969/200000 [13:18<03:45, 310.20it/s]

train_loss 0.9484874128041042


 65%|██████▌   | 130033/200000 [13:21<21:42, 53.70it/s]

test_loss 0.9402090442776744


 65%|██████▌   | 130967/200000 [13:25<03:58, 289.22it/s]

train_loss 1.0305218452841454


 66%|██████▌   | 131029/200000 [13:27<22:47, 50.45it/s]

test_loss 0.9947785929881545


 66%|██████▌   | 131981/200000 [13:31<03:45, 301.16it/s]

train_loss 0.9565052632265029


 66%|██████▌   | 132042/200000 [13:33<21:58, 51.55it/s]

test_loss 1.0382115626670119


 66%|██████▋   | 132995/200000 [13:37<04:03, 275.52it/s]

train_loss 1.042467156374245


 67%|██████▋   | 133053/200000 [13:40<23:06, 48.27it/s]

test_loss 0.9738164480561579


 67%|██████▋   | 133981/200000 [13:43<03:44, 293.71it/s]

train_loss 1.0231562756268604


 67%|██████▋   | 134044/200000 [13:46<21:13, 51.79it/s]

test_loss 0.9872966474449917


 67%|██████▋   | 134996/200000 [13:50<04:16, 252.97it/s]

train_loss 1.0176701570606004


 68%|██████▊   | 135053/200000 [13:52<22:57, 47.14it/s]

test_loss 1.0680651231391407


 68%|██████▊   | 135069/200000 [13:52<06:40, 162.22it/s]


KeyboardInterrupt: ignored

In [121]:
import numpy as np
languages_set2=np.array(languages_set)
with torch.no_grad():
    line_tensor = lineToTensor("Ivanov")
    hidden = rnn.initHidden()
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
    print(output)
    print(languages_set2[(-output).argsort()[0][:3]])

tensor([[-6.2351e-04, -1.1507e+01, -8.8284e+00, -1.1983e+01, -1.0631e+01,
         -1.6028e+01, -9.8823e+00, -1.1080e+01, -1.2941e+01, -1.2600e+01,
         -1.6891e+01, -1.4690e+01, -7.9562e+00, -1.2311e+01, -1.8311e+01,
         -1.2021e+01, -1.4164e+01, -1.3268e+01]])
['Russian' 'Arabic' 'Czech']
