In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchtext
from torchtext.vocab import vocab, Vocab
from torchtext.data import get_tokenizer

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import random
from collections import Counter, OrderedDict

from tqdm import trange, tqdm

In [4]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [5]:
C = 3
K = 15
MAX_VOCAB_SIZE = 10000
EMBEDDING_SIZE = 100
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [6]:
with open('text8/text8.train.txt') as f:
    text = f.read()

In [7]:
type(text)

str

In [8]:
tokenizer = get_tokenizer("basic_english")
tokens:list[str] = tokenizer(text)

counter = Counter(tokens).most_common(MAX_VOCAB_SIZE - 1)
counter_sorted_by_freq = sorted(counter, key=lambda x: x[1], reverse=True)
counter_dict = OrderedDict(counter_sorted_by_freq)

vocabulary = vocab(counter_dict, specials=["<unk>"],special_first = False)
vocabulary.set_default_index(vocabulary["<unk>"])

word_counts = torch.tensor([count for count in counter_dict.values()], dtype=torch.float32)
word_freqs = word_counts / torch.sum(word_counts)
word_freqs = word_freqs ** (3./4.)

In [9]:
res = vocabulary(["apple", "hello"])

In [10]:
class WordEmbeddingDataset(data.Dataset):
    def __init__(self, text: list[str], vocab: Vocab, word_freqs: torch.Tensor):
        super(WordEmbeddingDataset, self).__init__()

        self.text_encoded = torch.tensor(vocab(text), dtype=torch.long)
        self.vocab = vocab
        self.word_freqs = word_freqs

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx: int):
        center_words = self.text_encoded[idx]

        pos_indices = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]

        pos_words_num = len(pos_indices)

        pos_words = self.text_encoded[pos_indices]

        neg_words = torch.multinomial(
            self.word_freqs, K * pos_words_num, replacement=True
        )
        
        while len(set(pos_indices) & set(neg_words.tolist())) > 0:
            neg_words = torch.multinomial(
                self.word_freqs, K * pos_words_num, replacement=True
            )

        return center_words, pos_words, neg_words


In [11]:
dataset = WordEmbeddingDataset(tokens, vocabulary, word_freqs)
dataloader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
dataset[0]

(tensor(4997),
 tensor([  18,   10, 9999, 3098,   11,    5]),
 tensor([   9, 7386,  177,  390, 9202, 6696, 3054,    8,    0, 1552,   34, 5140,
          281,  263, 3342,   50, 9694, 7388,  456,  350, 5468,    4, 6672, 3118,
           24, 2078,   16, 3906,   29, 2005,   71,  395,  224, 1062,   87, 6866,
         2570, 4883,  114, 3268, 3807,    0, 1033,  276, 1919,  532,  444, 1385,
         2344, 3691, 4442, 1154,  197,    5, 4330,  130, 7619, 1889, 2640, 9707,
         4630,  188,   29,  262, 3285,  137, 9952,  270, 2821,  479, 2818, 8008,
          917, 2605, 1747, 6108,  475, 9238, 4524, 9572, 6651,   14,   51, 7351,
         1489,   11,  812, 3627,   30, 6306]))

In [13]:
class Word2VectorModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int):
        super(Word2VectorModel, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        self.in_embedding = nn.Embedding(vocab_size, embedding_size)
        self.out_embedding = nn.Embedding(vocab_size, embedding_size)

    def forward(self, input_labels: torch.Tensor, pos_labels: torch.Tensor, neg_labels: torch.Tensor):
        # [batch_size, embed_size, 1] 中心词只有一个
        in_emb: torch.Tensor = self.in_embedding(input_labels).unsqueeze(2)
        
        # [batch_size, (C * 2), embedding_size] 相邻窗口词有C * 2个
        pos_emb = self.out_embedding(pos_labels)
        
        # [batch_size, (C * 2 * K), embedding_size] 负采样词有C * 2 * K个
        neg_emb = self.out_embedding(neg_labels)

        pos_dot = torch.bmm(pos_emb, in_emb).squeeze(2)

        neg_dot = torch.bmm(neg_emb, in_emb).squeeze(2)

        log_pos = F.logsigmoid(pos_dot).sum(1)

        log_neg = F.logsigmoid(neg_dot).sum(1)

        loss = log_pos + log_neg
        
        return -loss
    
    def input_embedding(self):
        return self.in_embedding.weight.detach()

model = Word2VectorModel(MAX_VOCAB_SIZE, EMBEDDING_SIZE).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
epoch_bar = trange(EPOCHS)
for e in epoch_bar:
    for i, (input_labels, pos_labels, neg_labels) in enumerate(tqdm(dataloader)):
        input_labels = input_labels.to(DEVICE)
        pos_labels = pos_labels.to(DEVICE)
        neg_labels = neg_labels.to(DEVICE)

        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()

        optimizer.step()

        if i % 100 == 0:
            epoch_bar.set_postfix(epoch=e + 1, iteration=i + 1, loss=loss.item())


In [None]:
embedding_weights = model.input_embedding()
torch.save(model.state_dict(), f"embedding-{EMBEDDING_SIZE}.th")

In [34]:
vocabulary(["hello"])

[6524]

In [None]:
# def find_nearest(word):
#     index = vocabulary([word])
#     embedding = embedding_weights[index]
#     cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
#     return [idx2word[i] for i in cos_dis.argsort()[:10]]