In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [30]:
# set flags / seeds
torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [13]:
EPOCHS = 100
test_sentence = """n-gram models are widely used in statistical natural
language processing . In speech recognition , phonemes and sequences of
phonemes are modeled using a n-gram distribution . For parsing , words
are modeled such that each n-gram is composed of n words . For language
identification , sequences of characters / graphemes ( letters of the alphabet
) are modeled for different languages For sequences of characters ,
the 3-grams ( sometimes referred to as " trigrams " ) that can be
generated from " good morning " are " goo " , " ood " , " od " , " dm ",
" mo " , " mor " and so forth , counting the space character as a gram
( sometimes the beginning and end of a text are modeled explicitly , adding
" __g " , " _go " , " ng_ " , and " g__ " ) . For sequences of words ,
the trigrams that can be generated from " the dog smelled like a skunk "
are " # the dog " , " the dog smelled " , " dog smelled like ", " smelled
like a " , " like a skunk " and " a skunk # " .""".split()

In [14]:
trigrams = [([test_sentence[i], test_sentence[i+1]],
            test_sentence[i+2]) for i in range(len(test_sentence) - 2)]

vocab = set(test_sentence)

word2idx = {word: i for i, word in enumerate(vocab)}
idx2word = {i: word for word, i in word2idx.items()}

In [26]:
class NGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim=16, context_size=2):
        super().__init__()

        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        self.l1 = nn.Linear(context_size * embedding_dim, 128)
        self.l2 = nn.Linear(128, vocab_size)
        self._init_weight()

    def forward(self, inputs):
        embeds = self.embeddings(inputs).view(1, -1)
        out = F.relu(self.l1(embeds))
        out = self.l2(out)
        log_probs = F.log_softmax(out, dim=-1)
        return log_probs

    # 初始化参数
    def _init_weight(self, scope=0.1):
        # 这里要使用 data 才能修改值
        self.embeddings.weight.data.uniform_(-scope, scope)
        self.l1.weight.data.uniform_(0, scope)
        self.l1.bias.data.fill_(0)
        self.l2.weight.data.uniform_(0, scope)
        self.l2.bias.data.fill_(0)

In [27]:
criterion = nn.NLLLoss()
model = NGram(len(vocab))
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [28]:
model.train()
for epoch in range(EPOCHS):
    total_loss = torch.Tensor([0])
    for context, target in trigrams:
        context_idxs = list(map(lambda w: word2idx[w], context))
        context_var = torch.LongTensor(context_idxs)

        model.zero_grad()

        log_probs = model(context_var)
        loss = criterion(log_probs, torch.LongTensor([word2idx[target]]))

        loss.backward()
        optimizer.step()

        total_loss += loss.data
    print(total_loss.item())

model.eval()
def predict(context):
    context_idxs = list(map(lambda w: word2idx[w], context))
    context_var = torch.LongTensor(context_idxs)

    with torch.no_grad():
        predict = model(context_var)
        index = (torch.max(predict, 1)[1]).tolist()[0]
        return idx2word[index]


for context in [["widely", "used"], ["and", "so"], ["are", "modeled"]]:
    print("{} + {} = {}".format(context[0], context[1], predict(context)))

908.8380126953125
798.0969848632812
744.8796997070312
707.7742919921875
673.3292846679688
637.6712036132812
599.3861083984375
557.2484130859375
510.922119140625
461.93707275390625
412.8521728515625
367.5153503417969
327.61126708984375
290.9803466796875
258.08612060546875
228.65982055664062
203.10037231445312
180.83921813964844
163.3822784423828
148.52215576171875
137.2045135498047
127.42076110839844
119.11690521240234
112.48844146728516
106.7303237915039
102.3620834350586
98.2813720703125
95.12329864501953
92.1830825805664
90.16474151611328
88.42735290527344
86.56852722167969
84.93019104003906
83.75806427001953
82.68902587890625
81.57109069824219
80.80103302001953
79.99109649658203
79.34346771240234
78.94111633300781
78.17378997802734
77.45096588134766
77.22769927978516
76.73011016845703
76.53297424316406
76.14115905761719
75.9480972290039
75.77069854736328
75.21294403076172
75.12097930908203
74.6357421875
74.49486541748047
74.4771499633789
74.07783508300781
73.84907531738281
73.542411