In [84]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from typing import Sequence

raw_text: str = """a b c d e f g h i j k l m n o p q r s t u v w x y z"""


class TextReader(Dataset):
    def context_to_vector(self, context: list[str], word2idx: dict[str, int]):
        return [word2idx[word] for word in context]

    def stat_raw_text(self, raw_text_list: list[str]):
        vocab = set(raw_text_list)
        vocab_size = len(vocab)
        word2idx = {word: idx for idx, word in enumerate(vocab)}
        idx2word = {idx: word for idx, word in enumerate(vocab)}
        return vocab, vocab_size, word2idx, idx2word

    def make_train_data(self, raw_text_list: list[str], window: int):
        data: list[tuple[list[str], str]] = []
        start_index = window
        for i in range(start_index, len(raw_text_list) - window):
            context = raw_text_list[i - window:i] + \
                raw_text_list[i + 1:i + window + 1]
            target = raw_text_list[i]
            data.append((context, target))
        return data

    def __init__(self, raw_text: str, window: int):
        self.raw_text = raw_text
        self.window = window

        raw_text_list: list[str] = raw_text.split()
        self.vocab, self.vocab_size, self.word2idx, self.idx2word = self.stat_raw_text(
            raw_text_list)

        self.raw_data = self.make_train_data(raw_text_list, window)

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, idx):
        context, target_tensor = self.raw_data[idx]
        context_tensor = torch.tensor(self.context_to_vector(
            context, self.word2idx), dtype=torch.long)
        target_tensor = torch.tensor(
            self.word2idx[target_tensor], dtype=torch.long)
        return context_tensor, target_tensor


class CBow(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBow, self).__init__()
        # (vocab_size, embedding_dim)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = x.view(-1, x.size(-1))
        x = self.embeddings(x)
        x = x.sum(dim=1, keepdim=True)
        x = self.linear(x)
        return x.view(x.size(0), -1)


torch.tensor([1, 2, 3], dtype=torch.long).sum(dim=0)

epoch = 100
window = 2
embedding_dim = 10
dataset = TextReader(raw_text, window)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

model = CBow(dataset.vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(epoch):
    total_loss = 0
    model.train()
    for i, (context, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(context)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"epoch: {epoch}, loss: {total_loss}")

epoch: 0, loss: 40.399088859558105
epoch: 1, loss: 38.147178649902344
epoch: 2, loss: 36.04372477531433
epoch: 3, loss: 34.05264854431152
epoch: 4, loss: 32.22078466415405
epoch: 5, loss: 30.469457864761353
epoch: 6, loss: 28.885462284088135
epoch: 7, loss: 27.360865354537964
epoch: 8, loss: 25.96306025981903
epoch: 9, loss: 24.63059628009796
epoch: 10, loss: 23.39414346218109
epoch: 11, loss: 22.224555134773254
epoch: 12, loss: 21.15007770061493
epoch: 13, loss: 20.11043119430542
epoch: 14, loss: 19.170201241970062
epoch: 15, loss: 18.280733227729797
epoch: 16, loss: 17.432404279708862
epoch: 17, loss: 16.662469148635864
epoch: 18, loss: 15.93205451965332
epoch: 19, loss: 15.248654246330261
epoch: 20, loss: 14.57697343826294
epoch: 21, loss: 13.987346827983856
epoch: 22, loss: 13.418839275836945
epoch: 23, loss: 12.856502532958984
epoch: 24, loss: 12.363566935062408
epoch: 25, loss: 11.872780203819275
epoch: 26, loss: 11.40958958864212
epoch: 27, loss: 10.992053866386414
epoch: 28, lo

In [85]:
# test
total = 0
correct = 0
model.eval()
for context_tensor, target_tensor in dataset:
    output = model(context_tensor)
    output_index = output[0].argmax().item()
    correct += 1 if output_index == target_tensor.item() else 0
    total += 1
    print(
        f"input: {[dataset.idx2word[idx] for idx in  context_tensor.tolist()]}, output: {dataset.idx2word[output_index]}, target: {dataset.idx2word[target_tensor.item()]}")

print(f"accuracy: {correct / total:.4f}")

input: ['a', 'b', 'd', 'e'], output: c, target: c
input: ['b', 'c', 'e', 'f'], output: d, target: d
input: ['c', 'd', 'f', 'g'], output: e, target: e
input: ['d', 'e', 'g', 'h'], output: f, target: f
input: ['e', 'f', 'h', 'i'], output: g, target: g
input: ['f', 'g', 'i', 'j'], output: h, target: h
input: ['g', 'h', 'j', 'k'], output: i, target: i
input: ['h', 'i', 'k', 'l'], output: j, target: j
input: ['i', 'j', 'l', 'm'], output: k, target: k
input: ['j', 'k', 'm', 'n'], output: l, target: l
input: ['k', 'l', 'n', 'o'], output: m, target: m
input: ['l', 'm', 'o', 'p'], output: n, target: n
input: ['m', 'n', 'p', 'q'], output: o, target: o
input: ['n', 'o', 'q', 'r'], output: p, target: p
input: ['o', 'p', 'r', 's'], output: q, target: q
input: ['p', 'q', 's', 't'], output: r, target: r
input: ['q', 'r', 't', 'u'], output: s, target: s
input: ['r', 's', 'u', 'v'], output: t, target: t
input: ['s', 't', 'v', 'w'], output: u, target: u
input: ['t', 'u', 'w', 'x'], output: v, target: v
