In [1]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x2baf01fc210>

In [2]:
CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()

# By deriving a set from `raw_text`, we deduplicate the array
vocab = set(raw_text)
vocab_size = len(vocab)

word_to_ix = {word: i for i, word in enumerate(vocab)}
data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    data.append((context, target))
print(data[:5])


[(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study'), (['to', 'study', 'idea', 'of'], 'the'), (['study', 'the', 'of', 'a'], 'idea')]


In [3]:
vocab_size

49

In [116]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(2*context_size * embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size)
    def test_embedding(self,inputs):
        if inputs.data[0] == 36:
            print(self.embeddings(inputs[0]))
            print(self.embeddings.weight)
        else:
            pass
    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        self.test_embedding(inputs)
        
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out)
        return log_probs
    

# create your model and train.  here are some functions to help you make
# the data ready for use by your module

In [117]:
losses = []
loss_function = nn.NLLLoss()
model = CBOW(vocab_size, 10, CONTEXT_SIZE)
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [118]:
def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    tensor = torch.LongTensor(idxs)
    #print(tensor)
    return autograd.Variable(tensor)


In [119]:

for epoch in range(2):
    total_loss = torch.Tensor([0])
    print(str(epoch)+'*******')
    for context, target in data:
        
        # Step 1. Prepare the inputs to be passed to the model (i.e, turn the words
        # into integer indices and wrap them in variables)
        context_var = make_context_vector(context, word_to_ix)
        if context_var.data[0] == 36:
            print(context)
        # Step 2. Recall that torch *accumulates* gradients. Before passing in a
        # new instance, you need to zero out the gradients from the old
        # instance
        model.zero_grad()

        # Step 3. Run the forward pass, getting log probabilities over next
        # words
        log_probs = model(context_var)

        # Step 4. Compute your loss function. (Again, Torch wants the target
        # word wrapped in a variable)
        loss = loss_function(log_probs, autograd.Variable(
            torch.LongTensor([word_to_ix[target]])))

        # Step 5. Do the backward pass and update the gradient
        loss.backward()
        optimizer.step()

        total_loss += loss.data
    losses.append(total_loss)
print(losses, len(losses))  

0*******
['are', 'about', 'study', 'the']
Variable containing:
-0.2805  0.8897  1.1810 -0.0653  0.7099  1.4811  1.4208  1.2979 -0.0694 -0.3817
[torch.FloatTensor of size 1x10]

Parameter containing:
 0.7747 -0.7813 -0.8984  1.2846 -0.7147  2.1725 -0.9450 -0.2979 -0.0682  0.0382
-0.4303 -2.7833  2.4552  0.0798  0.2870 -0.5868 -1.2120 -0.1455  1.1880 -2.0991
 0.0692 -1.0046 -1.4106 -0.0915 -1.2086 -0.7630 -0.0951 -0.4809  0.6724 -0.9945
-1.2721 -0.8173 -0.1642 -0.4784 -0.0912 -2.4624  0.9267  0.9246  0.0243 -0.3899
-0.2931 -0.2631  1.9666  0.0820 -0.7569  0.4529  0.3982  0.8903  0.7367  0.7852
-2.0423  0.6389 -0.5267 -1.1582 -0.1310 -0.9850  0.5880 -0.7275 -1.1458  1.2855
 0.1970 -0.2333 -2.5346 -1.2132  0.2244 -1.5374  1.9983 -0.0735  0.2364 -0.8533
 0.1642 -0.2454 -1.3531 -1.7252 -1.1503 -2.7380  0.7589 -0.0798  0.3698  0.0653
 0.8205  1.4976 -1.5908  0.3253 -0.2204 -0.6457  1.4039 -0.4400 -0.7118 -0.4682
-0.8177 -0.7243  1.1279  0.7613  0.5012  1.1098  0.1932 -2.0337 -0.7202  0.0267
 