In [123]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import pandas as pd
import random
import math
import re
import codecs

import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

In [124]:
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x1ba5c8c5670>

In [125]:
#设定一些hyper parameters
C = 3 #中心词周围的单词
K = 100 #负单词
NUM_EPOCH = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100 #词向量维度

In [166]:
with codecs.open("./data/text8.train.txt","r","utf-8") as fin:
    text = fin.read()
text = text.split()

In [168]:
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
vocab['<unk>'] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()]
word_to_index = {word:x for x,word in enumerate(idx_to_word)}

In [172]:
list(word_to_index.items())[-10:]

[('arrowhead', 29990),
 ('detainee', 29991),
 ('gabbro', 29992),
 ('hyperinsulinism', 29993),
 ('lantau', 29994),
 ('landsmannschaft', 29995),
 ('symmes', 29996),
 ('taif', 29997),
 ('meine', 29998),
 ('<unk>', 29999)]

In [224]:
word_counts = np.array([value for value in vocab.values()])
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs **(3./4.)
word_freqs = word_freqs / np.sum(word_freqs)
VOCAB_SIZE = len(idx_to_word)

array([1.62470093e-02, 1.05140095e-02, 8.04994631e-03, ...,
       5.01155864e-06, 5.01155864e-06, 1.16691942e-02])

In [225]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self,text,word_to_idx,idx_to_word,word_freqs,word_counts):
        super(WordEmbeddingDataset,self).__init__()
        self.text_encoded = [word_to_index.get(word,VOCAB_SIZE-1) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_index
        self.idx_to_word = idx_to_word
        self.word_freqs = word_freqs
        self.word_counts = word_counts
        
    def __len__(self):
        return len(self.text_encoded)
    def __getitem__(self,idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C,idx))+list(range(idx+1,idx+1+C))
        pos_indices = [i%len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        self.word_freqs = torch.tensor(self.word_freqs)
        neg_words = torch.multinomial(self.word_freqs,K*pos_words.shape[0],replacement=True)
        
        return center_word,pos_words,neg_words

In [226]:
dataset = WordEmbeddingDataset(text,word_to_index,idx_to_word,word_freqs,word_counts)

In [230]:
dataloader = tud.DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)

In [229]:
for i,(center_word,pos_words,neg_words) in enumerate(dataloader):
    print(center_word,pos_words,neg_words)
    if i>5:
        break



tensor([    9,    27,  8960,    10,     3,  4902,     8,    10,    57, 29999,
          594,  1745,   193,   116,  2163,    16,     1,  5471,    27,    18,
            1,  1977,    25,   279,     0,    62,     0,     9,  2497,     0,
            1,  1736,     0,     7,     6,  2248,     0,   654,   979,    23,
            1,  3225,    75,  9417,  4258,   255,     3,     9,   451,  2235,
          214,   110,   691,  1566,    11,   363,    11,  1548,   222,   676,
           10,     5,     7, 29999,     1,   387,   899,   945,     1,   320,
            7,    16, 17904,     1,     5,   249,    16,    36,  1563,     0,
          360,    28,     1,   144,  8707,     9,   589,    10,     0,     0,
          105,   384,  1490,  3738,     1,    34,     2,     4,   416,    78,
         4909,  3983,    23,   736,   301,  1498, 19716,    19,   479,     0,
           26,  7040, 26732,  6937,     1, 29999, 12634,  1608,    16,  2650,
          880,  5569,  2174, 29999, 11607,   333,    44,  6527])

tensor([  132,    31,  6762,   205,   215,   243,   240,    63,   245,     8,
           17,    10,    17,     0,   751,   503,   998,   531,    77,     6,
           66,   813,     9,   242,  3672,    15,  5439, 29999,     9,  3467,
         6986,  3123,     7,   479,  2053,    58,    22,    25, 29999,   734,
           71,     0,     4,    57,   641,    26, 29999, 19197,    68,     8,
          690,     0,    25, 12637,  7497,  5873,     9,     8,  6860,  6348,
            0,     6,   276, 29999,    27,    21,   137,  1487,   793,   452,
           11,    22,   379,     1,  4642,    15,  6683,   108,  3520,     3,
          362,    21,     5,    45,   389,  1386,     0,    16,  1519,    32,
          348,   113,    43,    32,    67,  2427, 29228,  9940,  1235,   160,
          397,  8861, 12776,    60, 20918,     6,    16,   142,    11,  4121,
           26, 24608,    13,     6, 29999,     3,   713,     6,   632,  8021,
         2785,   104,  3085,    18,  2070,  3795,     4,     9])

           14,   953, 14602,     0,  3085,    91,     2,     7]) tensor([[16437, 29999,     3,     9,     9,     3],
        [  258,     1,    16,   383, 29999,    24],
        [  430,    11,   373,     0,   222,  3737],
        [ 2715,     1,  8052,  6243,    33, 10511],
        [ 6148,   173,   244,   560,  1254,     1],
        [  151,   766,     4,    67,    37,  1812],
        [   69,   273,    31,  4789,   162,    24],
        [ 5591, 19034,   330,  5591,  6733,    40],
        [    0,  6652,   442,  7448,   442,     1],
        [  713,    11,    45,     2,   150,    45],
        [    8,     7,     8,  3763,     2,   244],
        [   18,     0, 15685,     0,  8729,     1],
        [  472,     0,   466,   113,    10,  1124],
        [15266,     1,     0,    51,  3257,    65],
        [   99,  1984,  6681,    50,  4383,   991],
        [ 7925,  2653,  1564,   815,   390,  2881],
        [ 6021,   325,     4,     9,     7,     7],
        [   18,  6676,   111,    51,  1670, 18960],

tensor([  857,     8,   352,     8,  3050,   112,     2,  2481,   359,    30,
           24,   435,  3404,     1, 29999,  5899,   214,    11,    49,  4670,
            0,  1037,  5790, 23125,     8,   582,     0,  5718,   857,     0,
           97,  5308,    27,  1436,     7,  5287,   119,  3123,     1, 17252,
            9, 27628,   169, 24884,   167,    17, 29999,     0,  2270,    13,
           35,  5276,   851,    32,   508,   179,   403,   125,  1090,    34,
            0,     3,  4359,  2608,     6,  7982,   755,   556,  5041,     2,
         1919,     1,   265,    18,    31,  9564,    54,    40,     1,    12,
          353,   196,  1069,     5,    53,  4670,  3104,  1268, 25328,     7,
           70, 29999,  2569,  3594,  8178,     1,   719,     4,     3,     2,
         7029, 11457,     0,     5,     2,   694,     1,  9239,  2760,   217,
         1468,     6,  2530,  2463,     2,  1232, 29999,   773,    46,     5,
        14526,     0,   461,  1274,    53,  1440,   424,   121])

In [231]:
class EmbeddingModel(nn.Module):
    def __init__(self,vocab_size,embed_size):
        super(EmbeddingModel,self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        self.in_embed = nn.Embedding(self.vocab_size,self.embed_size)
        self.out_embed = nn.Embedding(self.vocab_size,self.embed_size)
        
    def forward(self,input_labels,pos_labels,neg_labels):
        input_embedding = self.in_embed(input_labels)
        pos_embedding = self.in_embed(pos_labels)
        neg_embedding = self.in_embed(neg_labels)
        
        input_embedding = input_embedding.unsqueeze(2)
        pos_dot = torch.bmm(pos_embedding,input_embedding).squeeze(2)
        neg_dot = torch.bmm(neg_embedding,input_embedding).squeeze(2)
        
        log_pos = F.logsigmoid(pos_dot).sum(1)
        log_neg = F.logsigmoid(neg_dot).sum(1)
        
        loss = log_pos + log_neg
        
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.numpy()

In [232]:
model = EmbeddingModel(VOCAB_SIZE,EMBEDDING_SIZE)

In [None]:
optimizer = torch.optim.SGD(model.parameters(),lr=LEARNING_RATE)
for e in range(NUM_EPOCH):
    for i,(input_labels,pos_labels,neg_labels) in enumerate(dataloader):
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        optimizer.zero_grad()
        loss = model(input_labels,pos_labels,neg_labels).mean()
        loss.backward()
        optimizer.step()

