In [1]:
import torch, torch.nn as nn


In [2]:
def make_context_vector(ctx, word2ix):
    return torch.tensor([word2ix[w] for w in ctx], dtype=torch.long)


In [3]:
text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()


In [4]:
vocab = list(set(text))
v2i = {w: i for i, w in enumerate(vocab)}
i2v = {i: w for w, i in v2i.items()}
print("vocab_size =", len(vocab))


vocab_size = 49


In [5]:
data = [([text[i-2], text[i-1], text[i+1], text[i+2]], text[i])
        for i in range(2, len(text)-2)]
print("examples:", data[:3])


examples: [(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study')]


In [6]:
class CBOW(nn.Module):
    def __init__(s, v, e):
        super().__init__()
        s.emb = nn.Embedding(v, e)
        s.l1  = nn.Linear(e, 128)
        s.l2  = nn.Linear(128, v)
    def forward(s, x):
        return nn.LogSoftmax(dim=-1)(s.l2(torch.relu(s.l1(sum(s.emb(x)).view(1, -1)))))


In [7]:
m = CBOW(len(vocab), 100)
lossf = nn.NLLLoss()
opt = torch.optim.SGD(m.parameters(), lr=0.001)


In [8]:
for _ in range(50):
    tl = 0.0
    for c, t in data:
        opt.zero_grad()
        o = m(make_context_vector(c, v2i))
        l = lossf(o, torch.tensor([v2i[t]]))
        l.backward()
        opt.step()
        tl += l.item()
    # optional: print loss for monitoring
    # print("epoch loss:", tl)
print("Training done!")


Training done!


In [9]:
ctx = ['People', 'create', 'to', 'direct']
out = m(make_context_vector(ctx, v2i))
pred = i2v[torch.argmax(out[0]).item()]
print("Context:", ctx)
print("Prediction:", pred)


Context: ['People', 'create', 'to', 'direct']
Prediction: programs
