In [10]:
import numpy as np
import string
from nltk.corpus import stopwords
import nltk
import random

nltk.download('stopwords')

[nltk_data] Downloading package stopwords to C:\Users\Granth
[nltk_data]     Bagadia\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [11]:
class SkipGram(object):
    def __init__(self, N, window_size):
        self.N = N
        self.X_train = []
        self.y_train = []
        self.window_size = window_size
        self.alpha = 0.001
        self.words = []
        self.word_index = {}

    def initialize(self, V, data):
        self.V = V
        self.W1 = np.random.uniform(-0.8, 0.8, (self.V, self.N))
        self.W2 = np.random.uniform(-0.8, 0.8, (self.N, self.V))
        self.words = data
        for i in range(len(data)):
            self.word_index[data[i]] = i

    def preprocessing(self, corpus):
        stop_words = set(stopwords.words('english'))
        training_data = []
        sentences = corpus.split(".")
        for i in range(len(sentences)):
            sentences[i] = sentences[i].strip()
            sentence = sentences[i].split()
            x = [word.strip(string.punctuation) for word in sentence if word not in stop_words]
            x = [word.lower() for word in x]
            training_data.append(x)
        return training_data

    def prepare_data_for_training(self, corpus):
        sentences = self.preprocessing(corpus)
        data = {}
        for sentence in sentences:
            for word in sentence:
                if word not in data:
                    data[word] = 1
                else:
                    data[word] += 1

        V = len(data)
        data = sorted(list(data.keys()))
        vocab = {data[i]: i for i in range(len(data))}

        for sentence in sentences:
            for i in range(len(sentence)):
                center_word = [0 for _ in range(V)]
                center_word[vocab[sentence[i]]] = 1
                context = [0 for _ in range(V)]
                for j in range(i - self.window_size, i + self.window_size + 1):
                    if i != j and j >= 0 and j < len(sentence):
                        context[vocab[sentence[j]]] += 1
                self.X_train.append(center_word)
                self.y_train.append(context)

        self.initialize(V, data)

    def feed_forward(self, X):
        self.h = np.dot(self.W1.T, X).reshape(self.N, 1)
        self.u = np.dot(self.W2.T, self.h)
        self.y = self.softmax(self.u)
        return self.y

    def softmax(self, x):
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum(axis=0)

    def backpropagate(self, x, t):
        e = self.y - np.asarray(t).reshape(self.V, 1)
        dLdW2 = np.dot(self.h, e.T)
        X = np.array(x).reshape(self.V, 1)
        dLdW1 = np.dot(X, np.dot(self.W2, e).T)
        self.W2 -= self.alpha * dLdW2
        self.W1 -= self.alpha * dLdW1

    def get_negative_samples(self, target_word_index, num_samples):
        neg_samples = []
        while len(neg_samples) < num_samples:
            neg_word_index = random.randint(0, self.V - 1)
            if neg_word_index != target_word_index:
                neg_samples.append(neg_word_index)
        return neg_samples

    def train(self, epochs, negative_samples):
        for epoch in range(1, epochs + 1):
            self.loss = 0
            for j in range(len(self.X_train)):
                center_word_vector = self.X_train[j]
                context_vector = self.y_train[j]


                self.feed_forward(center_word_vector)
                self.backpropagate(center_word_vector, context_vector)


                center_word_index = np.argmax(center_word_vector)
                neg_samples = self.get_negative_samples(center_word_index, negative_samples)

                for neg_index in neg_samples:

                    neg_context_vector = [0] * self.V
                    neg_context_vector[neg_index] = 1
                    self.feed_forward(center_word_vector)
                    self.backpropagate(center_word_vector, neg_context_vector)


                C = 0
                for m in range(self.V):
                    if context_vector[m]:
                        self.loss += -1 * self.u[m][0]
                        C += 1
                self.loss += C * np.log(np.sum(np.exp(self.u)))
            if epoch % 1000 == 0:
                print(f"Epoch {epoch}, Loss: {self.loss}")
            self.alpha *= 1 / (1 + self.alpha * epoch)

    def predict(self, word, number_of_predictions):
        if word in self.words:
            index = self.word_index[word]
            X = [0 for _ in range(self.V)]
            X[index] = 1
            prediction = self.feed_forward(X)
            output = {}
            for i in range(self.V):
                output[prediction[i][0]] = i
            top_context_words = []
            for k in sorted(output, reverse=True):
                top_context_words.append(self.words[output[k]])
                if len(top_context_words) >= number_of_predictions:
                    break
            return top_context_words
        else:
            print("Word not found in dictionary")

    def compute_similarity(self, vec1, vec2):
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

    def rank_words(self, target_vector):
        similarities = {}
        for i in range(self.V):
            word_vector = self.W1[i]
            similarity = self.compute_similarity(target_vector, word_vector)
            similarities[i] = similarity
        ranked_words = sorted(similarities, key=similarities.get, reverse=True)
        return ranked_words

    def compute_mrr_for_window(self, target_word, context_words):
        target_index = self.word_index[target_word]
        target_vector = self.W1[target_index]

        ranked_indices = self.rank_words(target_vector)

        mrr = 0
        for context_word in context_words:
            context_index = self.word_index[context_word]
            rank = ranked_indices.index(context_index) + 1
            mrr += 1 / rank
        return mrr / len(context_words)

    def compute_mrr(self, test_data):
        total_mrr = 0
        for window in test_data:
            target_word, context_words = window[0], window[1:]
            total_mrr += self.compute_mrr_for_window(target_word, context_words)
        average_mrr = total_mrr / len(test_data)
        return average_mrr

    def evaluate_mrr(self, test_corpus):
        test_sentences = self.preprocessing(test_corpus)
        test_data = []
        for sentence in test_sentences:
            for i in range(len(sentence)):
                window = sentence[max(0, i - self.window_size):min(len(sentence), i + self.window_size + 1)]
                test_data.append(window)

        mrr = self.compute_mrr(test_data)
        print("MRR:", mrr)

In [12]:
train_corpus = "The earth revolves around the sun. The moon revolves around the earth."
test_corpus = "The sun revolves around the earth. The earth revolves around the moon."

In [13]:
skipGram = SkipGram(N=50, window_size=2)
skipGram.prepare_data_for_training(train_corpus)
skipGram.train(epochs=10000, negative_samples=5)

Epoch 1000, Loss: 40.201049640150515
Epoch 2000, Loss: 40.17765489533684
Epoch 3000, Loss: 40.169494359353294
Epoch 4000, Loss: 40.165549987689054
Epoch 5000, Loss: 40.16328901236755
Epoch 6000, Loss: 40.16172556129341
Epoch 7000, Loss: 40.160648351432584
Epoch 8000, Loss: 40.159855153536824
Epoch 9000, Loss: 40.15924611745983
Epoch 10000, Loss: 40.15875543000474


In [14]:
skipGram.evaluate_mrr(test_corpus)

MRR: 0.3377777777777778


In [15]:
print(skipGram.predict("around", 3))
print(skipGram.W1[skipGram.word_index["around"]])

['earth', 'revolves', 'sun']
[ 0.10300137  0.5440801   0.19187793 -0.18084309  0.56858257 -0.81117868
  0.86946423  0.31258705 -0.46442859  0.42939062 -0.33877956  0.26585697
 -0.77899383  0.84464795  0.55410059 -0.26953526  0.0629062  -0.73567159
 -0.58223677  0.66609512  0.46389084 -0.49299219  0.49264696  0.68486236
 -0.48251843  0.47207132  0.25314679  0.0023692   0.10809423  0.21426151
 -0.162643    0.59679787 -0.70144888  0.71300725 -0.58103881  0.43951779
  0.86931822  0.13632959  0.17345098 -0.09002836 -0.33356811 -0.42807778
  0.28549781 -0.49739613  0.23173142  0.34241522 -0.08334234 -0.32194007
  0.49998108  0.76961949]
