In [1]:
#configuration variables
neg_sample_size = 5
embedding_size    = 10
vocab_size       = 10000
vocab_min_frequency = 0
batch_size          = 2
fileLocation        = "./corpus.txt"
table_size          = 100
window              = 3

In [2]:
#essential imports
import numpy as np
import tensorflow as tf
from collections import Counter
import random
import math

In [3]:
def getWord2idx(counter):
    word2idx = dict()
    index=0
    for word, _ in counter:
        word2idx[word] = index
        index+=1
    return word2idx

def getDataList(words, word2idx, counter):
    data      = []
    unk_count = 0
    for word in words:
        if word in word2idx:
            data.append(word2idx[word])
        else:
            unk_count+=1
            data.append(0)
    counter[0][1] = unk_count
    return data

def getCounterDict(word2idx, counter):
    counterDict = {}
    for w, c in counter:
        counterDict[word2idx[w]] = c
    return counterDict

def buildVocabulary(filename):
    with open(filename) as f:
        words = [word for line in f.readlines() for word in line.split()]
    total_words = len(words)
    counter     = [['UNK', 0]]
    counter.extend(list(item) for item in Counter(words).most_common() if item[1]>vocab_min_frequency)
    vocab_size = len(counter)
    word2idx   = getWord2idx(counter)
    counterDict = getCounterDict(word2idx, counter)
    data       = getDataList(words, word2idx, counter)
    idx2word   = dict(zip(word2idx.values(), word2idx.keys()))
    return words, total_words, counter, counterDict, vocab_size, word2idx, data, idx2word

In [4]:
words, total_words, counter, counterDict, vocab_size, word2idx, data, idx2word = buildVocabulary(fileLocation)

In [5]:
def get_labels_and_targets(batch_size):
    
    labels             = np.zeros([batch_size, 1+neg_sample_size], dtype=np.float32)
    labels[:, 0]       = 1.0
    targets            = np.ndarray([batch_size, 1+neg_sample_size], dtype=np.int32)
    return labels, targets

labels, targets = get_labels_and_targets(batch_size)

In [8]:
#create model
def model(Context, Target, embeddingLayer, lrWeights):
    context_embedding    = tf.nn.embedding_lookup(embeddingLayer, Context, name="context_embeddings")
    target_embedding     = tf.nn.embedding_lookup(lrWeights, Target, name="target_lookup")
    softmax_logits       = tf.nn.sigmoid(tf.matmul(context_embedding, target_embedding, transpose_b=True))
    return softmax_logits

In [45]:
#Create placeholders, model and then loss and train step.
Context        = tf.placeholder(tf.int32, [None, 1], name="Context")
Target         = tf.placeholder(tf.int32, [None, 1 + neg_sample_size], name="Target")
Y              = tf.placeholder(tf.float32, [None, 1 + neg_sample_size], name="Y")
init_width     = 0.5/embedding_size 
contextLayer   = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -init_width, init_width), name='embed_wts')
init_weight    = 1.0/math.sqrt(embedding_size)
targetLayer    = tf.Variable(tf.truncated_normal([vocab_size, embedding_size], stddev = init_weight), name='target_wts')
softmax_logits = model(Context, Target, contextLayer, targetLayer)
loss           = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=softmax_logits, labels=Y))
train          = tf.train.GradientDescentOptimizer(0.001).minimize(loss)

In [11]:
#tasks: 1. create table struture
#2. build new batch for each structure
#. pass those batches to train and print the loss

def fill_table(counter, table_size, idx2word, alpha=0.75):
    total_count_power = 0.0
    for _, count in counter.items():
        total_count_power += math.pow(count, alpha)
    word_idx  = 1
    table     = np.zeros([table_size], dtype=np.int32)
    word_prob = math.pow(counter[word_idx], alpha)/total_count_power
    limitCrossed = False
    for idx in range(table_size):
        table[idx] = word_idx
        tableProbNum = float(idx)/table_size
        if word_prob<tableProbNum:
            word_idx+=1
            word_prob+=math.pow(counter[word_idx], alpha)/total_count_power
        if word_idx > vocab_size or limitCrossed==True:
            limitCrossed=True
            word_idx = random(random.randrange(1, vocab_size))
    return table
table = fill_table(counterDict, table_size, idx2word, 0.75)

In [12]:
print(table)

[ 1  1  1  1  1  1  2  2  2  3  3  3  4  4  4  4  5  5  6  6  7  7  8  8  9
  9 10 11 11 12 12 13 13 14 14 15 15 16 16 17 17 18 18 19 19 20 20 21 21 22
 22 23 23 24 24 25 25 26 26 27 27 28 28 29 30 30 31 31 32 32 33 33 34 34 35
 35 36 36 37 37 38 38 39 39 40 40 41 41 42 42 43 43 44 44 45 45 46 46 47 47]


In [30]:
def sample_targets_per_batch(target_per_batch, context, targetTable):
    target_per_batch[0]  =context
    count = 0
    tableSize = len(targetTable)
    while count<neg_sample_size:
        neg_target = targetTable[random.randrange(tableSize)]
        if context!=neg_target:
            target_per_batch[count+1] = neg_target
            count+=1
    return target_per_batch

#takes batch size, target:[batch, neg_samples_size] and contexts[batch_size]
def sample_targets(batch_size, targets, contexts, targetTable):
#     print("Total contexts are:- ")
#     print(len(contexts))
    for batch in range(batch_size):
        targetPerBatch    = targets[batch, :]
        context           = contexts[batch]
        targets[batch, :] = sample_targets_per_batch(targetPerBatch, context, targetTable) 
    return targets
#get contexts of size that much and do samples of targets and now create the training data for that.

In [46]:
#run single training operation
def singleTrainOperation(sess, batch_size, xBatch, targets, positiveTargetBatch):
    targets = sample_targets(batch_size, targets, positiveTargetBatch, table)
#     print(type(xBatch))
#     print(len(xBatch))
#     print(type(targets))
#     print(type(labels))
    sess.run(train, feed_dict={Context:xBatch, Target:targets, Y:labels})
    
    
def fullTraining(sess, filename):
    with open(filename) as f:
        words = [word for line in f.readlines() for word in line.split()]
        currentBatchSize = 0
        xContextList     = []
        xTargetList      = []
        trainingOperations = 0
        for idx, word in enumerate(words):
            wor_idx = word2idx[word]
            reduced_window = random.randrange(1, window)
            context = word2idx[words[idx]]
            for jdx in xrange(idx-reduced_window, idx+reduced_window):
                if jdx<0 or jdx==idx or jdx>=len(words):
                    continue
                target = word2idx[words[jdx]]
                xContextList.append([context])
                xTargetList.append(target)
                currentBatchSize+=1
                if currentBatchSize==batch_size:
                    singleTrainOperation(sess, batch_size, xContextList, targets, xTargetList)
                    currentBatchSize = 0
                    trainingOperations+=1
                    if trainingOperations%1==0:
                        completeTargets = sample_targets(batch_size, targets, xTargetList, table)
                        myLoss = sess.run(loss, feed_dict={Context:xContextList, Target:completeTargets, Y:labels})
                        print("Loss this batch:- " )
                        print(myLoss)
                    xContextList = []
                    xTargetList  = []


In [47]:
sess  = tf.Session()
init_op = tf.global_variables_initializer()
sess.run(init_op)
fullTraining(sess, fileLocation)
sess.close()

Loss this batch:- 
3.58272
Loss this batch:- 
3.58254
Loss this batch:- 
3.587
Loss this batch:- 
3.59747
Loss this batch:- 
3.58961
Loss this batch:- 
3.5849
Loss this batch:- 
3.58239
Loss this batch:- 
3.58081
Loss this batch:- 
3.57663
Loss this batch:- 
3.60052
Loss this batch:- 
3.58189
Loss this batch:- 
3.5738
Loss this batch:- 
3.57694
Loss this batch:- 
3.58195
Loss this batch:- 
3.58882
Loss this batch:- 
3.58384
Loss this batch:- 
3.57385
Loss this batch:- 
3.57357
Loss this batch:- 
3.59244
Loss this batch:- 
3.59503
Loss this batch:- 
3.57119
Loss this batch:- 
3.58252
Loss this batch:- 
3.59103
Loss this batch:- 
3.57106
Loss this batch:- 
3.58966
Loss this batch:- 
3.58012
Loss this batch:- 
3.57467
Loss this batch:- 
3.57738
Loss this batch:- 
3.58734
Loss this batch:- 
3.59572
Loss this batch:- 
3.58566
Loss this batch:- 
3.5809
Loss this batch:- 
3.56582
Loss this batch:- 
3.59394
Loss this batch:- 
3.57744
Loss this batch:- 
3.58871
Loss this batch:- 
3.59005
Loss t