# 4章 word2vecの高速化

## 4.3 改良版word2vecの学習

### 4.3.1 CBOWモデルの学習

In [1]:
import sys
sys.path.append('..')
from common.np import *  # import numpy as np
from common.layers import Embedding
from ch04.negative_sampling_layer import NegativeSamplingLoss


class CBOW:
    def __init__(self, vocab_size, hidden_size, window_size, corpus):
        V, H = vocab_size, hidden_size

        # 重みの初期化
        W_in = 0.01 * np.random.randn(V, H).astype('f')
        W_out = 0.01 * np.random.randn(V, H).astype('f')

        # レイヤの生成
        self.in_layers = []
        for i in range(2 * window_size):
            layer = Embedding(W_in)  # Embeddingレイヤを使用
            self.in_layers.append(layer)
        self.ns_loss = NegativeSamplingLoss(W_out, corpus, power=0.75, sample_size=5)

        # すべての重みと勾配をリストにまとめる
        layers = self.in_layers + [self.ns_loss]
        self.params, self.grads = [], []
        for layer in layers:
            self.params += layer.params
            self.grads += layer.grads

        # メンバ変数に単語の分散表現を設定
        self.word_vecs = W_in

    def forward(self, contexts, target):
        h = 0
        for i, layer in enumerate(self.in_layers):
            h += layer.forward(contexts[:, i])
        h *= 1 / len(self.in_layers)
        loss = self.ns_loss.forward(h, target)
        return loss

    def backward(self, dout=1):
        dout = self.ns_loss.backward(dout)
        dout *= 1 / len(self.in_layers)
        for layer in self.in_layers:
            layer.backward(dout)
        return None


### 4.3.2 CBOWモデルの学習コード

In [2]:
import sys
sys.path.append('..')
from common import config
# GPUで実行する場合は、下記のコメントアウトを消去（要cupy）
# ===============================================
# config.GPU = True
# ===============================================
from common.np import *
import pickle
from common.trainer import Trainer
from common.optimizer import Adam
from cbow import CBOW
from skip_gram import SkipGram
from common.util import create_contexts_target, to_cpu, to_gpu
from dataset import ptb


# ハイパーパラメータの設定
window_size = 5
hidden_size = 100
batch_size = 100
max_epoch = 10

# データの読み込み
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)

contexts, target = create_contexts_target(corpus, window_size)
if config.GPU:
    contexts, target = to_gpu(contexts), to_gpu(target)

# モデルなどの生成
model = CBOW(vocab_size, hidden_size, window_size, corpus)
# model = SkipGram(vocab_size, hidden_size, window_size, corpus)
optimizer = Adam()
trainer = Trainer(model, optimizer)

# 学習開始
trainer.fit(contexts, target, max_epoch, batch_size)
trainer.plot()

# 後ほど利用できるように、必要なデータを保存
word_vecs = model.word_vecs
if config.GPU:
    word_vecs = to_cpu(word_vecs)
params = {}
params['word_vecs'] = word_vecs.astype(np.float16)
params['word_to_id'] = word_to_id
params['id_to_word'] = id_to_word
pkl_file = 'cbow_params.pkl'  # or 'skipgram_params.pkl'
with open(pkl_file, 'wb') as f:
    pickle.dump(params, f, -1)


| epoch 1 |  iter 1 / 9295 | time 0[s] | loss 4.16
| epoch 1 |  iter 21 / 9295 | time 2[s] | loss 4.16
| epoch 1 |  iter 41 / 9295 | time 4[s] | loss 4.15
| epoch 1 |  iter 61 / 9295 | time 5[s] | loss 4.12
| epoch 1 |  iter 81 / 9295 | time 7[s] | loss 4.05
| epoch 1 |  iter 101 / 9295 | time 9[s] | loss 3.92
| epoch 1 |  iter 121 / 9295 | time 11[s] | loss 3.78
| epoch 1 |  iter 141 / 9295 | time 13[s] | loss 3.62
| epoch 1 |  iter 161 / 9295 | time 15[s] | loss 3.48
| epoch 1 |  iter 181 / 9295 | time 17[s] | loss 3.35
| epoch 1 |  iter 201 / 9295 | time 19[s] | loss 3.26
| epoch 1 |  iter 221 / 9295 | time 21[s] | loss 3.15
| epoch 1 |  iter 241 / 9295 | time 22[s] | loss 3.08
| epoch 1 |  iter 261 / 9295 | time 24[s] | loss 3.01
| epoch 1 |  iter 281 / 9295 | time 26[s] | loss 2.98
| epoch 1 |  iter 301 / 9295 | time 28[s] | loss 2.93
| epoch 1 |  iter 321 / 9295 | time 30[s] | loss 2.88
| epoch 1 |  iter 341 / 9295 | time 32[s] | loss 2.84
| epoch 1 |  iter 361 / 9295 | time 34[s

| epoch 1 |  iter 2981 / 9295 | time 276[s] | loss 2.45
| epoch 1 |  iter 3001 / 9295 | time 278[s] | loss 2.47
| epoch 1 |  iter 3021 / 9295 | time 280[s] | loss 2.44
| epoch 1 |  iter 3041 / 9295 | time 281[s] | loss 2.46
| epoch 1 |  iter 3061 / 9295 | time 283[s] | loss 2.45
| epoch 1 |  iter 3081 / 9295 | time 285[s] | loss 2.43
| epoch 1 |  iter 3101 / 9295 | time 287[s] | loss 2.45
| epoch 1 |  iter 3121 / 9295 | time 289[s] | loss 2.44
| epoch 1 |  iter 3141 / 9295 | time 291[s] | loss 2.43
| epoch 1 |  iter 3161 / 9295 | time 293[s] | loss 2.42
| epoch 1 |  iter 3181 / 9295 | time 295[s] | loss 2.46
| epoch 1 |  iter 3201 / 9295 | time 298[s] | loss 2.43
| epoch 1 |  iter 3221 / 9295 | time 301[s] | loss 2.45
| epoch 1 |  iter 3241 / 9295 | time 303[s] | loss 2.44
| epoch 1 |  iter 3261 / 9295 | time 305[s] | loss 2.44
| epoch 1 |  iter 3281 / 9295 | time 307[s] | loss 2.43
| epoch 1 |  iter 3301 / 9295 | time 308[s] | loss 2.43
| epoch 1 |  iter 3321 / 9295 | time 311[s] | lo

| epoch 1 |  iter 5921 / 9295 | time 551[s] | loss 2.30
| epoch 1 |  iter 5941 / 9295 | time 553[s] | loss 2.33
| epoch 1 |  iter 5961 / 9295 | time 555[s] | loss 2.31
| epoch 1 |  iter 5981 / 9295 | time 558[s] | loss 2.34
| epoch 1 |  iter 6001 / 9295 | time 560[s] | loss 2.34
| epoch 1 |  iter 6021 / 9295 | time 562[s] | loss 2.30
| epoch 1 |  iter 6041 / 9295 | time 564[s] | loss 2.32
| epoch 1 |  iter 6061 / 9295 | time 567[s] | loss 2.31
| epoch 1 |  iter 6081 / 9295 | time 569[s] | loss 2.32
| epoch 1 |  iter 6101 / 9295 | time 571[s] | loss 2.32
| epoch 1 |  iter 6121 / 9295 | time 573[s] | loss 2.31
| epoch 1 |  iter 6141 / 9295 | time 575[s] | loss 2.32
| epoch 1 |  iter 6161 / 9295 | time 577[s] | loss 2.28
| epoch 1 |  iter 6181 / 9295 | time 579[s] | loss 2.32
| epoch 1 |  iter 6201 / 9295 | time 582[s] | loss 2.29
| epoch 1 |  iter 6221 / 9295 | time 584[s] | loss 2.28
| epoch 1 |  iter 6241 / 9295 | time 586[s] | loss 2.33
| epoch 1 |  iter 6261 / 9295 | time 589[s] | lo

| epoch 1 |  iter 8861 / 9295 | time 849[s] | loss 2.24
| epoch 1 |  iter 8881 / 9295 | time 852[s] | loss 2.20
| epoch 1 |  iter 8901 / 9295 | time 854[s] | loss 2.24
| epoch 1 |  iter 8921 / 9295 | time 856[s] | loss 2.19
| epoch 1 |  iter 8941 / 9295 | time 857[s] | loss 2.25
| epoch 1 |  iter 8961 / 9295 | time 859[s] | loss 2.21
| epoch 1 |  iter 8981 / 9295 | time 861[s] | loss 2.23
| epoch 1 |  iter 9001 / 9295 | time 863[s] | loss 2.25
| epoch 1 |  iter 9021 / 9295 | time 864[s] | loss 2.24
| epoch 1 |  iter 9041 / 9295 | time 866[s] | loss 2.22
| epoch 1 |  iter 9061 / 9295 | time 868[s] | loss 2.19
| epoch 1 |  iter 9081 / 9295 | time 870[s] | loss 2.22
| epoch 1 |  iter 9101 / 9295 | time 871[s] | loss 2.21
| epoch 1 |  iter 9121 / 9295 | time 873[s] | loss 2.22
| epoch 1 |  iter 9141 / 9295 | time 875[s] | loss 2.21
| epoch 1 |  iter 9161 / 9295 | time 877[s] | loss 2.19
| epoch 1 |  iter 9181 / 9295 | time 878[s] | loss 2.20
| epoch 1 |  iter 9201 / 9295 | time 880[s] | lo

| epoch 2 |  iter 2501 / 9295 | time 1137[s] | loss 2.10
| epoch 2 |  iter 2521 / 9295 | time 1139[s] | loss 2.12
| epoch 2 |  iter 2541 / 9295 | time 1141[s] | loss 2.14
| epoch 2 |  iter 2561 / 9295 | time 1143[s] | loss 2.16
| epoch 2 |  iter 2581 / 9295 | time 1145[s] | loss 2.11
| epoch 2 |  iter 2601 / 9295 | time 1147[s] | loss 2.11
| epoch 2 |  iter 2621 / 9295 | time 1148[s] | loss 2.11
| epoch 2 |  iter 2641 / 9295 | time 1150[s] | loss 2.14
| epoch 2 |  iter 2661 / 9295 | time 1152[s] | loss 2.13
| epoch 2 |  iter 2681 / 9295 | time 1154[s] | loss 2.14
| epoch 2 |  iter 2701 / 9295 | time 1156[s] | loss 2.11
| epoch 2 |  iter 2721 / 9295 | time 1158[s] | loss 2.13
| epoch 2 |  iter 2741 / 9295 | time 1160[s] | loss 2.11
| epoch 2 |  iter 2761 / 9295 | time 1162[s] | loss 2.11
| epoch 2 |  iter 2781 / 9295 | time 1163[s] | loss 2.14
| epoch 2 |  iter 2801 / 9295 | time 1165[s] | loss 2.09
| epoch 2 |  iter 2821 / 9295 | time 1167[s] | loss 2.13
| epoch 2 |  iter 2841 / 9295 |

| epoch 2 |  iter 5381 / 9295 | time 1422[s] | loss 2.11
| epoch 2 |  iter 5401 / 9295 | time 1424[s] | loss 2.09
| epoch 2 |  iter 5421 / 9295 | time 1426[s] | loss 2.10
| epoch 2 |  iter 5441 / 9295 | time 1428[s] | loss 2.08
| epoch 2 |  iter 5461 / 9295 | time 1429[s] | loss 2.08
| epoch 2 |  iter 5481 / 9295 | time 1431[s] | loss 2.02
| epoch 2 |  iter 5501 / 9295 | time 1433[s] | loss 2.07
| epoch 2 |  iter 5521 / 9295 | time 1435[s] | loss 2.06
| epoch 2 |  iter 5541 / 9295 | time 1437[s] | loss 2.07
| epoch 2 |  iter 5561 / 9295 | time 1439[s] | loss 2.05
| epoch 2 |  iter 5581 / 9295 | time 1441[s] | loss 2.07
| epoch 2 |  iter 5601 / 9295 | time 1443[s] | loss 2.09
| epoch 2 |  iter 5621 / 9295 | time 1445[s] | loss 2.08
| epoch 2 |  iter 5641 / 9295 | time 1447[s] | loss 2.06
| epoch 2 |  iter 5661 / 9295 | time 1449[s] | loss 2.09
| epoch 2 |  iter 5681 / 9295 | time 1451[s] | loss 2.07
| epoch 2 |  iter 5701 / 9295 | time 1452[s] | loss 2.11
| epoch 2 |  iter 5721 / 9295 |

| epoch 2 |  iter 8261 / 9295 | time 1700[s] | loss 2.05
| epoch 2 |  iter 8281 / 9295 | time 1702[s] | loss 2.03
| epoch 2 |  iter 8301 / 9295 | time 1704[s] | loss 2.02
| epoch 2 |  iter 8321 / 9295 | time 1706[s] | loss 2.02
| epoch 2 |  iter 8341 / 9295 | time 1708[s] | loss 2.03
| epoch 2 |  iter 8361 / 9295 | time 1710[s] | loss 2.00
| epoch 2 |  iter 8381 / 9295 | time 1712[s] | loss 2.03
| epoch 2 |  iter 8401 / 9295 | time 1713[s] | loss 2.02
| epoch 2 |  iter 8421 / 9295 | time 1716[s] | loss 2.01
| epoch 2 |  iter 8441 / 9295 | time 1718[s] | loss 2.03
| epoch 2 |  iter 8461 / 9295 | time 1720[s] | loss 2.01
| epoch 2 |  iter 8481 / 9295 | time 1722[s] | loss 2.02
| epoch 2 |  iter 8501 / 9295 | time 1724[s] | loss 2.04
| epoch 2 |  iter 8521 / 9295 | time 1726[s] | loss 2.02
| epoch 2 |  iter 8541 / 9295 | time 1729[s] | loss 2.06
| epoch 2 |  iter 8561 / 9295 | time 1731[s] | loss 2.00
| epoch 2 |  iter 8581 / 9295 | time 1733[s] | loss 1.99
| epoch 2 |  iter 8601 / 9295 |

| epoch 3 |  iter 1861 / 9295 | time 2010[s] | loss 1.91
| epoch 3 |  iter 1881 / 9295 | time 2012[s] | loss 1.96
| epoch 3 |  iter 1901 / 9295 | time 2013[s] | loss 1.91
| epoch 3 |  iter 1921 / 9295 | time 2015[s] | loss 1.91
| epoch 3 |  iter 1941 / 9295 | time 2017[s] | loss 1.94
| epoch 3 |  iter 1961 / 9295 | time 2019[s] | loss 1.93
| epoch 3 |  iter 1981 / 9295 | time 2021[s] | loss 1.92
| epoch 3 |  iter 2001 / 9295 | time 2023[s] | loss 1.94
| epoch 3 |  iter 2021 / 9295 | time 2024[s] | loss 1.92
| epoch 3 |  iter 2041 / 9295 | time 2026[s] | loss 1.92
| epoch 3 |  iter 2061 / 9295 | time 2028[s] | loss 1.94
| epoch 3 |  iter 2081 / 9295 | time 2030[s] | loss 1.95
| epoch 3 |  iter 2101 / 9295 | time 2032[s] | loss 1.95
| epoch 3 |  iter 2121 / 9295 | time 2033[s] | loss 1.92
| epoch 3 |  iter 2141 / 9295 | time 2035[s] | loss 1.91
| epoch 3 |  iter 2161 / 9295 | time 2037[s] | loss 1.91
| epoch 3 |  iter 2181 / 9295 | time 2039[s] | loss 1.90
| epoch 3 |  iter 2201 / 9295 |

| epoch 3 |  iter 4741 / 9295 | time 2267[s] | loss 1.91
| epoch 3 |  iter 4761 / 9295 | time 2269[s] | loss 1.97
| epoch 3 |  iter 4781 / 9295 | time 2270[s] | loss 1.91
| epoch 3 |  iter 4801 / 9295 | time 2272[s] | loss 1.93
| epoch 3 |  iter 4821 / 9295 | time 2274[s] | loss 1.88
| epoch 3 |  iter 4841 / 9295 | time 2276[s] | loss 1.93
| epoch 3 |  iter 4861 / 9295 | time 2277[s] | loss 1.91
| epoch 3 |  iter 4881 / 9295 | time 2279[s] | loss 1.90
| epoch 3 |  iter 4901 / 9295 | time 2281[s] | loss 1.90
| epoch 3 |  iter 4921 / 9295 | time 2283[s] | loss 1.93
| epoch 3 |  iter 4941 / 9295 | time 2284[s] | loss 1.90
| epoch 3 |  iter 4961 / 9295 | time 2286[s] | loss 1.94
| epoch 3 |  iter 4981 / 9295 | time 2288[s] | loss 1.91
| epoch 3 |  iter 5001 / 9295 | time 2290[s] | loss 1.90
| epoch 3 |  iter 5021 / 9295 | time 2291[s] | loss 1.93
| epoch 3 |  iter 5041 / 9295 | time 2293[s] | loss 1.94
| epoch 3 |  iter 5061 / 9295 | time 2295[s] | loss 1.89
| epoch 3 |  iter 5081 / 9295 |

| epoch 3 |  iter 7621 / 9295 | time 2523[s] | loss 1.92
| epoch 3 |  iter 7641 / 9295 | time 2525[s] | loss 1.89
| epoch 3 |  iter 7661 / 9295 | time 2526[s] | loss 1.91
| epoch 3 |  iter 7681 / 9295 | time 2528[s] | loss 1.91
| epoch 3 |  iter 7701 / 9295 | time 2530[s] | loss 1.87
| epoch 3 |  iter 7721 / 9295 | time 2532[s] | loss 1.89
| epoch 3 |  iter 7741 / 9295 | time 2533[s] | loss 1.92
| epoch 3 |  iter 7761 / 9295 | time 2535[s] | loss 1.89
| epoch 3 |  iter 7781 / 9295 | time 2537[s] | loss 1.92
| epoch 3 |  iter 7801 / 9295 | time 2539[s] | loss 1.90
| epoch 3 |  iter 7821 / 9295 | time 2540[s] | loss 1.90
| epoch 3 |  iter 7841 / 9295 | time 2542[s] | loss 1.88
| epoch 3 |  iter 7861 / 9295 | time 2544[s] | loss 1.90
| epoch 3 |  iter 7881 / 9295 | time 2546[s] | loss 1.90
| epoch 3 |  iter 7901 / 9295 | time 2548[s] | loss 1.90
| epoch 3 |  iter 7921 / 9295 | time 2549[s] | loss 1.90
| epoch 3 |  iter 7941 / 9295 | time 2551[s] | loss 1.89
| epoch 3 |  iter 7961 / 9295 |

| epoch 4 |  iter 1221 / 9295 | time 5493[s] | loss 1.81
| epoch 4 |  iter 1241 / 9295 | time 5495[s] | loss 1.82
| epoch 4 |  iter 1261 / 9295 | time 5497[s] | loss 1.82
| epoch 4 |  iter 1281 / 9295 | time 5501[s] | loss 1.84
| epoch 4 |  iter 1301 / 9295 | time 5505[s] | loss 1.85
| epoch 4 |  iter 1321 / 9295 | time 5508[s] | loss 1.78
| epoch 4 |  iter 1341 / 9295 | time 5512[s] | loss 1.83
| epoch 4 |  iter 1361 / 9295 | time 5515[s] | loss 1.81
| epoch 4 |  iter 1381 / 9295 | time 5518[s] | loss 1.84
| epoch 4 |  iter 1401 / 9295 | time 5522[s] | loss 1.83
| epoch 4 |  iter 1421 / 9295 | time 5525[s] | loss 1.81
| epoch 4 |  iter 1441 / 9295 | time 5529[s] | loss 1.78
| epoch 4 |  iter 1461 / 9295 | time 5532[s] | loss 1.81
| epoch 4 |  iter 1481 / 9295 | time 12735[s] | loss 1.80
| epoch 4 |  iter 1501 / 9295 | time 12738[s] | loss 1.82
| epoch 4 |  iter 1521 / 9295 | time 12740[s] | loss 1.84
| epoch 4 |  iter 1541 / 9295 | time 12742[s] | loss 1.82
| epoch 4 |  iter 1561 / 92

KeyboardInterrupt: 

### 4.3.3 CBOWモデルの評価

In [4]:
import sys
sys.path.append('..')
from common.util import most_similar, analogy
import pickle


pkl_file = 'cbow_params.pkl'
# pkl_file = 'skipgram_params.pkl'

with open(pkl_file, 'rb') as f:
    params = pickle.load(f)
    word_vecs = params['word_vecs']
    word_to_id = params['word_to_id']
    id_to_word = params['id_to_word']

# most similar task
querys = ['you', 'year', 'car', 'toyota']
for query in querys:
    most_similar(query, word_to_id, id_to_word, word_vecs, top=5)

# analogy task
print('-'*50)
analogy('king', 'man', 'queen',  word_to_id, id_to_word, word_vecs)
analogy('take', 'took', 'go',  word_to_id, id_to_word, word_vecs)
analogy('car', 'cars', 'child',  word_to_id, id_to_word, word_vecs)
analogy('good', 'better', 'bad',  word_to_id, id_to_word, word_vecs)


[query] you
 we: 0.6103515625
 someone: 0.59130859375
 i: 0.55419921875
 something: 0.48974609375
 anyone: 0.47314453125

[query] year
 month: 0.71875
 week: 0.65234375
 spring: 0.62744140625
 summer: 0.6259765625
 decade: 0.603515625

[query] car
 luxury: 0.497314453125
 arabia: 0.47802734375
 auto: 0.47119140625
 disk-drive: 0.450927734375
 travel: 0.4091796875

[query] toyota
 ford: 0.55078125
 instrumentation: 0.509765625
 mazda: 0.49365234375
 bethlehem: 0.47509765625
 nissan: 0.474853515625
--------------------------------------------------

[analogy] king:man = queen:?
 woman: 5.16015625
 veto: 4.9296875
 ounce: 4.69140625
 earthquake: 4.6328125
 successor: 4.609375

[analogy] take:took = go:?
 went: 4.55078125
 points: 4.25
 began: 4.09375
 comes: 3.98046875
 oct.: 3.90625

[analogy] car:cars = child:?
 children: 5.21875
 average: 4.7265625
 yield: 4.20703125
 cattle: 4.1875
 priced: 4.1796875

[analogy] good:better = bad:?
 more: 6.6484375
 less: 6.0625
 rather: 5.21875
 slow