In [23]:
import nltk
from nltk.corpus import reuters

import numpy as np
import random
import re
import time

import torch
import torch.nn as nn
import torch.optim as optim

from scipy.stats import spearmanr
from collections import Counter

In [24]:
def build_corpus():
    corpus = []
    for id in reuters.fileids():
        words = reuters.words(id)
        words = [word.lower() for word in words if word.isalpha()]
        corpus.append(words)

    return corpus


corpus = build_corpus()
len(corpus)  # 10788 sentences

10788

In [25]:
def build_vocab(corpus):
    # Flatten words from all sentences
    flatten_words = [word for sentence in corpus for word in sentence]
    word_counts = Counter(flatten_words)

    # Unique words
    vocab = list(set(flatten_words))
    vocab.append("<UNKNOWN>")

    word2index = {word: index for index, word in enumerate(vocab)}
    # word2index['<UNKNOWN>'] = 0
    # word2index

    return vocab, len(vocab), word2index, word_counts


vocab, vocab_size, word2index, word_counts = build_vocab(corpus)
vocab_size  # 29174 unique words

29174

In [26]:
def build_skipgrams(corpus, word2index, window_size=2):
    skip_grams = []

    for sentence in corpus:
        for position, center_word in enumerate(sentence):
            center_index = word2index[center_word]
            context_indices = list(
                [
                    i
                    for i in range(
                        max(position - window_size, 0),
                        min(position + window_size + 1, len(sentence)),
                    )
                    if i != position
                ]
            )
            for index in context_indices:
                context_word = sentence[index]
                context_index = word2index[context_word]
                skip_grams.append((center_index, context_index))

    return skip_grams


skip_grams = build_skipgrams(corpus, word2index, window_size=2)
len(skip_grams)  # 5243836 pairs

5243836

In [27]:
def build_unigram_table(word_counts, power=0.75):
    total_count = sum([count for count in word_counts.values()])
    unigram_table = []
    for word, count in word_counts.items():
        score = (count / total_count) ** power
        unigram_table.extend([word] * int(score * 1e6))
    return unigram_table


unigram_table = build_unigram_table(word_counts)

In [28]:
class Skipgram(nn.Module):

    def __init__(self, vocab_size, embed_size, mode="softmax"):
        super(Skipgram, self).__init__()

        self.mode = mode
        self.embedding_v = nn.Embedding(vocab_size, embed_size)
        self.embedding_u = nn.Embedding(vocab_size, embed_size)

    def forward(self, center_words, next_words, all_vocabs):
        center_embeds = self.embedding_v(center_words)
        next_embeds = self.embedding_u(next_words)
        all_embeds = self.embedding_u(all_vocabs)

        scores = next_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2)

        norm_scores = all_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2)

        negative_log_likelihood = (-1) * (
            torch.mean(
                torch.log(torch.exp(scores) / torch.sum(torch.exp(norm_scores), 1).unsqueeze(1))
            )
        )

        return negative_log_likelihood