<a href="https://colab.research.google.com/github/shotahorii/bareml/blob/master/word2vec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# clone bareml repo
!git clone https://github.com/shotahorii/bareml.git

# move to the bareml repo
import os
path = '/content/bareml'
os.chdir(path)

Cloning into 'bareml'...
remote: Enumerating objects: 647, done.[K
remote: Counting objects: 100% (647/647), done.[K
remote: Compressing objects: 100% (447/447), done.[K
remote: Total 1188 (delta 407), reused 418 (delta 187), pack-reused 541[K
Receiving objects: 100% (1188/1188), 18.05 MiB | 17.98 MiB/s, done.
Resolving deltas: 100% (719/719), done.


In [2]:
import pickle
import numpy as np

try:
  import cupy as cp
  device = 'cuda'
except:
  device = 'cpu'

from bareml.deeplearning import functions as F
from bareml.deeplearning.models import CBOW
from bareml.deeplearning.optimisers import Adam
from bareml.deeplearning.data import NewsGroupsWordInference, DataLoader
from bareml.deeplearning.metrics import cos_similarity

In [3]:
# set some parameters 
num_epoch = 10
batch_size = 256

In [4]:
# load the dataset
dataset = NewsGroupsWordInference()
dataloader = DataLoader(dataset, batch_size)
num_batches_per_epoch = len(dataset)//batch_size

Creating NewsGroupsWordInference Dataset: Step 1/2
Downloading: 20news-bydate.tar.gz
[##############################] 100.00% Done
-- Creating the corpus --
1000/11314 docs
2000/11314 docs
3000/11314 docs
4000/11314 docs
5000/11314 docs
6000/11314 docs
7000/11314 docs
8000/11314 docs
9000/11314 docs
10000/11314 docs
11000/11314 docs
Creating NewsGroupsWordInference Dataset: Step 2/2
1000/11314 docs
2000/11314 docs
3000/11314 docs
4000/11314 docs
5000/11314 docs
6000/11314 docs
7000/11314 docs
8000/11314 docs
9000/11314 docs
10000/11314 docs
11000/11314 docs


In [5]:
# define the model and optimiser
model = CBOW(dataset.corpus.corpus,embedding_dim=100).to(device)
optimiser = Adam(model.parameters())

In [6]:
# train the model
avg_losses = []
log_span = 500

model.train()
for epoch in range(num_epoch):
    cumulative_loss = 0
    for batch_idx, (context, target) in enumerate(dataloader):
        optimiser.zero_grad()
        y, correct = model(context, target)
        loss = F.binary_cross_entropy(y, correct)
        loss.backward()
        optimiser.step()

        cumulative_loss += loss.data

        if batch_idx % log_span == 0 and batch_idx != 0:
          avg_loss = cumulative_loss/log_span
          avg_losses.append(avg_loss)
          print("("+str(epoch+1)+"/"+str(num_epoch)+")", round(100 * batch_idx/num_batches_per_epoch, 2), "%")
          print("       avg loss: "+str(avg_loss))
          cumulative_loss = 0

(1/10) 3.84 %
       avg loss: 1.9115759113668642
(1/10) 7.67 %
       avg loss: 1.7710422988155148
(1/10) 11.51 %
       avg loss: 1.6607071168109897
(1/10) 15.34 %
       avg loss: 1.5535514232589929
(1/10) 19.18 %
       avg loss: 1.4531147077515192
(1/10) 23.02 %
       avg loss: 1.353547775973962
(1/10) 26.85 %
       avg loss: 1.2583664865234525
(1/10) 30.69 %
       avg loss: 1.1777139567381907
(1/10) 34.53 %
       avg loss: 1.1032481575648163
(1/10) 38.36 %
       avg loss: 1.0332345048604443
(1/10) 42.2 %
       avg loss: 0.9765158899455155
(1/10) 46.03 %
       avg loss: 0.9218838343109161
(1/10) 49.87 %
       avg loss: 0.8768610247172589
(1/10) 53.71 %
       avg loss: 0.8351331205468342
(1/10) 57.54 %
       avg loss: 0.8023242806576035
(1/10) 61.38 %
       avg loss: 0.7678409511140464
(1/10) 65.21 %
       avg loss: 0.7422213706221715
(1/10) 69.05 %
       avg loss: 0.7161778740815138
(1/10) 72.89 %
       avg loss: 0.6902189871027806
(1/10) 76.72 %
       avg loss: 0.6

In [7]:
# save the embedding, word2id and id2word
np.save('w2v_weight.npy',cp.asnumpy(model.emb_in.weight.data))

with open('word2id.pkl', 'wb') as handle:
    pickle.dump(dataset.corpus.word2id, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('id2word.pkl', 'wb') as handle:
    pickle.dump(dataset.corpus.id2word, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [10]:
# evaluation
def most_similar(query, top=5):
    word2id = dataset.corpus.word2id
    id2word = dataset.corpus.id2word
    embedding = model.emb_in.weight.data

    if query not in word2id:
        print('query not found')
        return 
    
    print('query:', query)
    
    query_id = word2id[query]
    query_vec = embedding[query_id]
    
    vocab_size = len(id2word)
    similarity = np.zeros(vocab_size)
    for i in range(vocab_size):
        similarity[i] = cos_similarity(embedding[i], query_vec)
    
    count = 0
    for i in (-1*similarity).argsort():
        if id2word[i] == query:
            continue
        print(id2word[i], ':', similarity[i])
        
        count+=1
        if count >= top:
            return

In [11]:
most_similar('children')

query: children
men : 0.5821971893310547
people : 0.5753355622291565
women : 0.5361747741699219
soldiers : 0.5230221152305603
civilians : 0.5228411555290222


In [12]:
most_similar('year')

query: year
month : 0.6843153238296509
years : 0.655612051486969
week : 0.5653883814811707
months : 0.5553741455078125
season : 0.554878830909729


In [13]:
# evaluation 2
def analogy(a, b, c, top=5):
    word2id = dataset.corpus.word2id
    id2word = dataset.corpus.id2word
    embedding = model.emb_in.weight.data

    if a not in word2id or b not in word2id or c not in word2id:
        print('query not found')
        return 
    
    print('[analogy]', a + ':' + b + ' = ' + c + ':?')
    
    a_id = word2id[a]
    b_id = word2id[b]
    c_id = word2id[c]
    a_vec = embedding[a_id]
    b_vec = embedding[b_id]
    c_vec = embedding[c_id]

    query_vec = c_vec + b_vec - a_vec
    
    vocab_size = len(id2word)
    similarity = np.zeros(vocab_size)
    for i in range(vocab_size):
        similarity[i] = cos_similarity(embedding[i], query_vec)
    
    count = 0
    for i in (-1*similarity).argsort():
        if id2word[i] == a or id2word[i] == b or id2word[i] == c:
            continue
        print(id2word[i], ':', similarity[i])
        
        count+=1
        if count >= top:
            return

In [14]:
analogy('car','cars','child')

[analogy] car:cars = child:?
children : 0.481119304895401
<n>j<n>qshz : 0.42697370052337646
unswervingly : 0.42480939626693726
djs : 0.4098614454269409
modernizing : 0.4080216884613037


In [18]:
analogy('take','took','go')

[analogy] take:took = go:?
went : 0.6760730743408203
came : 0.5314589142799377
turned : 0.5067143440246582
knew : 0.4969695806503296
skated : 0.4861728549003601


In [20]:
analogy('good','bad','more')

[analogy] good:bad = more:?
less : 0.6630487442016602
worse : 0.47593769431114197
echte : 0.455361932516098
halsted : 0.4298776388168335
rather : 0.42482250928878784
