In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np
import random
import math
import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
import utils

# 读取数据集

In [8]:
with open('text8.train.txt','r') as fin:
    text = fin.read()

In [9]:
def word_tokenize(text):
    return text.split()

words = text.split()

In [10]:
words[:10]

['anarchism',
 'originated',
 'as',
 'a',
 'term',
 'of',
 'abuse',
 'first',
 'used',
 'against']

In [11]:
print("Total words: {}".format(len(words)))
print("Unique words: {}".format(len(set(words))))

Total words: 15304686
Unique words: 240006


In [32]:
vocab = dict(Counter(words))
vocab

{'anarchism': 303,
 'originated': 505,
 'as': 119348,
 'a': 293387,
 'term': 6533,
 'of': 537144,
 'abuse': 531,
 'first': 25760,
 'used': 20455,
 'against': 7640,
 'early': 9129,
 'working': 2054,
 'class': 3111,
 'radicals': 106,
 'including': 8694,
 'the': 959616,
 'diggers': 25,
 'english': 10582,
 'revolution': 1862,
 'and': 376233,
 'sans': 54,
 'culottes': 6,
 'french': 7750,
 'whilst': 442,
 'is': 166132,
 'still': 6664,
 'in': 335477,
 'pejorative': 104,
 'way': 5793,
 'to': 285950,
 'describe': 1203,
 'any': 10796,
 'act': 3216,
 'that': 99419,
 'violent': 581,
 'means': 3822,
 'destroy': 423,
 'organization': 2147,
 'society': 3745,
 'it': 66521,
 'has': 34244,
 'also': 40020,
 'been': 22950,
 'taken': 2795,
 'up': 11199,
 'positive': 1157,
 'label': 549,
 'by': 100994,
 'self': 2614,
 'defined': 2192,
 'anarchists': 200,
 'word': 5066,
 'derived': 1533,
 'from': 65835,
 'greek': 4197,
 'without': 5152,
 'archons': 10,
 'ruler': 542,
 'chief': 1933,
 'king': 6641,
 'politica

# 二次采样

In [43]:
np.random.seed(123)
def sub_sampling(vocab, threshold):
    word_counter = Counter(vocab)
    total_count = len(vocab)
    freq= {word: count/total_count for word, count in vocab.items()}
    p_drop ={word:(1-np.sqrt(threshold/ freq[word])) for word in word_counter}
    trained_words = {word:vocab[word] for word in vocab if p_drop[word]<np.random.random()}
    return trained_words

train_words = sub_sampling(vocab, threshold= 1e-4)

In [58]:
train_words['<unk>'] = np.sum(list(vocab.values())) - np.sum(list(train_words.values()))
train_words

{'abuse': 531,
 'diggers': 25,
 'sans': 54,
 'culottes': 6,
 'society': 3745,
 'by': 100994,
 'archons': 10,
 'differing': 206,
 'authoritarian': 168,
 'anarchy': 95,
 'imply': 231,
 'chaos': 309,
 'nihilism': 33,
 'anomie': 6,
 'harmonious': 26,
 'governance': 160,
 'believe': 2186,
 'truly': 366,
 'how': 4266,
 'respect': 902,
 'argue': 969,
 'anthropologists': 150,
 'gatherer': 32,
 'bands': 686,
 'egalitarian': 55,
 'decreed': 64,
 'godwin': 74,
 'organisation': 507,
 'rothbard': 62,
 'zeno': 41,
 'citium': 7,
 'omnipotence': 20,
 'regimentation': 7,
 'anabaptists': 74,
 'religious': 3256,
 'forerunners': 23,
 'writes': 352,
 'holy': 1889,
 'premise': 140,
 'arrive': 207,
 'levellers': 11,
 'communistic': 7,
 'something': 1568,
 'armand': 63,
 'baron': 328,
 'lahontan': 2,
 'nouveaux': 12,
 'voyages': 97,
 'dans': 32,
 'rique': 5,
 'septentrionale': 2,
 'where': 11121,
 'laws': 2258,
 'enquiry': 54,
 'justice': 1238,
 'major': 6758,
 'philosophical': 879,
 'anarchiste': 2,
 'mainly

In [69]:
print("oringal words: {}".format(len(vocab)))
print("update words: {}".format(len(train_words)))

oringal words: 240006
update words: 228213


In [59]:
idx_to_word = [word for word in train_words.keys()]
word_to_index = {word:i for i, word in enumerate(idx_to_word)}
word_to_index['<unk>']

228212

In [60]:
word_counts = np.array([count for count in train_words.values()], 
                        dtype = np.float32)

word_freqs = word_counts/ np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_counts/ np.sum(word_counts)

In [63]:
word_counts

array([5.310000e+02, 2.500000e+01, 5.400000e+01, ..., 1.000000e+00,
       1.000000e+00, 1.506468e+07], dtype=float32)

In [61]:
word_freqs

array([3.0335827e-05, 1.4282404e-06, 3.0849992e-06, ..., 5.7129615e-08,
       5.7129615e-08, 8.6063939e-01], dtype=float32)

# 构造batch

In [62]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, 
                  word_freqs, word_counts):
        super(WordEmbeddingDataset,self).__init__()
        self.text_encoded = [word_to_idx.get(word, word_to_idx['<unk>']) 
                             for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)
    
    def __len__(self):#这个数据集一共有多少个item
        return len(self.text_encoded)
    
    def __getitem__(self, idx):#给定一个index返回一个item
        #给定一个idx返回对应中心词
        center_word = self.text_encoded[idx]
        #window内单词的index
        pos_indices = list(range(idx-c, idx)) + list(range(idx+1, idx+c+1))
        #取余数防止超出text长度
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        neg_words = torch.multinomial(self.word_freqs, 
                                      K*pos_words.shape[0],
                                      True)#采样方法——伯努利,依据单词的频率采样
        return center_word, pos_words, neg_words 

In [66]:
#周围的含义
c = 3 #context window
K = 100 #number of negative samples
num_epochs = 10
batch_size = 32
learning_rate = 0.2
embedding_size = 100  #词向量维度

dataset = WordEmbeddingDataset(words, word_to_index, idx_to_word, 
                               word_freqs, word_counts)
#shuffle() 方法将序列的所有元素随机排序
dataloader = tud.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = 0)

In [67]:
for i, (center_word, pos_words, neg_words) in enumerate(dataloader):
    print(center_word, pos_words, neg_words)
    if i>5:
        break

tensor([ 13126, 228212, 228212, 228212,  45359, 228212, 228212, 228212, 228212,
           878, 228212, 228212, 228212, 228212, 228212, 228212, 228212, 228212,
         49609, 228212, 228212,      5, 213452, 228212, 228212, 228212, 228212,
        228212, 228212,    151, 228212, 228212]) tensor([[ 13125, 228212, 228212,  13127,  11712,  13128],
        [228212, 228212,  38971, 228212, 228212, 228212],
        [228212, 228212, 228212, 228212,     52, 228212],
        [228212, 228212, 228212, 228212, 228212, 228212],
        [228212, 228212,   2970, 228212, 228212, 228212],
        [228212, 228212, 228212,   2070, 228212, 156849],
        [228212, 228212, 228212, 228212, 228212, 228212],
        [228212,   7883,   7851, 228212, 228212, 228212],
        [  2042, 228212, 228212, 228212, 228212,  14538],
        [228212,      5, 228212, 228212, 228212, 228212],
        [228212, 228212, 228212, 228212, 228212, 228212],
        [228212, 228212, 228212, 228212, 228212, 228212],
        [228212

# 构建模型

In [70]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        ''' 初始化输出和输出embedding
        '''
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        initrange = 0.5 / self.embed_size
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.out_embed.weight.data.uniform_(-initrange, initrange)
        
        
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.in_embed.weight.data.uniform_(-initrange, initrange)
        
        
    def forward(self, input_labels, pos_labels, neg_labels):
        '''
        input_labels: 中心词, [batch_size]
        pos_labels: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)]
        neg_labelss: 中心词周围没有出现过的单词，从 negative sampling 得到 [batch_size, (window_size * 2 * K)]
        
        return: loss, [batch_size]
        '''
        
        batch_size = input_labels.size(0)
        
        input_embedding = self.in_embed(input_labels) # B * embed_size
        pos_embedding = self.out_embed(pos_labels) # B * (2*C) * embed_size
        neg_embedding = self.out_embed(neg_labels) # B * (2*C * K) * embed_size
      
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # B * (2*C)
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # B * (2*C*K)

        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1) # batch_size
       
        loss = log_pos + log_neg
        
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

model.input_embeddings() 是我们最终需要的embedding！

# 模型训练

In [71]:
vocab_size = len(idx_to_word)

In [78]:
SEED = 1234

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EmbeddingModel(vocab_size, embedding_size)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [79]:
model

EmbeddingModel(
  (out_embed): Embedding(228213, 100)
  (in_embed): Embedding(228213, 100)
)

In [80]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [81]:
count_parameters(model) 

45642600

评估两个单词的相似性

In [82]:
def evaluate(filename, embedding_weights): 
    if filename.endswith(".csv"):
        data = pd.read_csv(filename, sep=",")
    else:
        data = pd.read_csv(filename, sep="\t")
    human_similarity = []
    model_similarity = []
    for i in data.iloc[:, 0:2].index:
        word1, word2 = data.iloc[i, 0], data.iloc[i, 1]
        if word1 not in word_to_idx or word2 not in word_to_idx:
            continue
        else:
            word1_idx, word2_idx = word_to_idx[word1], word_to_idx[word2]
            word1_embed, word2_embed = embedding_weights[[word1_idx]], embedding_weights[[word2_idx]]
            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed, word2_embed)))
            human_similarity.append(float(data.iloc[i, 2]))

    return scipy.stats.spearmanr(human_similarity, model_similarity)# , model_similarity

寻找到与word最接近的10个单词

In [83]:
def find_nearest(word):
    index = word_to_idx[word]
    embedding = embedding_weights[index]
    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    return [idx_to_word[i] for i in cos_dis.argsort()[:10]]

In [None]:
for e in range(num_epochs):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        
        input_labels = input_labels.long().to(device)
        pos_labels = pos_labels.long().to(device)
        neg_labels = neg_labels.long().to(device)
            
        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            with open(LOG_FILE, "a") as fout:
                fout.write("epoch: {}, iter: {}, loss: {}\n".format(e, i, loss.item()))
                print("epoch: {}, iter: {}, loss: {}".format(e, i, loss.item()))
            
        
        if i % 2000 == 0:
            embedding_weights = model.input_embeddings()
            sim_simlex = evaluate("simlex-999.txt", embedding_weights)
            sim_men = evaluate("men.txt", embedding_weights)
            sim_353 = evaluate("wordsim353.csv", embedding_weights)
            with open(LOG_FILE, "a") as fout:
                print("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
                    e, i, sim_simlex, sim_men, sim_353, find_nearest("monster")))
                fout.write("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
                    e, i, sim_simlex, sim_men, sim_353, find_nearest("monster")))
                
    embedding_weights = model.input_embeddings()
    np.save("embedding-{}".format(embedding_size), embedding_weights)
    torch.save(model.state_dict(), "embedding-{}.th".format(embedding_size))