In [1]:
import collections
import os
import random
import urllib
import zipfile

import numpy as np
import tensorflow as tf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  from ._conv import register_converters as _register_converters
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [37]:
# Training Parameters
learning_rate = 0.1
batch_size = 128
num_steps = 3000000
display_step = 100
eval_step = 200

# Evaluation Parameters
eval_words = [b'five',b'of', b'going', b'hardware', b'american', b'britain']

# Word2Vec Parameters
embedding_size = 200 # Dimension of the embedding vector
max_vocabulary_size = 50000 # Total number of different words in the vocabulary
min_occurrence = 10 # Remove all words that does not appears at least n times
skip_window = 3 # How many words to consider left and right
num_skips = 2 # How many times to reuse an input to generate a label
num_sampled = 64 # Number of negative examples to sample

In [3]:
# Download a small chunk of Wikipedia articles collection
url = 'http://mattmahoney.net/dc/text8.zip'
data_path = 'text8.zip'
if not os.path.exists(data_path):
    print("Downloading the dataset... (It may take some time)")
    filename, _ = urllib.request.urlretrieve(url, data_path)
    print("Done!")
# Unzip the dataset file. Text has already been processed
with zipfile.ZipFile(data_path) as f:
    text_words = f.read(f.namelist()[0]).lower().split()

In [4]:
# Build the dictionary and replace rare words with UNK token
count = [('UNK', -1)]
# Retrieve the most common words
count.extend(collections.Counter(text_words).most_common(max_vocabulary_size - 1))
# Remove samples with less than 'min_occurrence' occurrences
for i in range(len(count) - 1, -1, -1):
    if count[i][1] < min_occurrence:
        count.pop(i)
    else:
        # The collection is ordered, so stop when 'min_occurrence' is reached
        break
# Compute the vocabulary size
vocabulary_size = len(count)
# Assign an id to each word
word2id = dict()
for i, (word, _)in enumerate(count):
    word2id[word] = i

data = list()
unk_count = 0
for word in text_words:
    # Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
    index = word2id.get(word, 0)
    if index == 0:
        unk_count += 1
    data.append(index)
count[0] = ('UNK', unk_count)
id2word = dict(zip(word2id.values(), word2id.keys()))

In [24]:
max(data)

47134

In [25]:
data_index = 0
#next_batch函数

def next_batch(batch_size,num_skips,skip_window):
    global data_index
    
    span = 2*skip_window+1
    
    batch = np.ndarray(shape=(batch_size),dtype=np.int32)
    labels = np.ndarray(shape=(batch_size,1),dtype=np.int32)
    
    buffer = collections.deque(maxlen=span)
    
    buffer.extend(data[data_index:data_index+span])
    
    data_index = data_index+span
    
    context_index = [w for w in range(span) if w!=skip_window]
    
    for i in range(batch_size//num_skips):
        random_context = random.sample(context_index,num_skips)
        
        for j,word_id in enumerate(random_context):
            batch[i*num_skips+j] = buffer[skip_window]
            labels[i*num_skips+j,0] = buffer[word_id]
            
        if data_index==len(data):
            buffer.extend(data[0:span])
            data_index = span
        
        else:
            buffer.append(data[data_index])
            data_index+=1
            
    data_index = (data_index-span+len(data))%(len(data))
    return batch,labels


In [26]:
#check
batch,labels = next_batch(batch_size=batch_size,skip_window=skip_window,num_skips=num_skips)

print(len(batch))
print(batch[0])

128
6


In [34]:
#使用tensorflow构建模型
#占位符
train_x = tf.placeholder(shape=[None],dtype=tf.int32)
train_y = tf.placeholder(shape=[None,1],dtype=tf.int32)

#获得词向量
embedding = tf.Variable(tf.random_normal(shape=[vocabulary_size,embedding_size],dtype=tf.float32))

X_embed = tf.nn.embedding_lookup(embedding,train_x)

weights = tf.Variable(tf.random_normal(shape=[vocabulary_size,embedding_size],dtype=tf.float32))
bias    = tf.Variable(tf.zeros(shape=[vocabulary_size],dtype=tf.float32))

nce_loss = tf.reduce_mean(tf.nn.nce_loss(
        weights = weights,
        biases    = bias,
        labels  = train_y,
        inputs  = X_embed,
        num_sampled=num_sampled,
        num_classes = vocabulary_size
))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(nce_loss)

#compute 词向量consine距离，大致上判断效果
X_norm = X_embed/(tf.sqrt(tf.reduce_sum(tf.square(X_embed),axis=1,keep_dims=True)))
embedding_norm = embedding/(tf.sqrt(tf.reduce_sum(tf.square(embedding),axis=1,keep_dims=True)))
dis_consine = tf.matmul(X_norm,embedding_norm,transpose_b=True)
                                

In [None]:
#train

init = tf.global_variables_initializer()

x_test = np.array([word2id[w] for w in eval_words])

with tf.Session() as sess:
    sess.run(init)
    step_loss = 0.
    for step in range(num_steps):
        Batch,Labels = next_batch(batch_size=batch_size,skip_window=skip_window,num_skips=num_skips)
        _,loss = sess.run([optimizer,nce_loss],feed_dict={train_x:Batch,train_y:Labels})
        
        step_loss+=(loss/display_step)
        if (step+1)%(display_step)==0:
            print("average step loss is %f"%(step_loss))
            step_loss = 0.
        
        if (step+1)%(eval_step)==0:
            cos = sess.run(dis_consine,feed_dict={train_x:x_test})
            print("Eval is starting.......")
            for i in range(len(x_test)):
                
                nearest_id = (-cos[i,:]).argsort()[1:8]
                print("the nearst words of %s are"%(eval_words[i]))
                for word in nearest_id:
                    nearest_word = ""
                    nearest_word = "%s,%s"%(nearest_word,id2word[word])
                print(nearest_word)