In [19]:
import torch
from torch.autograd import Variable
import torch.functional as F
import torch.nn.functional as F
import numpy as np
import pandas as pd
from copy import deepcopy
from pprint import pprint

In [38]:
# Function defs

def tokenize(corpus : str) -> list:
    tokens = []
    for sentence in corpus:
        tokens.append(sentence.split())
    return tokens

def remove_stops(corpus, stop_words=["is", "a"]):
    c = []
    for line in corpus:
        s = ""
        for word in line.split():
            if word not in stop_words:
                s += word + " "
        c.append(s.strip())
    return c

In [4]:
def word2index(tokens):
    vocabulary = []
    for sentence in tokens:
        for token in sentence:
            if token not in vocabulary:
                vocabulary.append(token)
    word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}    
    return word2idx

In [5]:
def generate_center_context_pair(tokens, window: int) -> dict:
    pairs = dict()
    for row in tokens:
        for idx, center_word in enumerate(row):
            pairs.setdefault(center_word, [])
            for i in range(idx - window, idx + window + 1):
                if (i >= 0 and i != idx and i < len(row)):
                    pairs[center_word].append(row[i])
    return pairs

In [6]:
def get_idxpairs(cc_pair: dict, w2idx: list) -> list:
    """
    The generate_center_context_pair gives a dictionary like:
    {'center word 1': ['contextword1', 'contextword2', '...']
     'centerword2': ['contextword1', 'contextword2', '...']}
    But the code from the blog needs cc_pair like:
    [['centerword1', 'contextword1'],
     ['centerword1', 'contextword2'], ...]
    So this part changes from the former format to the latter
    """
    idx_pairs = []
    for center in cc_pair.keys():
        for context in cc_pair[center]:
            idx_pairs.append([w2idx[center], w2idx[context]])
    return idx_pairs

In [7]:
def generate_jdt(cc_pair: dict) -> list:
    jdt = []
    for center in cc_pair.keys():
        for context in cc_pair[center]:
            jdt.append([center, context])
    return jdt

In [8]:
def all_p_of_context_given_center(joint_distrib_table: pd.DataFrame):
    counts = joint_distrib_table.groupby(['center', 'context']).size()
    counts = counts.to_dict()

    # Denominator for the probability
    total = joint_distrib_table.groupby('center').size()
    total = total.to_dict()

    for center in total.keys():
        for k in list(counts.keys()):
            if k[0] is center:
                counts[k] = [counts[k]]
                counts[k].append(total[center])

    return counts

In [43]:
corpus = [
        "he is a king",
        "she is a queen",
        "he is a man",
        "she is a woman",
        "warsaw is poland capital",
        "berlin is germany capital",
        "paris is france capital",
        # "Sxi este juna kaj bela",
]

In [44]:

def get_input_layer(word_idx, vocab_size):
    x = torch.zeros(vocab_size).float()
    x[word_idx] = 1.0
    return x

In [45]:
def main():
    tokens = tokenize(corpus)
    vocabulary = set(sum(tokens, [])) # sum() flattens the 2d list
    vocab_size = len(vocabulary)
    cc_pair = generate_center_context_pair(tokens, 1)
    # pprint(cc_pair)

    word2idx = word2index(tokens)
    idx2word = {key: val for (val, key) in word2idx.items()}
    print(word2idx)
    print(idx2word)

    idx_pairs = get_idxpairs(cc_pair, word2idx)
    idx_pairs = np.array(idx_pairs)

    embedding_dims = 5
    W1 = Variable(torch.randn(embedding_dims, vocab_size).float(),
            requires_grad=True)
    W2 = Variable(torch.randn(vocab_size, embedding_dims).float(),
            requires_grad=True)
    max_iter = 200
    learning_rate = 0.001

    for i in range(max_iter):
        loss_val = 0
        for data, target in idx_pairs:
            x = Variable(get_input_layer(data, vocab_size)).float()
            y_true = Variable(torch.from_numpy(np.array([target])).long())

            z1 = torch.matmul(W1, x)
            z2 = torch.matmul(W2, z1)

            log_softmax = F.log_softmax(z2, dim=0)

            loss = F.nll_loss(log_softmax.view(1, -1), y_true)
            loss_val += loss.item()
            loss.backward()
            W1.data -= learning_rate * W1.grad.data
            W2.data -= learning_rate * W2.grad.data

            W1.grad.data.zero_()
            W2.grad.data.zero_()
        if i % 10 == 0:
            print(f"Loss at iter {i}: {loss_val/len(idx_pairs)}")

    # Lets see the word predictions for each word in our vocabulary
    for word in vocabulary:
        widx = word2idx[word]
        x = Variable(get_input_layer(widx, vocab_size)).float()
        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)

        softmax = F.softmax(z2, dim=0)
        max_arg = torch.argmax(softmax).item()
        pred_word = idx2word[max_arg]
        print(f"Center: {word} ; Context: {pred_word}")

if __name__ == "__main__":
    main()


{'he': 0, 'is': 1, 'a': 2, 'king': 3, 'she': 4, 'queen': 5, 'man': 6, 'woman': 7, 'warsaw': 8, 'poland': 9, 'capital': 10, 'berlin': 11, 'germany': 12, 'paris': 13, 'france': 14}
{0: 'he', 1: 'is', 2: 'a', 3: 'king', 4: 'she', 5: 'queen', 6: 'man', 7: 'woman', 8: 'warsaw', 9: 'poland', 10: 'capital', 11: 'berlin', 12: 'germany', 13: 'paris', 14: 'france'}
Loss at iter 0: 4.081058734939212
Loss at iter 10: 3.7357976039250693
Loss at iter 20: 3.4894746031079973
Loss at iter 30: 3.2941016469682967
Loss at iter 40: 3.1317113240559897
Loss at iter 50: 2.993488834017799
Loss at iter 60: 2.874097398349217
Loss at iter 70: 2.7697478163810003
Loss at iter 80: 2.677462305341448
Loss at iter 90: 2.59481250955945
Loss at iter 100: 2.519842803478241
Loss at iter 110: 2.4510283299854825
Loss at iter 120: 2.387219148022788
Loss at iter 130: 2.3275692803519115
Loss at iter 140: 2.271466456708454
Loss at iter 150: 2.218468490101042
Loss at iter 160: 2.168257568563734
Loss at iter 170: 2.120602181979588

In [12]:
t = tokenize(corpus)

In [13]:
print(t)

[['he', 'is', 'a', 'king'], ['she', 'is', 'a', 'queen'], ['he', 'is', 'a', 'man'], ['she', 'is', 'a', 'woman'], ['warsaw', 'is', 'poland', 'capital'], ['berlin', 'is', 'germany', 'capital'], ['paris', 'is', 'france', 'capital']]


In [17]:
t

[['he', 'king'],
 ['she', 'queen'],
 ['he', 'man'],
 ['she', 'woman'],
 ['warsaw', 'poland', 'capital'],
 ['berlin', 'germany', 'capital'],
 ['paris', 'france', 'capital']]

In [25]:
tokens = remove_stops(tokenize(corpus))
print(tokens)

[['he', 'king'], ['she', 'queen'], ['he', 'man'], ['she', 'woman'], ['warsaw', 'poland', 'capital'], ['berlin', 'germany', 'capital'], ['paris', 'france', 'capital']]
[['he', 'king'], ['she', 'queen'], ['he', 'man'], ['she', 'woman'], ['warsaw', 'poland', 'capital'], ['berlin', 'germany', 'capital'], ['paris', 'france', 'capital']]


In [37]:
print(remove_stops(corpus))

['he king ', 'she queen ', 'he man ', 'she woman ', 'warsaw poland capital ', 'berlin germany capital ', 'paris france capital ']
None
