In [2]:
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import jieba

raw_text: str = """a b c d e f g h i j k l m n o p q r s t u v w x y z"""


def read_dir(directory: str):
    tokens = []
    frequency = {}
    stop_words = set(['、', '：', '。', '，', '的', '等', '一', '二',
                     '三', '（', '）', '《', '》', '“', '”', ' ', '\n'])

    for filename in os.listdir(directory):
        filepath = os.path.join(directory, filename)
        with open(filepath, 'r') as f:
            text = f.read()
            for word in jieba.cut(text.strip()):
                if word not in stop_words and word.strip():
                    if word in frequency:
                        frequency[word] += 1
                    else:
                        frequency[word] = 1
                    tokens.append(word.strip())
    return tokens


class TextReader(Dataset):
    def context_to_vector(self, context: list[str], word2idx: dict[str, int]):
        return [word2idx[word] for word in context]

    def stat_raw_text(self, raw_text_list: list[str]):
        vocab = set(raw_text_list)
        vocab_size = len(vocab)
        word2idx = {word: idx for idx, word in enumerate(vocab)}
        idx2word = {idx: word for idx, word in enumerate(vocab)}
        return vocab, vocab_size, word2idx, idx2word

    def make_train_data(self, raw_text_list: list[str], window: int):
        data: list[tuple[list[str], str]] = []
        start_index = window
        for i in range(start_index, len(raw_text_list) - window):
            context = raw_text_list[i - window:i] + \
                raw_text_list[i + 1:i + window + 1]
            target = raw_text_list[i]
            data.append((context, target))
        return data

    def __init__(self, raw_text_list: list[str], window: int):
        self.raw_text_list = raw_text_list
        self.window = window
        self.vocab, self.vocab_size, self.word2idx, self.idx2word = self.stat_raw_text(
            raw_text_list)

        self.raw_data = self.make_train_data(raw_text_list, window)

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, idx):
        context, target_tensor = self.raw_data[idx]
        context_tensor = torch.tensor(self.context_to_vector(
            context, self.word2idx), dtype=torch.long)
        target_tensor = torch.tensor(
            self.word2idx[target_tensor], dtype=torch.long)
        return context_tensor, target_tensor


class CBow(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBow, self).__init__()
        # (vocab_size, embedding_dim)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(
            in_features=embedding_dim,
            out_features=vocab_size
        )

    def forward(self, x):
        x = x.view(-1, x.size(-1))
        x = self.embeddings(x)
        x = x.mean(dim=1, keepdim=True)
        x = self.linear(x)
        return x.view(x.size(0), -1)


batch_size = 128
epoch = 100
window = 4
embedding_dim = 100
tokens = read_dir('./data')
dataset = TextReader(tokens, window)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = CBow(dataset.vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(epoch):
    total_loss = 0
    model.train()
    for i, (context, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(context)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"epoch: {epoch}, loss: {total_loss}")

ValueError: too many values to unpack (expected 2)

In [2]:
# test
total = 0
correct = 0
model.eval()
for context_tensor, target_tensor in dataset:
    output = model(context_tensor)
    output_index = output[0].argmax().item()
    correct += 1 if output_index == target_tensor.item() else 0
    total += 1
    print(
        f"input: {[dataset.idx2word[idx] for idx in  context_tensor.tolist()]}, output: {dataset.idx2word[output_index]}, target: {dataset.idx2word[target_tensor.item()]}")

print(f"accuracy: {correct / total:.4f}")

input: ['应用领域', '医疗保健', '辅助', '医生', '疾病诊断', '医学影像', '分析', '药物'], output: 进行, target: 进行
input: ['医疗保健', '辅助', '医生', '进行', '医学影像', '分析', '药物', '研发'], output: 疾病诊断, target: 疾病诊断
input: ['辅助', '医生', '进行', '疾病诊断', '分析', '药物', '研发', '例如'], output: 医学影像, target: 医学影像
input: ['医生', '进行', '疾病诊断', '医学影像', '药物', '研发', '例如', '人工智能'], output: 分析, target: 分析
input: ['进行', '疾病诊断', '医学影像', '分析', '研发', '例如', '人工智能', '可以'], output: 药物, target: 药物
input: ['疾病诊断', '医学影像', '分析', '药物', '例如', '人工智能', '可以', '通过'], output: 研发, target: 研发
input: ['医学影像', '分析', '药物', '研发', '人工智能', '可以', '通过', '分析'], output: 例如, target: 例如
input: ['分析', '药物', '研发', '例如', '可以', '通过', '分析', '大量'], output: 人工智能, target: 人工智能
input: ['药物', '研发', '例如', '人工智能', '通过', '分析', '大量', '医学'], output: 可以, target: 可以
input: ['研发', '例如', '人工智能', '可以', '分析', '大量', '医学', '图像'], output: 通过, target: 通过
input: ['例如', '人工智能', '可以', '通过', '大量', '医学', '图像', '帮助'], output: 分析, target: 分析
input: ['人工智能', '可以', '通过', '分析', '医学', '图像', '帮助', '医生'], output: