In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

# add seed to make sure the result can be reproduced
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)

In [9]:
USE_CUDA

True

In [2]:
# hyper-parameters
C = 3 # window size
K = 100 # number of negative samples
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.1
EMBEDDING_SIZE = 100

In [3]:
def word_tokenize(text):
    return text.split()

In [4]:
with open("text/text.train.txt", "r") as fin:
    text = fin.read()

In [5]:
text = [w for w in word_tokenize(text)]

In [6]:
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1)) # key: word ； value: count

In [7]:
vocab["<unk>"] = len(text) - np.sum(list(vocab.values())) # put the words other than top MAX_VOCAB_SIZE - 1 to UNK. Count the number occurance

In [8]:
vocab

{'the': 958035,
 'of': 536684,
 'and': 375233,
 'one': 371796,
 'in': 335503,
 'a': 292250,
 'to': 285093,
 'zero': 235406,
 'nine': 224705,
 'two': 172079,
 'is': 164575,
 'as': 118931,
 'eight': 113412,
 'for': 106452,
 's': 104935,
 'five': 104416,
 'three': 103344,
 'was': 101939,
 'by': 100587,
 'that': 98710,
 'four': 97719,
 'six': 91897,
 'seven': 90940,
 'with': 85741,
 'on': 82022,
 'are': 68769,
 'it': 66093,
 'from': 65738,
 'or': 62113,
 'his': 58234,
 'an': 55741,
 'be': 55471,
 'this': 53073,
 'he': 49838,
 'at': 49636,
 'which': 49410,
 'not': 39778,
 'also': 39726,
 'have': 35792,
 'were': 35190,
 'has': 33922,
 'but': 32114,
 'other': 29018,
 'their': 28410,
 'its': 26355,
 'first': 25976,
 'they': 25969,
 'had': 25522,
 'some': 25219,
 'more': 23659,
 'all': 23598,
 'can': 22954,
 'most': 22880,
 'been': 22853,
 'such': 21830,
 'who': 21770,
 'many': 21703,
 'new': 21281,
 'there': 20450,
 'used': 20253,
 'after': 19196,
 'american': 18852,
 'when': 18698,
 'time': 1

In [10]:
# build index to word and word to index mapping
# the order of embedding is the same as word_to_idx. Like embedding[0] is the embedding for word[0]
# these two sets will also be used in the evaluation
idx_to_word = list(vocab.keys())
word_to_idx = {word: i for i, word in enumerate(idx_to_word)}

In [11]:
list(word_to_idx.items())[:10]

[('the', 0),
 ('of', 1),
 ('and', 2),
 ('one', 3),
 ('in', 4),
 ('a', 5),
 ('to', 6),
 ('zero', 7),
 ('nine', 8),
 ('two', 9)]

In [12]:
 # store word frequency that will be used in negative sampling
word_counts = np.array(list(vocab.values()), dtype = np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs)

In [13]:
word_freqs

array([1.6231162e-02, 1.0509998e-02, 8.0359895e-03, ..., 5.0128656e-06,
       5.0128656e-06, 1.1670408e-02], dtype=float32)

In [14]:
VOCAB_SIZE = len(idx_to_word)
VOCAB_SIZE

30000

In [15]:
word_to_idx["<unk>"]

29999

In [16]:
# implement data loader
class WordEmbeddingDataLoader(tud.Dataset):
    def __init__(self, text, idx_to_word, word_to_idx, word_counts, word_freqs):
        '''
        text: a list of words
        '''
        super(WordEmbeddingDataLoader, self).__init__()
        self.text_encoded = [word_to_idx.get(t, word_to_idx["<unk>"]) for t in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        self.word_counts = torch.Tensor(word_counts)
        self.word_freqs = torch.Tensor(word_freqs)
        
    def __len__(self):
        '''
        return len of the dataset
        '''
        return len(self.text_encoded)
    
    def __getitem__(self, idx):
        '''
        return (centor_word, positive words around centor word, K negative samples for each positive word)
        '''
        
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices] # avoid index out of bound
        pos_words = self.text_encoded[pos_indices]
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True) # for each pos_word, select K negative words based on word_freqs
        
        return center_word, pos_words, neg_words

In [17]:
dataset = WordEmbeddingDataLoader(text, idx_to_word, word_to_idx, word_counts, word_freqs)
dataLoader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [22]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.centor_word_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False) # (vocab_size, embed_size)
        self.context_word_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False) # (vocab_size, embed_size)
        
    def forward(self, input_labels, pos_labels, neg_labels):
        '''
        input_labels: (batch_size,)
        pos_labels: (batch_size, 2 * C)
        neg_labels: (batch_size, (2 * C * K))
        '''
        
        # implement Equation (4) in https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf
        batch_size = input_labels.shape[0]
        
        input_embedding = self.centor_word_embed(input_labels) # (batch_size, embed_size)
        pos_embedding = self.context_word_embed(pos_labels) # (batch_size, 2C, embed_size)
        neg_embedding = self.context_word_embed(neg_labels) # (batch_size, 2CK, embed_size)
        
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # (batch_size, 2C). The shape of input_embedding.unsqueeze(2) is (batch_size, embed_size, 1)
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # (batch_size, 2CK)
        
        log_pos = F.logsigmoid(log_pos).sum(axis=1) # (batch_size,)
        log_neg = F.logsigmoid(log_neg).sum(axis=1) # (batch_size,)
        
        loss = log_pos + log_neg
        return -loss # since Equation (4) in the paper is the objective. We maximize the objective but minimize loss
    
    def input_embeddings(self):
        # use center word embedding as the final word embedding
        return self.centor_word_embed.weight.data.cpu().numpy()

In [23]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

In [24]:
optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE)

In [25]:
for epoch in range(NUM_EPOCHS):
    for i, (center_word, pos_words, neg_words) in enumerate(dataLoader):
        
        center_word = center_word.long()
        pos_words = pos_words.long()
        neg_words = neg_words.long()
        if USE_CUDA:
            center_word = center_word.cuda()
            pos_words = pos_words.cuda()
            neg_words = neg_words.cuda()
        
        optimizer.zero_grad()
        loss = model(center_word, pos_words, neg_words).mean()
        loss.backward()
        optimizer.step()
        
        if i % 1000 == 0:
            print("epoch: {}, batch: {}, loss: {}".format(epoch, i, loss))

epoch: 0, batch: 0, loss: 2466.075927734375
epoch: 0, batch: 1000, loss: 562.1318969726562
epoch: 0, batch: 2000, loss: 408.2147216796875
epoch: 0, batch: 3000, loss: 280.123046875
epoch: 0, batch: 4000, loss: 195.11968994140625
epoch: 0, batch: 5000, loss: 216.25302124023438
epoch: 0, batch: 6000, loss: 143.4933319091797
epoch: 0, batch: 7000, loss: 116.78927612304688
epoch: 0, batch: 8000, loss: 126.276611328125
epoch: 0, batch: 9000, loss: 144.72784423828125
epoch: 0, batch: 10000, loss: 96.56869506835938
epoch: 0, batch: 11000, loss: 87.42467498779297
epoch: 0, batch: 12000, loss: 103.14344787597656
epoch: 0, batch: 13000, loss: 96.37692260742188
epoch: 0, batch: 14000, loss: 89.42659759521484
epoch: 0, batch: 15000, loss: 69.29397583007812
epoch: 0, batch: 16000, loss: 64.8114013671875
epoch: 0, batch: 17000, loss: 68.66221618652344
epoch: 0, batch: 18000, loss: 82.23294067382812
epoch: 0, batch: 19000, loss: 54.06744384765625
epoch: 0, batch: 20000, loss: 70.83740234375
epoch: 0,

epoch: 1, batch: 51000, loss: 34.474151611328125
epoch: 1, batch: 52000, loss: 35.43360900878906
epoch: 1, batch: 53000, loss: 34.22370147705078
epoch: 1, batch: 54000, loss: 34.114776611328125
epoch: 1, batch: 55000, loss: 34.41321563720703
epoch: 1, batch: 56000, loss: 33.7686882019043
epoch: 1, batch: 57000, loss: 37.052940368652344
epoch: 1, batch: 58000, loss: 36.223026275634766
epoch: 1, batch: 59000, loss: 34.479522705078125
epoch: 1, batch: 60000, loss: 34.014495849609375
epoch: 1, batch: 61000, loss: 34.220611572265625
epoch: 1, batch: 62000, loss: 33.95252227783203
epoch: 1, batch: 63000, loss: 34.91365051269531
epoch: 1, batch: 64000, loss: 33.761741638183594
epoch: 1, batch: 65000, loss: 35.484825134277344
epoch: 1, batch: 66000, loss: 34.574745178222656
epoch: 1, batch: 67000, loss: 34.67189025878906
epoch: 1, batch: 68000, loss: 34.89488220214844
epoch: 1, batch: 69000, loss: 34.25187683105469
epoch: 1, batch: 70000, loss: 33.44493103027344
epoch: 1, batch: 71000, loss: 3

In [26]:
# the simplex-999.txt contains word1, word2, similarity by human
# the idea is to compute the spearman correlation between predicted similarity and human similarity
# the better the correlation, the better of our embedding
def evaluate(filename, embedding_weights, word_to_idx): 
    data = pd.read_csv(filename, sep="\t")
    human_similarity = []
    model_similarity = []
    for i in data.iloc[:, 0:2].index:
        word1, word2 = data.iloc[i, 0], data.iloc[i, 1]
        if word1 not in word_to_idx or word2 not in word_to_idx:
            continue
        else:
            word1_idx, word2_idx = word_to_idx[word1], word_to_idx[word2]
            word1_embed, word2_embed = embedding_weights[[word1_idx]], embedding_weights[[word2_idx]]
            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed, word2_embed)))
            human_similarity.append(float(data.iloc[i, 2]))

    # (correlation, p-value). 1 - p-value sorts of represent the confidence of the computed correlation
    return scipy.stats.spearmanr(human_similarity, model_similarity)

def find_nearest(word, word_to_idx, idx_to_word, embedding_weights):
    index = word_to_idx[word]
    embedding = embedding_weights[index]
    # loop through all the embedding and compute cosine distance between each embedding and the embedding of given word
    cosine_distance = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    # select top 10 words with smallest distance to the given word
    return [idx_to_word[i] for i in cosine_distance.argsort()[:10]]

In [27]:
embedding_weights = model.input_embeddings()
# np.save("embedding-{}".format(EMBEDDING_SIZE), embedding_weights)
# torch.save(model.state_dict(), "embedding-{}.th".format(EMBEDDING_SIZE))

In [28]:
print("simlex-999", evaluate("simlex-999.txt", embedding_weights, word_to_idx))

simlex-999 SpearmanrResult(correlation=-0.053366410819609765, pvalue=0.09948912566926381)


In [29]:
for word in ["good", "fresh", "monster", "green", "like", "america", "chicago", "work", "computer", "language"]:
    print(word, find_nearest(word, word_to_idx, idx_to_word, embedding_weights))

good ['good', 'either', 'even', 'given', 'means', 'because', 'without', 'usually', 'therefore', 'them']
fresh ['fresh', 'shady', 'predicting', 'folk', 'recall', 'cores', 'aclu', 'associated', 'abused', 'fin']
monster ['monster', 'preparedness', 'infections', 'routledge', 'paint', 'trio', 'argument', 'geology', 'crows', 'presented']
green ['green', 'black', 'group', 'red', 'and', 'white', 'etc', 'through', 'with', 'from']
like ['like', 'or', 'called', 'also', 'thus', 'being', 'with', 'include', 'using', 'instead']
america ['america', 'north', 'europe', 'south', 'central', 'east', 'africa', 'southern', 'west', 'western']
chicago ['chicago', 'university', 'press', 'state', 'japan', 'history', 'city', 'studies', 'education', 'uk']
work ['work', 'being', 'created', 'while', 'their', 'having', 'was', 'made', 'were', 'which']
computer ['computer', 'information', 'systems', 'software', 'based', 'technology', 'uses', 'system', 'using', 'source']
language ['language', 'modern', 'languages', 'cul