In [None]:
import torch
import torch.nn as nn
import random
import matplotlib.pyplot as plt
import sys
sys.path.append("../common/")
import name2countries_helper
import data_preprocessing

In [None]:
def randomSample(X_int_encoding, Y_int_encoding):
    index = random.randint(0, len(X_int_encoding) - 1)
    return index, X_int_encoding[index], Y_int_encoding[index]

def randomSampleTransform(X, Y, X_int_encoding, Y_int_encoding, n_letters, category_num):
    try:
        index, cur_x, cur_y = randomSample(X_int_encoding, Y_int_encoding)
        cur_x_np = data_preprocessing.one_hot_encoding(cur_x, n_letters)
        cur_y_np = data_preprocessing.one_hot_encoding(cur_y, category_num)
        cur_x_tensor = torch.from_numpy(cur_x_np)
        cur_y_tensor = torch.from_numpy(cur_y_np)
    except IndexError:
        print(index)
    return X[index], Y[index], cur_x_tensor, cur_y_tensor, cur_y


"""
输入当前字符的one-hot表示和上一个hidden state(第一个字符的hidden state是0)
获取的output是各个语言的概率和下一个步骤要用到的hidden state。
"""
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 0)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        return output, hidden

def get_category_from_output(output):
    top_n_values, top_n_indexes = output.topk(1)
    return top_n_values[0].item(), top_n_indexes[0].item()

def train(x_tensor, y_tensor, rnn, y_int, n_hidden):
    # 第一个字符的hidden state是0
    hidden = torch.zeros(n_hidden)
    rnn.zero_grad()
    for i in range(x_tensor.size(0)):
        output, hidden = rnn(x_tensor[i], hidden)
    output = output.unsqueeze(dim=0)
    target = torch.Tensor([y_int]).long()
    loss = loss_func(output, target)
    loss.backward()
    """
    Add parameters' gradients to their values, multiplied by learning rate
    没有使用optimizer，手动更新参数。
    """
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    output = output.squeeze(dim=0)
    return output, loss.item()


def evaluate(rnn, n_hidden, char_tokenizer, n_letters, x_string):
    cur_encoding = char_tokenizer.texts_to_sequences([x_string])
    cur_x_np = data_preprocessing.one_hot_encoding(cur_encoding, n_letters)
    cur_x_tensor = torch.from_numpy(cur_x_np)
    cur_x_tensor = cur_x_tensor.squeeze(dim=0)
    hidden = torch.zeros(n_hidden)
    for i in range(cur_x_tensor.size(0)):
        output, hidden = rnn(cur_x_tensor[i], hidden)
    return output


def predict_name(rnn, n_hidden, x_string, char_tokenizer, n_letters, int2country, n_predictions=3):
    with torch.no_grad():
        output = evaluate(rnn, n_hidden, char_tokenizer, n_letters, x_string)
        # Get top N categories
        top_values, top_index = output.topk(n_predictions)
        for i in range(n_predictions):
            value = top_values[i].item()
            category_index = top_index[i].item()
            print("name - {}, score - {}, country - {}".format(x_string, value, int2country[category_index]))


print("#1 load data")
char_tokenizer, n_letters, country_num, X, Y, X_int_encoding, Y_int_encoding, int2country = name2countries_helper.prepare_name2countries_data_for_task()
print(n_letters)
print(country_num)
print("#2 init RNN")
n_hidden = 128
n_categories = country_num
rnn = RNN(n_letters, n_hidden, n_categories)
print("#3 define parameters")
loss_func = nn.CrossEntropyLoss()
learning_rate = 5e-4
n_iters = 100000
print_every = 5000
plot_every = 5000
# Keep track of losses for plotting
current_loss = 0
all_losses = []


print("#4 start training")
for iter in range(1, n_iters + 1):
    x_ori, y_ori, x_tensor, y_tensor, y_int = randomSampleTransform(X, Y, X_int_encoding, Y_int_encoding, n_letters, n_categories)
    output, loss = train(x_tensor, y_tensor, rnn, y_int, n_hidden)
    current_loss += loss
    predict, predict_index = get_category_from_output(output)
    if iter % print_every == 0:
        print(iter)
        correct = '✓' if predict_index == y_int else '✗ (%s)'
        print("name - {}, country - {}, actual - {}, {}, predict - {}, {}".format(x_ori, y_ori, y_int, int2country[y_int], predict_index, int2country[predict_index]))
    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0
print("#5 print loss")
for loss in all_losses:
    print(loss)

plt.figure()
plt.plot(all_losses)

print("# 5 test single")
predict_name(rnn, n_hidden, 'Dovesky', char_tokenizer, n_letters, int2country)
predict_name(rnn, n_hidden, 'Jackson', char_tokenizer, n_letters, int2country)
predict_name(rnn, n_hidden, 'Satoshi', char_tokenizer, n_letters, int2country)