In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from pygments.lexer import words

In [2]:
def make_batch(sentences, word_dict):
    input_batch = []
    target_batch = []
    for sen in sentences:
        words = sen.split()
        input = [word_dict[n] for n in words[:-1]]  # create (1~n-1) as input
        target = word_dict[words[-1]]  # create (n) as target, We usually call this 'casual language model'
        input_batch.append(input)
        target_batch.append(target)
    return input_batch, target_batch

In [3]:
class NNLM(nn.Module):
    def __init__(self, voc_size, m, n_step, n_hidden):
        super(NNLM, self).__init__()
        self.C = nn.Embedding(voc_size, m)
        self.H = nn.Linear(n_step * m, n_hidden, bias=False)
        self.d = nn.Parameter(torch.ones(n_hidden))
        self.U = nn.Linear(n_hidden, voc_size, bias=False)
        self.W = nn.Linear(n_step * m, voc_size, bias=False)
        self.b = nn.Parameter(torch.ones(voc_size))

    def forward(self, X):
        X = self.C(X)  # X : [batch_size, n_step, m]
        X = X.view(-1, n_step * m)  # [batch_size, n_step * m]
        tanh = torch.tanh(self.d + self.H(X))  # [batch_size, n_hidden]
        output = self.b + self.W(X) + self.U(tanh)  # [batch_size, n_class]
        return output

In [4]:
sentences = ["i like dog", "i love coffee", "i hate milk"]

n_step = 2
n_hidden = 2
m = 2
word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
voc_size = len(word_dict)

In [5]:
model = NNLM(voc_size, m, n_step, n_hidden)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
input_batch, target_batch = make_batch(sentences, word_dict)
input_batch = torch.LongTensor(input_batch)
target_batch = torch.LongTensor(target_batch)
for epoch in range(5000):
    optimizer.zero_grad()
    output = model(input_batch)
    loss = criterion(output, target_batch)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 100 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

Epoch: 0100 loss = 1.626768
Epoch: 0200 loss = 1.144668
Epoch: 0300 loss = 0.868681
Epoch: 0400 loss = 0.714122
Epoch: 0500 loss = 0.623982
Epoch: 0600 loss = 0.565981
Epoch: 0700 loss = 0.521377
Epoch: 0800 loss = 0.475226
Epoch: 0900 loss = 0.402908
Epoch: 1000 loss = 0.263795
Epoch: 1100 loss = 0.151724
Epoch: 1200 loss = 0.097270
Epoch: 1300 loss = 0.068839
Epoch: 1400 loss = 0.051929
Epoch: 1500 loss = 0.040870
Epoch: 1600 loss = 0.033143
Epoch: 1700 loss = 0.027481
Epoch: 1800 loss = 0.023180
Epoch: 1900 loss = 0.019821
Epoch: 2000 loss = 0.017137
Epoch: 2100 loss = 0.014954
Epoch: 2200 loss = 0.013151
Epoch: 2300 loss = 0.011642
Epoch: 2400 loss = 0.010366
Epoch: 2500 loss = 0.009277
Epoch: 2600 loss = 0.008338
Epoch: 2700 loss = 0.007524
Epoch: 2800 loss = 0.006813
Epoch: 2900 loss = 0.006189
Epoch: 3000 loss = 0.005638
Epoch: 3100 loss = 0.005149
Epoch: 3200 loss = 0.004713
Epoch: 3300 loss = 0.004323
Epoch: 3400 loss = 0.003974
Epoch: 3500 loss = 0.003659
Epoch: 3600 loss = 0