In [165]:
import os
import torch
import torch.nn as nn
from gensim.models import KeyedVectors
from gensim.models import Word2Vec
# 需命令行执行：
# pip install pytorch-nlp
import torchnlp
from torchnlp import word_to_vector
from torchtext import vocab 

# 自定义词嵌入(torch.nn.Embedding)

In [166]:
# 字典
word_to_idx = {'hello': 0, 'world': 1}
word_to_idx['hello'], word_to_idx['world']

(0, 1)

In [167]:
# 词嵌入，将单词表示成向量
# 2: num_embeddings，字典中词的个数
# 5: embedding_dim，embedding的维度
word_embedding = nn.Embedding(2, 5) 

In [168]:
lookup_tensor = torch.tensor([word_to_idx['hello']], dtype=torch.long)
lookup_tensor

tensor([0])

In [169]:
embed_result = word_embedding(lookup_tensor)
# embedding的维度为5
print(embed_result)

tensor([[-0.2567, -1.5409, -0.7245,  0.8061,  0.5982]],
       grad_fn=<EmbeddingBackward0>)


In [170]:
lookup_tensor = torch.tensor([word_to_idx['hello'], word_to_idx['world']], dtype=torch.long)

In [171]:
embed_result = word_embedding(lookup_tensor)
print(embed_result)

tensor([[-0.2567, -1.5409, -0.7245,  0.8061,  0.5982],
        [ 0.1217, -1.3175, -1.5157, -1.3376, -0.9417]],
       grad_fn=<EmbeddingBackward0>)


# Word2Vec词嵌入

## 预定义Word2Vec模型(gensim)

In [172]:
# word2vec文件路径
word2vec_file = './model/word2vec/GoogleNews-vectors-negative300.bin'
# 3.4GB文件
word_embedding = KeyedVectors.load_word2vec_format(word2vec_file, binary=True) 

In [173]:
word_embedding['hello']

array([-0.05419922,  0.01708984, -0.00527954,  0.33203125, -0.25      ,
       -0.01397705, -0.15039062, -0.265625  ,  0.01647949,  0.3828125 ,
       -0.03295898, -0.09716797, -0.16308594, -0.04443359,  0.00946045,
        0.18457031,  0.03637695,  0.16601562,  0.36328125, -0.25585938,
        0.375     ,  0.171875  ,  0.21386719, -0.19921875,  0.13085938,
       -0.07275391, -0.02819824,  0.11621094,  0.15332031,  0.09082031,
        0.06787109, -0.0300293 , -0.16894531, -0.20800781, -0.03710938,
       -0.22753906,  0.26367188,  0.012146  ,  0.18359375,  0.31054688,
       -0.10791016, -0.19140625,  0.21582031,  0.13183594, -0.03515625,
        0.18554688, -0.30859375,  0.04785156, -0.10986328,  0.14355469,
       -0.43554688, -0.0378418 ,  0.10839844,  0.140625  , -0.10595703,
        0.26171875, -0.17089844,  0.39453125,  0.12597656, -0.27734375,
       -0.28125   ,  0.14746094, -0.20996094,  0.02355957,  0.18457031,
        0.00445557, -0.27929688, -0.03637695, -0.29296875,  0.19

In [174]:
word_embedding['world'] 

array([-6.39648438e-02,  6.83593750e-02,  2.24609375e-01,  1.31835938e-01,
       -5.95703125e-02,  3.88183594e-02,  7.56835938e-02, -1.41601562e-01,
        7.08007812e-02,  1.51367188e-01, -8.83789062e-02, -6.88476562e-02,
        2.05078125e-01, -2.11914062e-01, -1.10351562e-01,  9.57031250e-02,
       -8.15429688e-02,  2.13867188e-01,  1.62109375e-01,  3.93066406e-02,
       -7.81250000e-02,  7.66601562e-02,  9.08203125e-02, -8.98437500e-02,
        9.13085938e-02,  2.14843750e-02, -2.61718750e-01,  6.05468750e-02,
       -5.59082031e-02, -1.35742188e-01, -1.33666992e-02, -5.24902344e-02,
       -1.85546875e-01,  7.61718750e-02,  1.03027344e-01, -3.05175781e-02,
       -1.60156250e-01, -8.10546875e-02,  1.36718750e-01,  1.43554688e-01,
       -2.57568359e-02,  1.64062500e-01, -4.51660156e-02,  2.89062500e-01,
        2.13867188e-01,  3.30078125e-01, -5.43212891e-03, -1.20544434e-03,
       -8.49609375e-02,  1.86523438e-01, -1.74804688e-01,  7.71484375e-02,
        9.64355469e-03, -

In [175]:
word_embedding.most_similar('obama', topn=5)

[('mccain', 0.7319012880325317),
 ('hillary', 0.7284600138664246),
 ('obamas', 0.7229632139205933),
 ('george_bush', 0.7205674648284912),
 ('barack_obama', 0.7045838832855225)]

## 自定义语料库训练Word2Vec模型(gensim)

In [176]:
# 准备一个样本语料库
sentences = [["I", "love", "natural", "language", "processing"],
             ["Word2Vec", "is", "a", "popular", "embedding", "model"],
             ["hello", "world"]]

In [177]:
# 训练Word2Vec模型
word_embedding = Word2Vec(sentences, min_count=1).wv

In [178]:
word_embedding["hello"]

array([-8.6196875e-03,  3.6657380e-03,  5.1898835e-03,  5.7419385e-03,
        7.4669183e-03, -6.1676754e-03,  1.1056137e-03,  6.0472824e-03,
       -2.8400505e-03, -6.1735227e-03, -4.1022300e-04, -8.3689485e-03,
       -5.6000124e-03,  7.1045388e-03,  3.3525396e-03,  7.2256695e-03,
        6.8002474e-03,  7.5307419e-03, -3.7891543e-03, -5.6180597e-04,
        2.3483764e-03, -4.5190323e-03,  8.3887316e-03, -9.8581640e-03,
        6.7646410e-03,  2.9144168e-03, -4.9328315e-03,  4.3981876e-03,
       -1.7395747e-03,  6.7113843e-03,  9.9648498e-03, -4.3624435e-03,
       -5.9933780e-04, -5.6956373e-03,  3.8508223e-03,  2.7866268e-03,
        6.8910765e-03,  6.1010956e-03,  9.5384968e-03,  9.2734173e-03,
        7.8980681e-03, -6.9895042e-03, -9.1558648e-03, -3.5575271e-04,
       -3.0998408e-03,  7.8943167e-03,  5.9385742e-03, -1.5456629e-03,
        1.5109634e-03,  1.7900408e-03,  7.8175711e-03, -9.5101865e-03,
       -2.0553112e-04,  3.4691966e-03, -9.3897223e-04,  8.3817719e-03,
      

In [179]:
word_embedding['world'] 

array([-5.3622725e-04,  2.3643136e-04,  5.1033497e-03,  9.0092728e-03,
       -9.3029495e-03, -7.1168090e-03,  6.4588725e-03,  8.9729885e-03,
       -5.0154282e-03, -3.7633716e-03,  7.3805046e-03, -1.5334714e-03,
       -4.5366134e-03,  6.5540518e-03, -4.8601604e-03, -1.8160177e-03,
        2.8765798e-03,  9.9187379e-04, -8.2852151e-03, -9.4488179e-03,
        7.3117660e-03,  5.0702621e-03,  6.7576934e-03,  7.6286553e-04,
        6.3508903e-03, -3.4053659e-03, -9.4640139e-04,  5.7685734e-03,
       -7.5216377e-03, -3.9361035e-03, -7.5115822e-03, -9.3004224e-04,
        9.5381187e-03, -7.3191668e-03, -2.3337686e-03, -1.9377411e-03,
        8.0774371e-03, -5.9308959e-03,  4.5162440e-05, -4.7537340e-03,
       -9.6035507e-03,  5.0072931e-03, -8.7595852e-03, -4.3918253e-03,
       -3.5099984e-05, -2.9618145e-04, -7.6612402e-03,  9.6147433e-03,
        4.9820580e-03,  9.2331432e-03, -8.1579173e-03,  4.4957981e-03,
       -4.1370760e-03,  8.2453608e-04,  8.4986202e-03, -4.4621765e-03,
      

In [180]:
word_embedding.most_similar('hello', topn=5)

[('embedding', 0.06797593832015991),
 ('processing', 0.03364058583974838),
 ('a', 0.009391162544488907),
 ('love', 0.008315935730934143),
 ('popular', 0.0045030261389911175)]

# GloVe词嵌入

## 预定义GloVe模型(gensim)

In [181]:
# glove文件路径
glove_file = './model/glove/glove.6B.100d.txt'
# 加载词向量模型
word_embedding = KeyedVectors.load_word2vec_format(glove_file, binary=False, no_header=True)

In [182]:
word_embedding['hello'] 

array([ 0.26688  ,  0.39632  ,  0.6169   , -0.77451  , -0.1039   ,
        0.26697  ,  0.2788   ,  0.30992  ,  0.0054685, -0.085256 ,
        0.73602  , -0.098432 ,  0.5479   , -0.030305 ,  0.33479  ,
        0.14094  , -0.0070003,  0.32569  ,  0.22902  ,  0.46557  ,
       -0.19531  ,  0.37491  , -0.7139   , -0.51775  ,  0.77039  ,
        1.0881   , -0.66011  , -0.16234  ,  0.9119   ,  0.21046  ,
        0.047494 ,  1.0019   ,  1.1133   ,  0.70094  , -0.08696  ,
        0.47571  ,  0.1636   , -0.44469  ,  0.4469   , -0.93817  ,
        0.013101 ,  0.085964 , -0.67456  ,  0.49662  , -0.037827 ,
       -0.11038  , -0.28612  ,  0.074606 , -0.31527  , -0.093774 ,
       -0.57069  ,  0.66865  ,  0.45307  , -0.34154  , -0.7166   ,
       -0.75273  ,  0.075212 ,  0.57903  , -0.1191   , -0.11379  ,
       -0.10026  ,  0.71341  , -1.1574   , -0.74026  ,  0.40452  ,
        0.18023  ,  0.21449  ,  0.37638  ,  0.11239  , -0.53639  ,
       -0.025092 ,  0.31886  , -0.25013  , -0.63283  , -0.0118

In [183]:
word_embedding['world'] 

array([ 4.9177e-01,  1.1164e+00,  1.1424e+00,  1.4381e-01, -1.0696e-01,
       -4.6727e-01, -4.4374e-01, -8.8024e-03, -5.0406e-01, -2.0549e-01,
        5.0910e-01, -6.0904e-01,  2.0980e-01, -4.4836e-01, -7.0383e-01,
        2.1516e-01,  6.6189e-01,  3.4620e-01, -8.9294e-01, -4.8032e-01,
        4.3069e-01,  3.5697e-01,  8.4277e-01,  5.2344e-01,  8.2065e-01,
        5.3183e-04,  2.4835e-01, -2.0887e-01,  8.1657e-01,  2.5048e-01,
       -7.4761e-01, -1.1309e-02, -4.7481e-01,  6.4520e-02,  5.4517e-01,
        2.0714e-01, -4.6237e-01,  1.0724e+00, -1.0526e+00, -1.5567e-01,
       -7.9339e-01, -2.8366e-02,  1.0138e-01, -2.0909e-01,  4.5513e-01,
        4.7330e-01,  6.8859e-01, -2.3840e-01, -5.5178e-02, -8.3022e-01,
       -4.7127e-01,  2.2713e-01,  4.2651e-02,  1.1273e+00, -8.4776e-02,
       -3.0378e+00, -1.8389e-01,  7.8244e-01,  1.6395e+00,  7.6146e-01,
       -1.4258e-01,  6.5115e-01, -1.3549e-02, -5.1465e-01,  6.6951e-01,
       -3.4464e-01, -1.4525e-01,  4.9258e-01,  8.0085e-01, -5.49

In [184]:
word_embedding.most_similar('obama', topn=5)

[('barack', 0.937216579914093),
 ('bush', 0.927285373210907),
 ('clinton', 0.896000325679779),
 ('mccain', 0.8875633478164673),
 ('gore', 0.8000321388244629)]

## 预定义GloVe模型(torchnlp)

In [185]:
word_embedding = torchnlp.word_to_vector.GloVe(name='6B', dim=300, cache="./model/glove")

In [186]:
word_embedding['hello']

tensor([-3.3712e-01, -2.1691e-01, -6.6365e-03, -4.1625e-01, -1.2555e+00,
        -2.8466e-02, -7.2195e-01, -5.2887e-01,  7.2085e-03,  3.1997e-01,
         2.9425e-02, -1.3236e-02,  4.3511e-01,  2.5716e-01,  3.8995e-01,
        -1.1968e-01,  1.5035e-01,  4.4762e-01,  2.8407e-01,  4.9339e-01,
         6.2826e-01,  2.2888e-01, -4.0385e-01,  2.7364e-02,  7.3679e-03,
         1.3995e-01,  2.3346e-01,  6.8122e-02,  4.8422e-01, -1.9578e-02,
        -5.4751e-01, -5.4983e-01, -3.4091e-02,  8.0017e-03, -4.3065e-01,
        -1.8969e-02, -8.5670e-02, -8.1123e-01, -2.1080e-01,  3.7784e-01,
        -3.5046e-01,  1.3684e-01, -5.5661e-01,  1.6835e-01, -2.2952e-01,
        -1.6184e-01,  6.7345e-01, -4.6597e-01, -3.1834e-02, -2.6037e-01,
        -1.7797e-01,  1.9436e-02,  1.0727e-01,  6.6534e-01, -3.4836e-01,
         4.7833e-02,  1.6440e-01,  1.4088e-01,  1.9204e-01, -3.5009e-01,
         2.6236e-01,  1.7626e-01, -3.1367e-01,  1.1709e-01,  2.0378e-01,
         6.1775e-01,  4.9075e-01, -7.5210e-02, -1.1

In [188]:
word_embedding['world']

tensor([-0.2583,  0.4364, -0.1138, -0.5259,  0.2021,  0.9525, -0.5876, -0.0470,
        -0.0537, -1.7440,  0.9958,  0.0635, -0.0931, -0.2644, -0.2868, -0.5236,
        -0.1787,  0.1817, -0.7170, -0.1330,  0.4248,  0.4204,  0.3775,  0.0824,
         0.1315, -0.1015, -0.1190,  0.0295, -0.3963,  0.2652, -0.5509,  0.2380,
        -0.0187, -0.0399, -1.1972,  0.1357,  0.0937, -0.6013,  0.1289,  0.3488,
        -0.2559, -0.3347,  0.0697,  0.5429,  0.2525,  0.1725,  0.0999,  0.0995,
        -0.0159,  0.2617,  0.3616, -0.1242,  0.2752,  0.0374, -0.0750,  0.6110,
         0.0526,  0.0173,  0.1258, -0.1195, -0.4908,  0.0267, -0.2719, -0.1527,
        -0.2215,  0.1813, -0.0453,  0.7615,  0.1749, -0.4411,  0.0273,  0.4268,
        -0.0070, -0.6023, -0.0166,  0.1842,  0.0218, -0.3418, -0.5515,  0.3501,
         0.4214, -0.2679, -0.1804,  0.0532,  0.1408,  0.2905,  0.1520,  0.0144,
         0.3866,  0.0303, -0.1470, -0.0361,  0.2735, -0.2179,  0.1987,  0.1249,
        -0.0498,  0.4140, -0.1476, -0.41

## 预定义GloVe模型(torchtext)

In [189]:
word_embedding = vocab.GloVe(name='6B', dim=300, cache="./model/glove")

In [190]:
word_embedding['hello']

tensor([-3.3712e-01, -2.1691e-01, -6.6365e-03, -4.1625e-01, -1.2555e+00,
        -2.8466e-02, -7.2195e-01, -5.2887e-01,  7.2085e-03,  3.1997e-01,
         2.9425e-02, -1.3236e-02,  4.3511e-01,  2.5716e-01,  3.8995e-01,
        -1.1968e-01,  1.5035e-01,  4.4762e-01,  2.8407e-01,  4.9339e-01,
         6.2826e-01,  2.2888e-01, -4.0385e-01,  2.7364e-02,  7.3679e-03,
         1.3995e-01,  2.3346e-01,  6.8122e-02,  4.8422e-01, -1.9578e-02,
        -5.4751e-01, -5.4983e-01, -3.4091e-02,  8.0017e-03, -4.3065e-01,
        -1.8969e-02, -8.5670e-02, -8.1123e-01, -2.1080e-01,  3.7784e-01,
        -3.5046e-01,  1.3684e-01, -5.5661e-01,  1.6835e-01, -2.2952e-01,
        -1.6184e-01,  6.7345e-01, -4.6597e-01, -3.1834e-02, -2.6037e-01,
        -1.7797e-01,  1.9436e-02,  1.0727e-01,  6.6534e-01, -3.4836e-01,
         4.7833e-02,  1.6440e-01,  1.4088e-01,  1.9204e-01, -3.5009e-01,
         2.6236e-01,  1.7626e-01, -3.1367e-01,  1.1709e-01,  2.0378e-01,
         6.1775e-01,  4.9075e-01, -7.5210e-02, -1.1

In [191]:
word_embedding['world']

tensor([-0.2583,  0.4364, -0.1138, -0.5259,  0.2021,  0.9525, -0.5876, -0.0470,
        -0.0537, -1.7440,  0.9958,  0.0635, -0.0931, -0.2644, -0.2868, -0.5236,
        -0.1787,  0.1817, -0.7170, -0.1330,  0.4248,  0.4204,  0.3775,  0.0824,
         0.1315, -0.1015, -0.1190,  0.0295, -0.3963,  0.2652, -0.5509,  0.2380,
        -0.0187, -0.0399, -1.1972,  0.1357,  0.0937, -0.6013,  0.1289,  0.3488,
        -0.2559, -0.3347,  0.0697,  0.5429,  0.2525,  0.1725,  0.0999,  0.0995,
        -0.0159,  0.2617,  0.3616, -0.1242,  0.2752,  0.0374, -0.0750,  0.6110,
         0.0526,  0.0173,  0.1258, -0.1195, -0.4908,  0.0267, -0.2719, -0.1527,
        -0.2215,  0.1813, -0.0453,  0.7615,  0.1749, -0.4411,  0.0273,  0.4268,
        -0.0070, -0.6023, -0.0166,  0.1842,  0.0218, -0.3418, -0.5515,  0.3501,
         0.4214, -0.2679, -0.1804,  0.0532,  0.1408,  0.2905,  0.1520,  0.0144,
         0.3866,  0.0303, -0.1470, -0.0361,  0.2735, -0.2179,  0.1987,  0.1249,
        -0.0498,  0.4140, -0.1476, -0.41

In [192]:
word_embedding.get_vecs_by_tokens('hello')

tensor([-3.3712e-01, -2.1691e-01, -6.6365e-03, -4.1625e-01, -1.2555e+00,
        -2.8466e-02, -7.2195e-01, -5.2887e-01,  7.2085e-03,  3.1997e-01,
         2.9425e-02, -1.3236e-02,  4.3511e-01,  2.5716e-01,  3.8995e-01,
        -1.1968e-01,  1.5035e-01,  4.4762e-01,  2.8407e-01,  4.9339e-01,
         6.2826e-01,  2.2888e-01, -4.0385e-01,  2.7364e-02,  7.3679e-03,
         1.3995e-01,  2.3346e-01,  6.8122e-02,  4.8422e-01, -1.9578e-02,
        -5.4751e-01, -5.4983e-01, -3.4091e-02,  8.0017e-03, -4.3065e-01,
        -1.8969e-02, -8.5670e-02, -8.1123e-01, -2.1080e-01,  3.7784e-01,
        -3.5046e-01,  1.3684e-01, -5.5661e-01,  1.6835e-01, -2.2952e-01,
        -1.6184e-01,  6.7345e-01, -4.6597e-01, -3.1834e-02, -2.6037e-01,
        -1.7797e-01,  1.9436e-02,  1.0727e-01,  6.6534e-01, -3.4836e-01,
         4.7833e-02,  1.6440e-01,  1.4088e-01,  1.9204e-01, -3.5009e-01,
         2.6236e-01,  1.7626e-01, -3.1367e-01,  1.1709e-01,  2.0378e-01,
         6.1775e-01,  4.9075e-01, -7.5210e-02, -1.1