In [None]:
import numpy as np
from collections import Counter
from sklearn.preprocessing import OneHotEncoder


In [None]:
# Sample
sentences = [
    "he is a king",
    "she is a queen",
    "he is a man",
    "she is a woman",
    "wars make a king",
    "peace makes a queen",
]
#This is done for simplicity
words = [word.lower() for sentence in sentences for word in sentence.split()]


In [None]:
# Vocabulary and word frequency
word_counts = Counter(words)
vocab = list(word_counts.keys())
vocab_size = len(vocab)
word_to_index = {word: i for i, word in enumerate(vocab)}
index_to_word = {i: word for word, i in word_to_index.items()}


In [None]:
# Window size for context
window_size = 3
def generate_training_data(words, word_to_index, window_size):
    training_data = []
    for i, word in enumerate(words):
        context = words[max(0, i - window_size): i] + words[i + 1: i + 1 + window_size]
        target = word
        for context_word in context:
            training_data.append((word_to_index[target], word_to_index[context_word]))
    return np.array(training_data)

training_data = generate_training_data(words, word_to_index, window_size)


In [None]:
# One-hot encoding
def to_one_hot(word_index, vocab_size):
    one_hot = np.zeros(vocab_size)
    one_hot[word_index] = 1
    return one_hot


In [None]:
# Initialize weights
embedding_dim = 5
W1 = np.random.uniform(-1, 1, (vocab_size, embedding_dim))
W2 = np.random.uniform(-1, 1, (embedding_dim, vocab_size))


In [None]:
# Forward pass
def forward_pass(one_hot_input):
    hidden = np.dot(one_hot_input, W1)
    output = np.dot(hidden, W2)
    y_pred = softmax(output)
    return hidden, y_pred

# Softmax function
def softmax(x):
    exp_x = np.exp(x - np.max(x))
    return exp_x / exp_x.sum(axis=0)


In [None]:
# Backpropagation
def backward_pass(error, hidden, one_hot_input):
    dW2 = np.outer(hidden, error)
    dW1 = np.outer(one_hot_input, np.dot(W2, error))
    return dW1, dW2


In [None]:
# Training
learning_rate = 0.01
epochs = 10

for epoch in range(epochs):
    loss = 0
    for target_idx, context_idx in training_data:
        one_hot_target = to_one_hot(target_idx, vocab_size)
        one_hot_context = to_one_hot(context_idx, vocab_size)
        hidden, y_pred = forward_pass(one_hot_target)
        error = y_pred - one_hot_context
        dW1, dW2 = backward_pass(error, hidden, one_hot_target)
        W1 -= learning_rate * dW1
        W2 -= learning_rate * dW2

        loss += -np.log(y_pred[context_idx])
    print(f"Epoch {epoch}, Loss: {loss}")


Epoch 0, Loss: 2.5766744200105025
Epoch 0, Loss: 5.445291424757586
Epoch 0, Loss: 6.923903565570465
Epoch 0, Loss: 10.06076810973243
Epoch 0, Loss: 12.184831824097026
Epoch 0, Loss: 14.350115561726177
Epoch 0, Loss: 17.251145154532153
Epoch 0, Loss: 20.053906511578294
Epoch 0, Loss: 22.714764189785996
Epoch 0, Loss: 26.010823262261376
Epoch 0, Loss: 28.82133611144585
Epoch 0, Loss: 32.26730589789078
Epoch 0, Loss: 35.50887668368157
Epoch 0, Loss: 38.307530953468515
Epoch 0, Loss: 41.368183562542725
Epoch 0, Loss: 44.112938509576175
Epoch 0, Loss: 46.884551305504175
Epoch 0, Loss: 48.981087783818154
Epoch 0, Loss: 51.24835352455501
Epoch 0, Loss: 54.36485898194253
Epoch 0, Loss: 57.54482183722108
Epoch 0, Loss: 60.15768947051011
Epoch 0, Loss: 63.00982474258522
Epoch 0, Loss: 65.13293526229033
Epoch 0, Loss: 67.30401523070304
Epoch 0, Loss: 70.62410921021021
Epoch 0, Loss: 74.13883044872361
Epoch 0, Loss: 75.47113663307236
Epoch 0, Loss: 78.75936700942293
Epoch 0, Loss: 81.6139596714147

In [None]:
word_embeddings = {word: W1[i] for word, i in word_to_index.items()}
for word, embedding in word_embeddings.items():
    print(f"{word}: {embedding}")


he: [ 0.00246618  0.41254924  0.23325808 -0.29285939  0.0961159 ]
is: [ 0.70426141  0.25911051 -0.64530519 -0.17949431  0.52033904]
a: [-0.19143089  0.19952307  0.4657596  -0.30542171  0.72586937]
king: [-0.40444153 -0.23770762  0.02700394 -0.11833331  0.45339811]
she: [-0.70217659  0.21126998 -0.46749718  0.36691592 -0.68722176]
queen: [ 0.50227483  0.40602354 -0.54670739  0.28562272  0.37581573]
man: [ 0.32334353 -0.67383419 -0.19703408 -1.0008143  -0.85198242]
woman: [-0.23143204  0.0157053   0.46537266 -0.66080491 -0.89779761]
wars: [ 0.12672642  0.59596468  0.41216848 -0.39669606  0.09385499]
make: [ 0.25211405 -0.60399437  0.0062198   0.52831562 -0.89871701]
peace: [-0.74110534  0.66862619  0.04351192 -0.39275598  0.1577187 ]
makes: [0.52278634 0.65463242 0.38044735 0.38311549 0.34581924]
