In [2]:
import torch
import torch.nn as nn

from typing import Any


class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.output = nn.Linear(embedding_dim, vocab_size)

    def forward(self, center_words):
        embeds = self.embedding(center_words)
        out = self.output(embeds)
        return out


def stat_raw_text(raw_text: list[str]):
    vocab = set(raw_text)
    vocab_size = len(vocab)
    word2idx = {word: i for i, word in enumerate(vocab)}
    idx2word = {i: word for i, word in enumerate(vocab)}
    return vocab, vocab_size, word2idx, idx2word


def make_train_data(raw_text: list[str], window_size: int):
    data = []
    test_data = []
    for i in range(window_size, len(raw_text) - window_size):
        center_word = raw_text[i]
        indices = list(range(i - window_size, i)) + \
            list(range(i + 1, i + window_size + 1))

        test_data.append((center_word, [raw_text[j] for j in indices]))
        for j in indices:
            data.append((center_word, raw_text[j]))
    return data, test_data


def word_to_idx_tensor(word: str, word2idx: dict[str, int]) -> torch.Tensor:
    return torch.tensor([word2idx[word]], dtype=torch.long)


raw_text: list[str] = """We are about to study the idea of a computational process.
    Computational processes are abstract beings that inhabit computers.
    As they evolve, processes manipulate other abstract things called data.
    The evolution of a process is directed by a pattern of rules
    called a program. People create programs to direct processes. In effect,
    we conjure the spirits of the computer with our spells.""".split()  # type: ignore

vocab, vocab_size, word2idx, idx2word = stat_raw_text(raw_text)
data,  test_data = make_train_data(raw_text, 2)

input_indices = torch.tensor([word2idx[center]
                             for center, _ in data], dtype=torch.long)
output_indices = torch.tensor([word2idx[out]
                              for _, out in data], dtype=torch.long)

epochs = 500
embedding_dim = 300
learning_rate = 0.001
model = SkipGram(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(input_indices)
    loss = criterion(outputs, output_indices)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 4.024089813232422
Epoch: 1, Loss: 3.9155352115631104
Epoch: 2, Loss: 3.809387445449829
Epoch: 3, Loss: 3.705781936645508
Epoch: 4, Loss: 3.604853391647339
Epoch: 5, Loss: 3.5067331790924072
Epoch: 6, Loss: 3.411552906036377
Epoch: 7, Loss: 3.3194375038146973
Epoch: 8, Loss: 3.230499505996704
Epoch: 9, Loss: 3.144841432571411
Epoch: 10, Loss: 3.0625476837158203
Epoch: 11, Loss: 2.983687400817871
Epoch: 12, Loss: 2.908306837081909
Epoch: 13, Loss: 2.8364319801330566
Epoch: 14, Loss: 2.768064498901367
Epoch: 15, Loss: 2.7031824588775635
Epoch: 16, Loss: 2.6417407989501953
Epoch: 17, Loss: 2.5836734771728516
Epoch: 18, Loss: 2.5288944244384766
Epoch: 19, Loss: 2.477301836013794
Epoch: 20, Loss: 2.4287776947021484
Epoch: 21, Loss: 2.383195638656616
Epoch: 22, Loss: 2.340420961380005
Epoch: 23, Loss: 2.300316333770752
Epoch: 24, Loss: 2.262744665145874
Epoch: 25, Loss: 2.2275702953338623
Epoch: 26, Loss: 2.1946616172790527
Epoch: 27, Loss: 2.163891553878784
Epoch: 28, Loss: 2

In [8]:
def same_count(list1, list2):
    count = 0
    for word in list1:
        if word in list2:
            count += 1
    return count


def sorted_list(a):
    return sorted(a, key=lambda x: x[0])


for center_word, contexts in test_data:
    center_word = word_to_idx_tensor(center_word, word2idx)
    output = model(center_word)
    indices = torch.topk(output[0], 4).indices.tolist()
    predicted_words = [idx2word[idx] for idx in indices]
    print(
        f"Context: {sorted(contexts)}, Predicted: {sorted(predicted_words)}, accuracy: {same_count(contexts, predicted_words)}")

Context: ['We', 'are', 'study', 'to'], Predicted: ['We', 'are', 'study', 'to'], accuracy: 4
Context: ['about', 'are', 'study', 'the'], Predicted: ['about', 'direct', 'programs', 'study'], accuracy: 2
Context: ['about', 'idea', 'the', 'to'], Predicted: ['about', 'idea', 'the', 'to'], accuracy: 4
Context: ['idea', 'of', 'study', 'to'], Predicted: ['computer', 'conjure', 'of', 'spirits'], accuracy: 1
Context: ['a', 'of', 'study', 'the'], Predicted: ['a', 'of', 'study', 'the'], accuracy: 4
Context: ['a', 'computational', 'idea', 'the'], Predicted: ['a', 'called', 'computer', 'the'], accuracy: 2
Context: ['computational', 'idea', 'of', 'process.'], Predicted: ['directed', 'is', 'of', 'pattern'], accuracy: 1
Context: ['Computational', 'a', 'of', 'process.'], Predicted: ['Computational', 'a', 'of', 'process.'], accuracy: 4
Context: ['Computational', 'a', 'computational', 'processes'], Predicted: ['Computational', 'a', 'computational', 'processes'], accuracy: 4
Context: ['are', 'computational'