# 1. 构建语料库

In [2]:
sentences = [
    "Kage is Teacher", "Mazong is Boss", "Niuzong is Boss",
    "Xiaobing is Student", "Xiaoxue is Student"
]
words = " ".join(sentences).split()
word_list = list(set(words))
word_to_idx = {w: i for i, w in enumerate(word_list)}
idx_to_word = {i: w for i, w in enumerate(word_list)}
voc_size = len(word_list)
print(f"词汇表: {word_list}")
print(f"词汇到索引: {word_to_idx}")
print(f"索引到词汇: {idx_to_word}")
print(f"词汇表大小: {voc_size}")

# 2. 生成 SkipGram 数据集

In [3]:
def create_skipgram_dataset(sentences, window_size=2):
    data = []
    for sentence in sentences:
        sentence = sentence.split()
        for idx, word in enumerate(sentence):
            for neighbor in sentence[max(idx - window_size, 0): min(idx + window_size, len(sentence))]:
                if neighbor != word:
                    data.append((neighbor, word))
    return data

In [4]:
skipgram_data = create_skipgram_dataset(sentences)
print(f"SkipGram 数据集: {skipgram_data[:3]}")

# 3. 进行 OneHot 编码

In [5]:
import torch

def one_hot_encoding(word, word_to_idx):
    tensor = torch.zeros(len(word_to_idx))
    tensor[word_to_idx[word]] = 1
    return tensor


word_sample = "Teacher"
print(f"编码前单词: {word_sample}")
print(f"Onehot 编码: {one_hot_encoding(word_sample, word_to_idx)}")
print(
    f"SkipGram 数据样例: {[(one_hot_encoding(context, word_to_idx), word_to_idx[target]) for context, target in skipgram_data[:3]]}")

# 4. 定义 SkipGram 模型

In [6]:
import torch.nn as nn

class SkipGram(nn.Module):
    def __init__(self, voc_size, embedding_size):
        super().__init__()
        self.input_to_hidden = nn.Linear(voc_size, embedding_size, bias=False)
        self.hidden_to_output = nn.Linear(embedding_size, voc_size, bias=False)

    def forward(self, X):
        hidden = self.input_to_hidden(X)
        output = self.hidden_to_output(hidden)
        return output


In [7]:
embedding_size = 2
skipgram_model = SkipGram(voc_size, embedding_size)
print(skipgram_model)

# 5. 训练 SkipGram 模型

In [8]:
import torch.optim as optim

learning_rate = 0.001
epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(skipgram_model.parameters(), lr=learning_rate)
loss_values = []
for epoch in range(epochs):
    loss_sum = 0
    for context, target in skipgram_data:
        X = one_hot_encoding(target, word_to_idx).float().unsqueeze(0)
        y_true = torch.tensor([word_to_idx[context]], dtype=torch.long)
        y_pred = skipgram_model(X)
        loss = criterion(y_pred, y_true)
        loss_sum += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f"Epoch:{epoch + 1}, Loss:{loss_sum / len(skipgram_data)}")
        loss_values.append(loss_sum / len(skipgram_data))




In [9]:
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = ["Arial Unicode MS"]
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

plt.plot(range(1, epochs // 100 + 1), loss_values)
plt.title("训练损失曲线")
plt.xlabel("轮次")
plt.ylabel("损失")
plt.show()

# 6. 展示词向量

In [10]:
print("SkimGram Embedding")
for word, idx in word_to_idx.items():
    print(f"{word}: {skipgram_model.input_to_hidden.weight[:, idx].detach().numpy()}")

In [11]:
fig, ax = plt.subplots()
for word, idx in word_to_idx.items():
    vec = skipgram_model.input_to_hidden.weight[:, idx].detach().numpy()
    ax.scatter(vec[0], vec[1])
    ax.annotate(word, (vec[0], vec[1]), fontsize=12)
plt.title("二维词嵌入")
plt.xlabel("向量维度 1")
plt.ylabel("向量维度 2")
plt.show()

# 7. 生成 CBOW 数据集

In [12]:
def create_cbow_dataset(sentences, window_size=2):
    data = []
    for sentence in sentences:
        sentence = sentence.split()
        for idx, word in enumerate(sentence):
            countext_words = sentence[max(idx - window_size, 0): idx] + sentence[idx + 1: min(idx + window_size + 1, len(sentence))]
            data.append((word, countext_words))
    return data

    

In [13]:
cbow_data = create_cbow_dataset(sentences)
print(f"Cbow DataSet: {cbow_data[:3]}")

# 8. 构建 CBOW 模型

In [14]:
import torch.nn as nn
class CBOW(nn.Module):
    def __init__(self, voc_size, embedding_size):
        super().__init__()
        self.input_to_hidden = nn.Linear(voc_size, embedding_size, bias=False)
        self.hidden_to_output = nn.Linear(embedding_size, voc_size, bias=False)
        
    def forward(self, X):
        embeddings = self.input_to_hidden(X)
        hidden_layer  = torch.mean(embeddings, dim=0)
        output_layer = self.hidden_to_output(hidden_layer.unsqueeze(0))
        return output_layer
    
        

In [15]:
embedding_size = 2
cbow_model = CBOW(voc_size, embedding_size)
print(f"CBOW model: {cbow_model}")

# 9. 训练 CBOW 模型

In [16]:
import torch.optim as optim

learning_rate = 0.001
epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cbow_model.parameters(), lr=learning_rate)
loss_values = []
for epoch in range(epochs):
    loss_sum = 0
    for target, context in cbow_data:
        X = torch.stack([one_hot_encoding(w, word_to_idx) for w in context]).float()
        y_true = torch.tensor([word_to_idx[target]], dtype=torch.long)
        y_pred = cbow_model(X)
        loss = criterion(y_pred, y_true)
        loss_sum += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f"Epoch:{epoch + 1}, Loss:{loss_sum / len(skipgram_data)}")
        loss_values.append(loss_sum / len(skipgram_data))

In [17]:

import matplotlib.pyplot as plt

plt.rcParams["font.family"] = ["Arial Unicode MS"]
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

plt.plot(range(1, epochs // 100 + 1), loss_values)
plt.title("训练损失曲线")
plt.xlabel("轮次")
plt.ylabel("损失")
plt.show()

In [18]:
fig, ax = plt.subplots()
for word, idx in word_to_idx.items():
    vec = skipgram_model.input_to_hidden.weight[:, idx].detach().numpy()
    ax.scatter(vec[0], vec[1])
    ax.annotate(word, (vec[0], vec[1]), fontsize=12)
plt.title("二维词嵌入")
plt.xlabel("向量维度 1")
plt.ylabel("向量维度 2")
plt.show()