In [4]:
with open("./pride_and_prejudice.txt", "r") as f:
    corpus = f.readlines()

In [8]:
sample = corpus[:100]

In [36]:
import re
def preprocess_text(corpus):
    # lowercase
    corpus = sum([a.lower().split() for a in corpus if a != "\n"],[])
    corpus = [re.sub('[^A-Za-z0-9]+', '', token) for token in corpus]
    corpus = [token for token in corpus if len(token) > 0] 
    corpus = [re.sub('^\d+$', '_NUM_', token) for token in corpus]
    
    return corpus

corpus = preprocess_text(corpus)

In [53]:
len(corpus)

124543

In [52]:
vocab = sorted(list(set(corpus)))
print( "Vocabulary size: ", len(vocab))
print(vocab[:100])

Vocabulary size:  7044
['13420txt', '13420zip', '15th', '18th', '1a', '1b', '1c', '1d', '1e', '1e1', '1e2', '1e3', '1e4', '1e5', '1e6', '1e7', '1e8', '1e9', '1f', '1f1', '1f2', '1f3', '1f4', '1f5', '1f6', '26th', '501c3', '_NUM_', 'a', 'abatement', 'abhorrence', 'abhorrent', 'abide', 'abiding', 'abilities', 'able', 'ablution', 'abode', 'abominable', 'abominably', 'abominate', 'abound', 'about', 'above', 'abroad', 'abrupt', 'abruptly', 'abruptness', 'absence', 'absent', 'absolute', 'absolutely', 'absurd', 'absurdities', 'absurdity', 'abundant', 'abundantly', 'abuse', 'abused', 'abusing', 'abusive', 'accede', 'acceded', 'acceding', 'accent', 'accents', 'accept', 'acceptable', 'acceptance', 'accepted', 'accepting', 'access', 'accessed', 'accessible', 'accident', 'accidental', 'accidentally', 'accompanied', 'accompany', 'accompanying', 'accomplished', 'accomplishedshe', 'accomplishment', 'accomplishments', 'accordance', 'according', 'accordingly', 'accosted', 'account', 'accounted', 'accou

- Total words: 125000
- Vocabulary size: 7100


## word to id representation

In [55]:
word_to_ids = {vocab[i]:(i+1) for i in range(len(vocab))}
id_to_words = {(i+1):vocab[i] for i in range(len(vocab))}

In [56]:
corpus_ids = [word_to_ids[token] for token in corpus]
print(corpus_ids[:100])

[6267, 4911, 2887, 2025, 4332, 4852, 319, 4797, 859, 3545, 548, 6302, 2025, 3532, 2574, 6267, 6655, 4332, 365, 367, 504, 4222, 1409, 319, 6937, 269, 4222, 5367, 6867, 7026, 3931, 1390, 3534, 2754, 3534, 572, 4388, 5390, 3534, 6516, 6267, 6251, 4332, 6267, 4911, 2887, 3735, 3274, 6937, 6302, 2025, 4388, 4369, 504, 7002, 6347, 4852, 319, 4797, 550, 3545, 548, 4756, 1521, 545, 28, 28, 2025, 28, 5213, 1521, 3587, 28, 3662, 6645, 4326, 28, 28, 3655, 2146, 974, 5658, 2114, 6665, 5953, 4332, 6302, 4911, 2887, 2025, 4852, 319, 4797, 4889, 859, 344, 6757, 4852, 319, 4797]


In [57]:
import tensorflow as tf

In [59]:
V = len(vocab)
C = 3 # context size 
D = 20
N = len(corpus)
batch size = 128

In [93]:
def create_word_pair(word_ids, C):
    # cut corpus in batch_size
    N = len(word_ids)    
    M = (N-2*C) * 2*C
    centers = [0] * M
    targets = [0] * M
    
    for i in range(C, N-C):        
        k = (i-C)*2*C + C
        
        for j in range(1, C + 1):            
            centers[k - j]  = word_ids[i]
            targets[k - j]  = word_ids[i - j]            
            
            centers[k + j - 1]  = word_ids[i]
            targets[k + j - 1]  = word_ids[i + j]
            
    return list(zip(centers, targets))

In [97]:
train_pairs = create_word_pair([1,2,3,4,5,6,7,8],2)
print(train_pairs)

[(3, 1), (3, 2), (3, 4), (3, 5), (4, 2), (4, 3), (4, 5), (4, 6), (5, 3), (5, 4), (5, 6), (5, 7), (6, 4), (6, 5), (6, 7), (6, 8)]


In [109]:
import random 
def create_batches(train_pairs, batch_size):
    random.shuffle(train_pairs)
    M = len(train_pairs) // batch_size
    if len(train_pairs) > batch_size * M:
        M += 1
    
    return [train_pairs[i*batch_size:(i+1)*batch_size] for i in range(M)]

## Model

- Input : batch of (id_word, id_context_word)
- 2 embedded matrix each of size (VxD): P and Q
-


In [75]:
params = tf.Variable(tf.random_normal([4,3], mean=0.0, stddev=1.0, seed=0))
initializer = tf.global_variables_initializer()

ids= tf.constant([[1,3], [0, 2 ]])

lookup = tf.nn.embedding_lookup(params, ids)

with tf.Session() as sess:
    sess.run(initializer)
    print("parameters")
    print(sess.run(ids))
    print(sess.run(params))
    print("lookup")
    print(sess.run(lookup))


parameters
[[1 3]
 [0 2]]
[[-0.39915761  2.10443926  0.17107224]
 [ 0.54651815 -2.42340255  0.42255393]
 [ 0.28943786 -0.50430411 -0.96068907]
 [-0.65109813  0.11453361 -0.10354779]]
lookup
[[[ 0.54651815 -2.42340255  0.42255393]
  [-0.65109813  0.11453361 -0.10354779]]

 [[-0.39915761  2.10443926  0.17107224]
  [ 0.28943786 -0.50430411 -0.96068907]]]
