<a href="https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/0512.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 10.3 word2vec 的实现

In [0]:
!pip install mxnet d2lzh

In [3]:
!git clone https://www.github.com/d2l-ai/d2l-zh.git

Cloning into 'd2l-zh'...
remote: Enumerating objects: 17, done.[K
remote: Counting objects:   5% (1/17)[Kremote: Counting objects:  11% (2/17)[Kremote: Counting objects:  17% (3/17)[Kremote: Counting objects:  23% (4/17)[Kremote: Counting objects:  29% (5/17)[Kremote: Counting objects:  35% (6/17)[Kremote: Counting objects:  41% (7/17)[Kremote: Counting objects:  47% (8/17)[Kremote: Counting objects:  52% (9/17)[Kremote: Counting objects:  58% (10/17)[Kremote: Counting objects:  64% (11/17)[Kremote: Counting objects:  70% (12/17)[Kremote: Counting objects:  76% (13/17)[Kremote: Counting objects:  82% (14/17)[Kremote: Counting objects:  88% (15/17)[Kremote: Counting objects:  94% (16/17)[Kremote: Counting objects: 100% (17/17)[Kremote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 15702 (delta 9), reused 8 (delta 4), pack-reused 15685[K
Receiving objects: 100% (15702/15702), 159.56 MiB | 34.

In [0]:
!mkdir ../data

In [0]:
!cp ./d2l-zh/data/ptb.zip ../data/

In [0]:
import collections 
import d2lzh as d2l 
import math 
from mxnet import autograd, gluon, nd 
from mxnet.gluon import data as gdata, loss as gloss, nn 
import random 
import sys 
import time 
import zipfile 

In [7]:
with zipfile.ZipFile('../data/ptb.zip', 'r') as zin:
    zin.extractall('../data/')

with open('../data/ptb/ptb.train.txt', 'r') as f:
    lines = f.readlines()
    raw_dataset = [st.split() for st in lines]

'# sentences: %d' % len(raw_dataset)

'# sentences: 42068'

In [8]:
for st in raw_dataset[:3]:
    print('# tokens:', len(st), st[:5])

# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']
# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']
# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']


In [0]:
counter = collections.Counter([tk for st in raw_dataset for tk in st])
counter = dict(filter(lambda x: x[1] >= 5, counter.items()))

In [10]:
idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]
num_tokens = sum([len(st) for st in dataset])
"# tokens: %d" % num_tokens

'# tokens: 887100'

In [11]:
def discard(idx):
    return random.uniform(0, 1) < 1 - math.sqrt(1e-4 / counter[idx_to_token[idx]] * num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset])

'# tokens: 375509'

In [12]:
def compare_counts(token):
    return '#%s: before=%d, after=%d' %(token, 
        sum([st.count(token_to_idx[token]) for st in dataset]), 
        sum([st.count(token_to_idx[token]) for st in subsampled_dataset]))

compare_counts('the')

'#the: before=50770, after=2075'

In [13]:
compare_counts('join')

'#join: before=45, after=45'

In [0]:
def get_centers_and_contexts(dataset, max_window_size):
    centers, contexts = [], []
    for st in dataset:
        if len(st) < 2:
            continue 
        centers += st 
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, center_i - window_size), min(len(st), center_i + 1 + window_size)))
            indices.remove(center_i)
            contexts.append([st[idx] for idx in indices])
    return centers, contexts

In [15]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('dataset', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('center', center, 'has contexts', context)

dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
center 0 has contexts [1, 2]
center 1 has contexts [0, 2]
center 2 has contexts [0, 1, 3, 4]
center 3 has contexts [2, 4]
center 4 has contexts [2, 3, 5, 6]
center 5 has contexts [3, 4, 6]
center 6 has contexts [5]
center 7 has contexts [8, 9]
center 8 has contexts [7, 9]
center 9 has contexts [8]


In [0]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [0]:
def get_negatives(all_context, sampling_weights, K):
    all_negatives, neg_candidates, i = [], [], 0 
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                i, neg_candidates = 0, random.choices(population, sampling_weights, k=int(1e5))
            neg, i = neg_candidates[i], i + 1 
            if neg not in set(contexts):
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives 

sampling_weights = [counter[w]**0.75 for w in idx_to_token]
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [0]:
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]

    return (nd.array(centers).reshape((-1, 1)), nd.array(contexts_negatives), 
        nd.array(masks), nd.array(labels))

In [35]:
batch_size = 512 
num_workers = 4 
dataset = gdata.ArrayDataset(all_centers, all_contexts, all_negatives)
data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True, batchify_fn=batchify, num_workers=num_workers)
for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):
        print(name, 'shape:', data.shape)
    break 

centers shape: (512, 1)
contexts_negatives shape: (512, 60)
masks shape: (512, 60)
labels shape: (512, 60)


In [36]:
embed = nn.Embedding(input_dim=20, output_dim=4)
embed.initialize()
embed.weight

Parameter embedding0_weight (shape=(20, 4), dtype=float32)

In [37]:
x = nd.array([[1, 2, 3], [4, 5, 6]])
embed(x)


[[[ 0.01438687  0.05011239  0.00628365  0.04861524]
  [-0.01068833  0.01729892  0.02042518 -0.01618656]
  [-0.00873779 -0.02834515  0.05484822 -0.06206018]]

 [[ 0.06491279 -0.03182812 -0.01631819 -0.00312688]
  [ 0.0408415   0.04370362  0.00404529 -0.0028032 ]
  [ 0.00952624 -0.01501013  0.05958354  0.04705103]]]
<NDArray 2x3x4 @cpu(0)>

In [38]:
X = nd.ones((2, 1, 4))
Y = nd.ones((2, 4, 6))
nd.batch_dot(X, Y).shape 

(2, 1, 6)

In [0]:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    v = embed_v(center)
    u = embed_u(contexts_and_negatives)
    pred = nd.batch_dot(v, u.swapaxes(1, 2))
    return pred 

In [0]:
loss = gloss.SigmoidBinaryCrossEntropyLoss()

In [57]:
pred = nd.array([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])
label = nd.array([[1, 0, 0, 0], [1, 1, 0, 0]])
mask = nd.array([[1, 1, 1, 1], [1, 1, 1, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)


[0.8739896 1.2099689]
<NDArray 2 @cpu(0)>

In [58]:
def sigmd(x):
    return -math.log(1 / (1 + math.exp(-x)))

print('%.7f' % ((sigmd(1.5) + sigmd(-0.3) + sigmd(1) + sigmd(-2)) / 4))
print('%.7f' % ((sigmd(1.1) + sigmd(-0.6) + sigmd(-2.2)) / 3))

0.8739896
1.2099689


In [0]:
embed_size = 100 
net = nn.Sequential()
net.add(nn.Embedding(input_dim=len(idx_to_token), output_dim=embed_size),
    nn.Embedding(input_dim=len(idx_to_token), output_dim=embed_size))

In [0]:
def train(net, lr, num_epochs):
    ctx = d2l.try_gpu()
    net.initialize(ctx=ctx, force_reinit=True)
    trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})

    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0 
        for batch in data_iter:
            center, context_negative, mask, label = [data.as_in_context(ctx) for data in batch]
            with autograd.record():
                pred = skip_gram(center, context_negative, net[0], net[1])
                l = (loss(pred.reshape(label.shape), label, mask) * mask.shape[1] / mask.sum(axis=1))
            l.backward()
            trainer.step(batch_size)
            l_sum += l.sum().asscalar()
            n += l.size 
        print('epoch %d, loss %.2f, time %.2fs' % (epoch + 1, l_sum / n, time.time() - start))

In [61]:
train(net, 0.005, 5)

epoch 1, loss 0.46, time 158.52s
epoch 2, loss 0.39, time 155.11s
epoch 3, loss 0.35, time 147.60s
epoch 4, loss 0.32, time 148.85s
epoch 5, loss 0.30, time 154.30s


In [67]:
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data()
    x = W[token_to_idx[query_token]]
    cos = nd.dot(W, x) / (nd.sum(W * W, axis=1) * nd.sum(x * x) + 1e-9).sqrt()
    topk = nd.topk(cos, k=k+1, ret_typ='indices').asnumpy().astype('int32')
    for i in topk[1:]:
        print('cosine sim=%.3f: %s' % (cos[i].asscalar(), (idx_to_token[1])))

get_similar_tokens('chip', 3, net[0])

cosine sim=0.516: <unk>
cosine sim=0.480: <unk>
cosine sim=0.470: <unk>
