In [60]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from typing import Sequence

raw_text: Sequence[str] = """a b c d e f g h i j k l m n o p q r s t u v w x y z""".split()


class TextReader(Dataset):
    def stat_raw_text(self, raw_text_list: list[str]):
        vocab = set(raw_text_list)
        vocab_size = len(vocab)
        word2idx = {word: idx for idx, word in enumerate(vocab)}
        idx2word = {idx: word for idx, word in enumerate(vocab)}
        return vocab, vocab_size, word2idx, idx2word

    def make_train_data(self, raw_text_list: list[str], window: int):
        data: list[tuple[list[str], str]] = []
        start_index = window
        for i in range(start_index, len(raw_text_list) - window):
            context = raw_text_list[i - window:i] + \
                raw_text_list[i + 1:i + window + 1]
            target = raw_text_list[i]
            data.append(
                (self.context_to_vector(context), self.word2idx[target]))
        return data

    def context_to_vector(self, context: list[str], word2idx: dict[str, int]):
        return [word2idx[word] for word in context]

    def __init__(self, raw_text: str, window: int):
        self.raw_text = raw_text
        self.window = window

        raw_text_list: list[str] = raw_text.split()
        self.vocab, self.vocab_size, self.word2idx, self.idx2word = self.stat_raw_text(
            raw_text_list)

        self.raw_data = self.make_train_data(raw_text_list, window)


class CBow(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBow, self).__init__()
        # (vocab_size, embedding_dim)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = self.embeddings(x)
        x = x.sum(dim=0, keepdim=True)
        x = self.linear(x)
        return x


torch.tensor([1, 2, 3], dtype=torch.long).sum(dim=0)

epoch = 100
window = 2
embedding_dim = 5
vocab, vocab_size, word2idx, idx2word = stat_raw_text(raw_text)
data = make_train_data(raw_text, window)  # type: ignore
model = CBow(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(epoch):
    total_loss = 0
    for context, target in data:
        context_vector = torch.tensor(
            context_to_vector(context, word2idx), dtype=torch.long)
        target = torch.tensor([word2idx[target]], dtype=torch.long)

        optimizer.zero_grad()
        output = model(context_vector)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"epoch: {epoch}, loss: {total_loss}")

epoch: 0, loss: 79.0704243183136
epoch: 1, loss: 75.67793798446655
epoch: 2, loss: 72.52678179740906
epoch: 3, loss: 69.59238135814667
epoch: 4, loss: 66.8529144525528
epoch: 5, loss: 64.28892278671265
epoch: 6, loss: 61.88314151763916
epoch: 7, loss: 59.62040662765503
epoch: 8, loss: 57.48754823207855
epoch: 9, loss: 55.47321283817291
epoch: 10, loss: 53.56761360168457
epoch: 11, loss: 51.76227378845215
epoch: 12, loss: 50.0497624874115
epoch: 13, loss: 48.423459112644196
epoch: 14, loss: 46.877384305000305
epoch: 15, loss: 45.406075060367584
epoch: 16, loss: 44.00450521707535
epoch: 17, loss: 42.668042838573456
epoch: 18, loss: 41.39242708683014
epoch: 19, loss: 40.17375087738037
epoch: 20, loss: 39.00845295190811
epoch: 21, loss: 37.893294274806976
epoch: 22, loss: 36.82534384727478
epoch: 23, loss: 35.80194345116615
epoch: 24, loss: 34.82068282365799
epoch: 25, loss: 33.879368513822556
epoch: 26, loss: 32.975992143154144
epoch: 27, loss: 32.10868388414383
epoch: 28, loss: 31.275705

In [61]:
# test
total = 0
correct = 0
for context, target in data:
    context_vector = torch.tensor(
        context_to_vector(context, word2idx), dtype=torch.long)
    output = model(context_vector)
    output_word = idx2word[output.argmax().item()]
    correct += 1 if output_word == target else 0
    total += 1
    print(
        f"input: {context}, output: {output_word}, target: {target}")

print(f"accuracy: {correct / total:.4f}")

input: ['a', 'b', 'd', 'e'], output: c, target: c
input: ['b', 'c', 'e', 'f'], output: d, target: d
input: ['c', 'd', 'f', 'g'], output: e, target: e
input: ['d', 'e', 'g', 'h'], output: f, target: f
input: ['e', 'f', 'h', 'i'], output: g, target: g
input: ['f', 'g', 'i', 'j'], output: h, target: h
input: ['g', 'h', 'j', 'k'], output: i, target: i
input: ['h', 'i', 'k', 'l'], output: j, target: j
input: ['i', 'j', 'l', 'm'], output: k, target: k
input: ['j', 'k', 'm', 'n'], output: l, target: l
input: ['k', 'l', 'n', 'o'], output: m, target: m
input: ['l', 'm', 'o', 'p'], output: n, target: n
input: ['m', 'n', 'p', 'q'], output: o, target: o
input: ['n', 'o', 'q', 'r'], output: p, target: p
input: ['o', 'p', 'r', 's'], output: q, target: q
input: ['p', 'q', 's', 't'], output: r, target: r
input: ['q', 'r', 't', 'u'], output: s, target: s
input: ['r', 's', 'u', 'v'], output: t, target: t
input: ['s', 't', 'v', 'w'], output: u, target: u
input: ['t', 'u', 'w', 'x'], output: v, target: v
