In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from pygments.lexer import words

In [8]:
def make_batch(sentences, word_dict):
    input_batch = []
    target_batch = []
    for sen in sentences:
        words = sen.split()
        input = [word_dict[n] for n in words[:-1]]  # create (1~n-1) as input
        target = word_dict[words[-1]]  # create (n) as target, We usually call this 'casual language model'
        input_batch.append(input)
        target_batch.append(target)
    return input_batch, target_batch

In [21]:
class NNLM(nn.Module):
    def __init__(self, voc_size, m, n_step, n_hidden):
        super(NNLM, self).__init__()
        self.C = nn.Embedding(voc_size, m)
        self.H = nn.Linear(n_step * m, n_hidden, bias=False)
        self.d = nn.Parameter(torch.ones(n_hidden))
        self.U = nn.Linear(n_hidden, voc_size, bias=False)
        self.W = nn.Linear(n_step * m, voc_size, bias=False)
        self.b = nn.Parameter(torch.ones(voc_size))

    def forward(self, X):
        X = self.C(X)  # X : [batch_size, n_step, m]
        X = X.view(-1, n_step * m)  # [batch_size, n_step * m]
        tanh = torch.tanh(self.d + self.H(X))  # [batch_size, n_hidden]
        output = self.b + self.W(X) + self.U(tanh)  # [batch_size, n_class]
        return output

In [22]:
sentences = ["i like dog", "i love coffee", "i hate milk"]

n_step = 2
n_hidden = 2
m = 2
word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
voc_size = len(word_dict)

In [23]:
model = NNLM(voc_size, m, n_step, n_hidden)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [26]:
input_batch, target_batch = make_batch(sentences, word_dict)
input_batch = torch.LongTensor(input_batch)
target_batch = torch.LongTensor(target_batch)
for epoch in range(5000):
    optimizer.zero_grad()
    output = model(input_batch)
    loss = criterion(output, target_batch)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 100 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

Epoch: 0100 cost = 0.011559
Epoch: 0200 cost = 0.010265
Epoch: 0300 cost = 0.009154
Epoch: 0400 cost = 0.008195
Epoch: 0500 cost = 0.007362
Epoch: 0600 cost = 0.006634
Epoch: 0700 cost = 0.005996
Epoch: 0800 cost = 0.005434
Epoch: 0900 cost = 0.004937
Epoch: 1000 cost = 0.004496
Epoch: 1100 cost = 0.004102
Epoch: 1200 cost = 0.003751
Epoch: 1300 cost = 0.003436
Epoch: 1400 cost = 0.003153
Epoch: 1500 cost = 0.002899
Epoch: 1600 cost = 0.002668
Epoch: 1700 cost = 0.002460
Epoch: 1800 cost = 0.002271
Epoch: 1900 cost = 0.002099
Epoch: 2000 cost = 0.001942
Epoch: 2100 cost = 0.001799
Epoch: 2200 cost = 0.001668
Epoch: 2300 cost = 0.001549
Epoch: 2400 cost = 0.001439
Epoch: 2500 cost = 0.001338
Epoch: 2600 cost = 0.001246
Epoch: 2700 cost = 0.001161
Epoch: 2800 cost = 0.001082
Epoch: 2900 cost = 0.001010
Epoch: 3000 cost = 0.000943
Epoch: 3100 cost = 0.000881
Epoch: 3200 cost = 0.000823
Epoch: 3300 cost = 0.000770
Epoch: 3400 cost = 0.000721
Epoch: 3500 cost = 0.000675
Epoch: 3600 cost = 0