In [77]:
string = 'Today is a good day for taking a walk'
tokens = string.split()
print(tokens)

['Today', 'is', 'a', 'good', 'day', 'for', 'taking', 'a', 'walk']


In [100]:
token_set = set(tokens) # create all unique tokens
word2id = {word:idx for idx,word in enumerate(token_set)} # give unique identifier to each unique token
id2word = {idx:word for idx,word in enumerate(token_set)}
window = 2        # context window size
embeddings = 10  # number of embeddings to be used for representation
epochs = 100     # number of training iterations
lr = 0.001       # learning rate for CBOW 

print(token_set) # vocabulary
vocab_size = len(token_set)  # size of vocabulary

{'Today', 'is', 'walk', 'taking', 'good', 'a', 'day', 'for'}


In [101]:
import torch
import torch.nn as nn

def context_vector(tokens:list):
    # list of values for each token
    val_context = [word2id[word] for word in tokens] 
    return val_context
    
    
context_pairs = []

# loop through all possible cases 
for i in range(window,len(tokens) - window):
    
    context = []
    
    # words to the left
    for j in range(-window,0):
        context.append(tokens[i+j])
    
    # words to the right
    for j in range(1,window+1):
        context.append(tokens[i+j])
        
    context_pairs.append((context,tokens[i]))
    
# show all context pairs in document
print('context, target pairs\n')
for context in context_pairs:
    print(context)
    
# sample tensor conversion
print('\nfor pytorch; context, target word tensors\n')
for context,target in context_pairs:
    X = torch.tensor(context_vector(context))
    y = torch.tensor(word2id[target])
    print(X,y)

context, target pairs

(['Today', 'is', 'good', 'day'], 'a')
(['is', 'a', 'day', 'for'], 'good')
(['a', 'good', 'for', 'taking'], 'day')
(['good', 'day', 'taking', 'a'], 'for')
(['day', 'for', 'a', 'walk'], 'taking')

for pytorch; context, target word tensors

tensor([0, 1, 4, 6]) tensor(5)
tensor([1, 5, 6, 7]) tensor(4)
tensor([5, 4, 7, 3]) tensor(6)
tensor([4, 6, 3, 5]) tensor(7)
tensor([6, 7, 5, 2]) tensor(3)


In [103]:
from torch.optim import Adam

class CBOW(torch.nn.Module):
    
    def __init__(self,vocab_size,embed_dim):
        super(CBOW,self).__init__()
        
        self.embedding = nn.Embedding(vocab_size,embed_dim)
        self.linear = nn.Linear(embed_dim,vocab_size)
#         self.active = nn.ReLU()
        self.active = nn.LogSoftmax(dim=-1)
        
    def forward(self,x):
        x = sum(self.embedding(x)).view(1,-1)
        x = self.linear(x)
        x = self.active(x)
        return x
    
    def word_embedding(self, x):
        word = torch.tensor([word2id[x]])
        return self.embedding(word).view(1,-1)
    

model = CBOW(vocab_size,embeddings)    
criterion = nn.NLLLoss()
optimiser = Adam(model.parameters(),lr=lr)

In [104]:
model.word_embedding('day')  # randomly initialised weights

tensor([[ 0.4653, -0.5790, -0.0121,  0.7312,  0.5235, -1.0668, -1.2971,  1.1165,
          1.2788, -1.3157]], grad_fn=<ViewBackward0>)

In [105]:
# training loop

lst_loss = []
for epoch in range(epochs):
    
    loss = 0.0
    for context,target in context_pairs:
        
        X = torch.tensor(context_vector(context))
        y = torch.tensor([word2id[target]])        

        y_pred = model(X)
        loss += criterion(y_pred,y)
        
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    lst_loss.append(float(loss.detach().numpy()))
        
print(lst_loss[-1])

6.324942111968994


In [106]:
''' Test out our CBOW model '''

# tokenised text
temp = ['good','day','taking','a']

# convert to tensor
cont_vector = torch.tensor(context_vector(temp)) 

# prediction
y_pred = model(cont_vector)
print(y_pred)
pred_index = torch.argmax(y_pred[0]) # get largest argument
print(f'prediction: ',id2word[pred_index.item()])

tensor([[-3.1887, -2.5548, -2.2363, -2.6371, -2.8696, -2.0258, -1.5787, -1.1783]],
       grad_fn=<LogSoftmaxBackward0>)
prediction:  for


In [107]:
model.word_embedding('day')  # embedding for word day after training

tensor([[ 0.5619, -0.5008, -0.0986,  0.6555,  0.5892, -1.1915, -1.2203,  1.0517,
          1.1961, -1.4373]], grad_fn=<ViewBackward0>)