In [1]:
from collections import defaultdict
import math
import time
import random
import torch

In [6]:
class WordEmbSkip(torch.nn.Module):
    def __init__(self, nwords, emb_size):
        super(WordEmbSkip, self).__init__()

        """ word embeddings """
        self.word_embedding = torch.nn.Embedding(nwords, emb_size)
        # uniform initialization
        torch.nn.init.uniform_(self.word_embedding.weight, -0.25, 0.25)
        """ context embeddings"""
        self.context_embedding = torch.nn.Parameter(torch.randn(emb_size, nwords))

    def forward(self, word):
        embed_word = self.word_embedding(word)    # 1 * emb_size
        # (1, emb_size) * (emb_size, nwords) = (1, nwords)
        out = torch.mm(embed_word, self.context_embedding)  # 1 * nwords
        return out

In [7]:
N = 2  # length of window on each side (so N=2 gives a total window size of 5, as in t-2 t-1 t t+1 t+2)
EMB_SIZE = 128  # The size of the embedding

embeddings_location = "embeddings.txt"  # the file to write the word embeddings to
labels_location = "labels.txt"  # the file to write the labels to

# We reuse the data reading from the language modeling class
w2i = defaultdict(lambda: len(w2i))
S = w2i["<s>"]
UNK = w2i["<unk>"]

def read_dataset(filename):
    with open(filename, "r") as f:
        for line in f:
            yield [w2i[x] for x in line.strip().split(" ")]


# Read in the data
train = list(read_dataset("../data/ptb/train.txt"))
w2i = defaultdict(lambda: UNK, w2i)
dev = list(read_dataset("../data/ptb/valid.txt"))
i2w = {v: k for k, v in w2i.items()}
nwords = len(w2i)

import os 
if not os.path.exists(labels_location):
    with open(labels_location, 'w') as labels_file:
        for i in range(nwords):
            labels_file.write(i2w[i] + '\n')

In [9]:
# initialize the model
model = WordEmbSkip(nwords, EMB_SIZE)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

data_type = torch.LongTensor
use_cuda = torch.cuda.is_available()

if use_cuda:
    data_type = torch.cuda.LongTensor
    model.cuda()

### [Understanding nn.CrossEntropyLoss](https://discuss.pytorch.org/t/how-exactly-should-i-understand-the-cross-entropy-loss-function/61183/2)

In [10]:
def calc_sent_loss(sent):
    losses = []
    for i, word in enumerate(sent):
        for j in range(1, N+1):
            for direction in [-1, 1]:
                # c is Tensor for center word
                c = torch.tensor([word]).type(data_type)
                context_id = sent[i+direction*j] if 0 <= i+direction*j < len(sent) else S
                # context is Tensor for context word
                context = torch.tensor([context_id]).type(data_type)
                logits = model(c)
                
                # Predict context given center word
                loss = criterion(logits, context)
                losses.append(loss)
    return torch.stack(losses).sum()

In [11]:
len(dev)

3370

In [12]:
for ITER in range(1):
    print("started iter %r" % ITER)
    
    # Start training
    random.shuffle(train)
    train_words, train_loss = 0, 0.0
    start = time.time()
    model.train()
    
    for sent_id, sent in enumerate(train):
        my_loss = calc_sent_loss(sent)
        train_loss += my_loss.item()
        train_words += len(sent)
        
        # Take step after calculating loss for all words in sent
        optimizer.zero_grad()  # Zero the gradients 
        my_loss.backward()
        optimizer.step()
        
        if (sent_id+1) % 5000 == 0:
            print("--finished {} sentences".format(sent_id+1))
    
    print("iter {}: train loss/word={}, ppl={}, time={}".format(
        ITER,
        train_loss/train_words,
        math.exp(train_loss/train_words),
        time.time()-start))
    
    # Evaluate on dev set 
    dev_words, dev_loss = 0, 0.0
    start = time.time()
    model.eval() 
    for sent_id, sent in enumerate(dev):
        my_loss = calc_sent_loss(sent)
        dev_loss += my_loss.item()
        dev_words += len(sent)
    
    # Why 709?
    dev_ppl = float('inf') if dev_loss / dev_words > 709 else math.exp(dev_loss / dev_words)
    print("iter {}: dev loss/word={}, ppl={}, time={}".format(
        ITER,
        dev_loss/dev_words,
        dev_ppl,
        time.time()-start))

started iter 0
--finished 5000 sentences
--finished 10000 sentences
--finished 15000 sentences
--finished 20000 sentences
--finished 25000 sentences
--finished 30000 sentences
--finished 35000 sentences
--finished 40000 sentences
iter 0: train loss/word=nan, ppl=nan, time=2507.6852102279663
iter 0: train loss/word=nan, ppl=nan, time=51.093786001205444
