## CBOW - Word2Vec Implementation Pytorch

In [1]:
import re
import nltk
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from nltk.corpus import webtext
from keras.preprocessing import sequence
from keras.utils import np_utils
from keras.preprocessing import text
from keras.utils import np_utils
from keras.preprocessing import sequence
import matplotlib.pyplot as plt

pd.options.display.max_colwidth = 200
%matplotlib inline

In [2]:
import nltk
nltk.download('stopwords')
nltk.download('webtext')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package webtext to /root/nltk_data...
[nltk_data]   Unzipping corpora/webtext.zip.


True

## Pre-Processing text Code

In [3]:
wordpt = nltk.WordPunctTokenizer()
stop_words = nltk.corpus.stopwords.words('english')

def normalize_document(doc):
    # lower case and remove special characters\whitespaces
    doc = re.sub(r'[^a-zA-Z\s]', '', doc, re.I|re.A)
    doc = doc.lower()
    doc = doc.strip()
    # tokenize document
    tokens = wordpt.tokenize(doc)
    # filter stopwords out of document
    filtered_tokens = [token for token in tokens if token not in stop_words]
    # re-create document from filtered tokens
    doc = ' '.join(filtered_tokens)
    return doc

normalize_corpus = np.vectorize(normalize_document)

In [4]:
corpus = ['The sky is blue and beautiful.',
          'Love this blue and beautiful sky!',
          'The quick brown fox jumps over the lazy dog.',
          "A king's breakfast has sausages, ham, bacon, eggs, toast and beans",
          'I love green eggs, ham, sausages and bacon!',
          'The brown fox is quick and the blue dog is lazy!',
          'The sky is very blue and the sky is very beautiful today',
          'The dog is lazy but the brown fox is quick!'    
]
labels = ['weather', 'weather', 'animals', 'food', 'food', 'animals', 'weather', 'animals']

corpus = np.array(corpus)
corpus_df = pd.DataFrame({'Document': corpus, 
                          'Category': labels})
corpus_df = corpus_df[['Document', 'Category']]
corpus_df

Unnamed: 0,Document,Category
0,The sky is blue and beautiful.,weather
1,Love this blue and beautiful sky!,weather
2,The quick brown fox jumps over the lazy dog.,animals
3,"A king's breakfast has sausages, ham, bacon, eggs, toast and beans",food
4,"I love green eggs, ham, sausages and bacon!",food
5,The brown fox is quick and the blue dog is lazy!,animals
6,The sky is very blue and the sky is very beautiful today,weather
7,The dog is lazy but the brown fox is quick!,animals


In [8]:
# build a sample vocab
vocab = []

for fileid in webtext.fileids():
    vocab.append(webtext.raw(fileid))

### text preprocessing (Remove tags e.g HTML,Remove special characters, Remove stopwords) === Clean data

In [9]:
tokenizer = text.Tokenizer()
tokenizer.fit_on_texts(corpus)
word2id = tokenizer.word_index

word2id['PAD'] = 0
id2word = {v:k for k, v in word2id.items()}
wids = [[word2id[w] for w in text.text_to_word_sequence(doc)] for doc in corpus]

vocab_size = len(word2id)
embed_size = 100
window_size = 2

print('Vocabulary Size:', vocab_size)
print('Vocabulary Sample:', list(word2id.items())[:10])

Vocabulary Size: 31
Vocabulary Sample: [('the', 1), ('is', 2), ('and', 3), ('sky', 4), ('blue', 5), ('beautiful', 6), ('quick', 7), ('brown', 8), ('fox', 9), ('lazy', 10)]


### [context_words, target_word] pairs

In [10]:
def generate_context_word_pairs(corpus, window_size, vocab_size):
    X = []
    Y = []
    context_length = window_size*2
    for words in wids:
        sentence_length = len(words)
        for index, word in enumerate(words):           
            start = index - window_size
            end = index + window_size + 1
            context = [words[i] for i in range(start, end)if 0 <= i < sentence_length and i != index]
            x = sequence.pad_sequences([context], maxlen=context_length)
            X.append(x)
            Y.append(word)
    return X,Y

## CBOW (Contineous bag of Words Model architecture)

In [11]:
import torch
import torch.nn as nn
import numpy as np

class CBOW(torch.nn.Module):

    def __init__(self, inp_size , vocab_size, embedding_dim=100):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 100)
        self.activation_function1 = nn.ReLU()        
        self.linear2 = nn.Linear(100, vocab_size)
        self.activation_function2 = nn.LogSoftmax(dim = -1)
        
    def forward(self, inputs):
        embeds = sum(self.embeddings(torch.from_numpy(inputs).long())).view(1,-1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out
    
model = CBOW(window_size*2,vocab_size)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [13]:
for epoch in range(1, 1000):
    loss = 0
    i = 0
    X,Y = generate_context_word_pairs(corpus=wids, window_size=window_size, vocab_size=10000)
    for x, y in zip(X,Y):
        i += 1
        optimizer.zero_grad()
        log_probs = model(x[0])
        loss = loss_function(log_probs,torch.Tensor([y]).long())
        loss.backward()
        optimizer.step()
        loss += loss.data
    print('Epoch:', epoch, '\tLoss:', loss)

Epoch: 1 	Loss: tensor(5.1038, grad_fn=<AddBackward0>)
Epoch: 2 	Loss: tensor(4.9339, grad_fn=<AddBackward0>)
Epoch: 3 	Loss: tensor(4.7704, grad_fn=<AddBackward0>)
Epoch: 4 	Loss: tensor(4.6138, grad_fn=<AddBackward0>)
Epoch: 5 	Loss: tensor(4.4486, grad_fn=<AddBackward0>)
Epoch: 6 	Loss: tensor(4.3001, grad_fn=<AddBackward0>)
Epoch: 7 	Loss: tensor(4.1443, grad_fn=<AddBackward0>)
Epoch: 8 	Loss: tensor(3.9982, grad_fn=<AddBackward0>)
Epoch: 9 	Loss: tensor(3.8371, grad_fn=<AddBackward0>)
Epoch: 10 	Loss: tensor(3.6970, grad_fn=<AddBackward0>)
Epoch: 11 	Loss: tensor(3.5633, grad_fn=<AddBackward0>)
Epoch: 12 	Loss: tensor(3.4263, grad_fn=<AddBackward0>)
Epoch: 13 	Loss: tensor(3.2965, grad_fn=<AddBackward0>)
Epoch: 14 	Loss: tensor(3.1613, grad_fn=<AddBackward0>)
Epoch: 15 	Loss: tensor(3.0402, grad_fn=<AddBackward0>)
Epoch: 16 	Loss: tensor(2.9310, grad_fn=<AddBackward0>)
Epoch: 17 	Loss: tensor(2.8105, grad_fn=<AddBackward0>)
Epoch: 18 	Loss: tensor(2.7157, grad_fn=<AddBackward0>)
E

In [25]:
weights = model.embeddings(torch.Tensor([list(range(0,vocab_size))]).long())

pd.DataFrame(weights.view(-1,100).tolist(), index=list(id2word.values())[0:]).head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
the,-1.020247,0.858016,-0.972954,1.621598,0.015694,0.071058,-0.546601,0.627173,0.0701,0.540836,-1.658709,-0.710162,0.553695,0.681404,0.93994,1.242287,0.095183,-1.436292,0.910662,0.286473,0.412264,0.236389,0.63772,-1.741466,-1.664177,0.194743,0.228008,-0.64035,0.526864,0.28124,2.378866,1.220054,-2.521966,0.165114,-0.357386,0.639706,0.945562,-2.888935,1.225117,0.340856,...,1.06993,0.774617,-0.063797,0.463545,1.626877,1.04387,-0.687107,-1.10419,1.49926,0.428856,-0.612172,-0.791734,1.190712,0.0771,0.949545,-0.92926,0.752076,-0.10831,0.605397,0.636176,-0.630588,-0.763144,-0.626585,1.092433,-0.112572,-0.522684,0.018455,0.125532,-0.902343,1.652762,1.157533,1.45648,1.167684,0.679525,0.390032,0.883022,1.013172,0.127352,-0.370458,0.407416
is,-0.469278,0.731419,-2.046814,1.10763,-0.602588,0.407658,-0.544868,-1.117358,-2.473301,0.07916,0.954797,-1.084227,-0.515376,-0.23578,-0.376261,-0.193844,-0.043524,0.727588,0.505118,1.006242,-0.607311,-0.436902,1.811525,-0.897329,-0.348212,0.015865,1.939071,0.395484,1.446937,-0.788801,0.87563,0.702702,-1.203162,0.620883,-2.84522,2.736629,-0.391041,-0.1365,0.472429,-1.558961,...,2.048724,-0.249179,-1.945006,0.344417,0.122567,0.743959,-0.53915,0.288773,0.19854,-0.402431,0.996831,-0.342212,1.165547,-0.513374,-0.188044,-0.739049,1.485368,-0.371456,0.615853,0.180818,-1.017328,-0.179646,-0.637363,0.252615,-0.06857,-0.491677,0.869759,0.058729,0.296049,0.85282,-0.52971,-0.124254,0.493743,0.997346,0.847354,1.380898,1.105859,0.46841,-1.300014,0.211676
and,1.303654,-0.400301,0.474038,1.324667,-2.26677,0.128912,-0.284075,-0.957212,-1.241336,0.697034,0.906708,-0.607447,-0.22234,-2.749931,0.071322,-0.807236,0.392371,-0.240701,0.277111,-0.4167,-0.802208,1.83897,0.603281,-0.184513,0.061163,1.31152,1.531312,-1.966784,0.910535,0.025413,-2.657902,-0.771684,0.841913,1.475178,0.801869,-0.333531,1.314127,0.111316,-0.349497,-0.301652,...,-0.805469,0.663779,0.062341,-1.293513,0.419038,1.36008,-0.428137,-1.014816,2.154513,0.037632,-1.194724,-0.565663,-0.230392,1.56047,0.745605,-1.303745,-0.616904,0.891984,-0.829454,0.00904,-0.167278,0.10052,0.025883,-0.090795,-0.806667,-0.212954,-0.57166,1.301951,-1.995423,-0.757284,-1.709719,-0.327805,-0.244655,-0.214154,0.025581,-0.82059,0.131171,0.425759,-0.504117,1.773358
sky,1.449165,0.153278,-1.509703,-0.340397,-0.808891,0.449833,-0.356767,-1.193008,0.558038,0.764705,0.095494,0.418742,0.50398,0.34879,-1.417122,-0.39001,-1.054957,0.292385,-2.709125,0.646774,-1.06224,-1.409375,-0.158662,0.855864,0.027267,-0.104874,0.945288,1.170664,1.482344,-0.349548,-0.754053,-2.300823,-1.179725,2.56183,-0.39251,2.203889,-1.327304,-0.063887,-0.411009,-1.667183,...,-0.587274,-1.464557,1.315149,-0.181798,-0.215293,-0.064679,-0.125366,1.888723,1.516137,-0.700118,-0.243861,1.494783,-0.556467,0.085412,-2.062423,-1.162731,-0.515584,-1.128738,0.951601,-0.017662,-2.747708,-0.893692,-0.329091,-0.308133,0.417533,-0.089249,-0.588559,0.202914,-0.494305,-0.19185,0.045936,2.844258,0.67656,0.412295,-0.157112,-0.010236,-0.340997,0.800338,0.298349,0.183261
blue,-0.469048,-0.312949,0.123316,1.329445,-0.551616,-0.133613,-0.911876,0.007528,2.264735,-0.722735,0.055664,-1.083289,0.455743,-0.041167,-0.014597,-0.915513,0.417387,-0.742937,-0.534138,0.256865,0.496202,0.672305,-0.081591,1.141091,0.126867,-0.233428,-0.047002,0.639444,1.844794,0.935714,2.214909,1.384562,-0.715399,1.816184,0.282828,0.029626,-1.683331,0.306919,0.286198,-0.184079,...,-1.06138,0.460302,-0.739375,1.300935,0.556676,0.095558,-0.119986,-0.871369,-0.106704,-0.253442,-0.041808,0.202927,-0.038642,0.794085,-0.859474,0.157171,1.115904,1.366879,1.360007,-0.997755,0.543524,0.206829,0.667261,-0.277421,-1.065924,1.259926,1.12996,0.334031,-1.456439,-1.293781,-0.815232,-2.548271,-0.507428,0.511041,0.378241,1.546835,-0.228272,0.910393,-2.542138,2.244305
beautiful,-0.533289,-1.827823,-1.032799,-0.451846,0.894649,0.748004,1.423167,0.299123,1.135338,0.562234,0.655088,-0.11121,-1.503113,0.532644,-0.92909,0.205524,0.453472,2.669638,1.494946,1.303193,-0.559608,-1.185347,-1.100271,1.257161,-2.833918,1.233855,-0.0356,0.023203,0.544362,-0.163978,0.631034,2.037613,1.275391,1.610185,-1.418666,-2.337459,-0.100569,-0.214049,0.246707,0.769618,...,-0.866297,0.129173,-0.164398,0.627094,-0.244747,-0.768476,0.259487,0.354147,0.3337,-0.829295,0.270319,0.299325,-1.462435,0.618261,0.635193,0.839166,-0.812424,-0.886244,-1.163957,-0.170511,-0.735343,-3.068204,-1.123936,0.032735,0.532889,0.084369,-1.344577,-0.91287,-1.17175,-0.942274,-0.466412,2.093053,1.168456,1.49962,-0.231552,0.603686,0.1414,-0.061871,-0.691325,-0.207963
quick,-0.558469,-0.937921,-0.6952,2.036233,1.005487,0.400959,0.102566,-1.534967,0.229444,0.2138,-0.71618,-0.071263,0.846729,-0.002873,0.712897,-0.503977,0.68497,1.153641,-0.035872,0.718323,-0.149394,-1.899651,0.425477,-1.223452,-0.281957,-0.086701,-1.098551,-1.214337,0.875268,-0.198583,-0.60376,1.951853,0.452532,-0.478115,0.786527,-0.292773,-1.606256,0.355996,1.234043,0.538837,...,0.407044,-0.616079,-1.552686,0.16148,1.508547,1.535354,-1.188404,0.022315,1.206893,-0.230642,0.525955,-0.411932,1.052186,-1.426858,0.772702,0.410571,2.02269,0.427967,-0.236282,-2.390484,-0.810348,1.703603,1.794683,-0.031043,1.53866,-0.721608,-0.382144,0.785618,-1.141353,0.622182,-0.171468,0.061658,0.587657,-0.942987,-1.038087,-0.553705,-1.465742,0.107394,0.593551,0.867413
brown,0.393551,0.450134,-0.079574,0.268751,1.278429,0.613324,-0.850955,-1.418713,0.596749,-0.345802,0.950896,-0.025625,-1.168137,0.793924,-2.977073,-0.668861,1.038769,-0.083873,0.679444,-0.738849,-0.430129,-0.074172,0.663841,0.483713,1.198106,1.100216,-2.123839,0.062811,-0.477568,-0.211494,-0.315232,-0.184991,1.263915,1.327454,-0.317354,0.562355,0.205006,-0.359538,-1.399411,0.60399,...,0.245905,0.249995,0.207692,-1.040047,-0.546615,-1.59664,0.017206,-0.460313,-0.473303,-0.445533,1.025541,-0.217998,-0.569339,-1.148984,0.874883,2.073158,0.776481,-0.50493,-0.226141,2.423683,-2.107358,-0.331935,1.325698,-0.067668,0.190376,0.62428,-0.458483,0.667423,1.764585,0.414479,0.61054,1.147712,-0.101477,-0.694855,0.206836,0.363352,-0.757771,1.55818,-0.250275,1.686646
fox,0.797055,1.411196,-0.031755,-0.41571,0.437677,-0.657659,0.354126,0.090269,0.325975,-2.354204,0.107897,-0.545043,-1.713373,-0.675881,1.067379,0.405168,0.817522,-0.046598,-0.891815,0.11198,-0.225963,0.112378,-0.113158,0.944294,-1.158892,1.507765,-1.871655,-0.772381,-3.0802,-1.35665,-1.736288,-1.821512,0.791796,-0.07305,0.336441,0.421689,-1.321763,1.158198,-2.189407,-1.608815,...,-0.141466,-1.450243,-0.314869,1.57435,0.951972,1.280705,0.31433,-1.201877,0.33451,0.089217,0.224379,0.550554,0.538171,0.594383,0.131937,-0.734224,-1.917705,0.283525,-0.938562,-1.972943,0.597334,-1.106533,1.231255,-0.223283,1.583508,0.981517,-0.033549,-2.119381,2.264282,1.134446,-0.785851,1.335509,-1.517941,-1.34934,0.753482,0.640396,-0.159388,1.045616,0.765108,-1.618113
lazy,-1.776232,0.995156,1.236979,1.855693,0.06658,0.742035,0.651343,-2.000114,-0.701359,-1.533831,-1.086544,0.519863,0.184006,-0.27405,2.911425,0.115098,-1.112191,-0.479669,1.260256,-0.132396,-1.566988,-0.787054,0.502047,0.440928,-0.449495,-2.725736,0.063942,-1.871578,-0.684811,-0.386248,-0.146623,-0.403312,0.504547,-2.330001,0.666139,-0.979507,1.063577,1.380405,0.362965,2.558188,...,0.077024,0.204456,0.052451,1.531995,1.10002,-0.782305,0.171052,0.772296,0.057735,1.645802,-0.042587,0.384749,-0.337151,-0.776732,0.059966,-0.806825,0.517419,-0.666449,-1.164875,-0.713597,-0.276371,0.01294,-0.516041,1.522119,0.7486,-1.616673,-0.584538,-0.564948,0.231805,0.781042,-0.329508,0.476513,0.212365,-1.254748,-0.556678,-1.258437,-0.174385,-0.158279,0.970042,-0.846919


In [26]:
from sklearn.metrics.pairwise import euclidean_distances

weights = weights.view(-1,100)
distance_matrix = euclidean_distances(weights.detach().numpy())
similar_words = {search_term: [id2word[idx] for idx in distance_matrix[word2id[search_term]-1].argsort()[1:4]+1] 
                 for search_term in ['the', 'fox', 'beautiful','brown','lazy']}

similar_words

{'beautiful': ['over', 'but', 'brown'],
 'brown': ['has', 'jumps', "king's"],
 'fox': ['a', 'toast', 'brown'],
 'lazy': ['quick', 'a', 'this'],
 'the': ['is', 'over', 'quick']}