In [0]:
!pip install --no-cache-dir -I pillow

In [0]:
#get dataset
!wget https://s3.amazonaws.com/video.udacity-data.com/topher/2018/October/5bbe6499_text8/text8.zip
!unzip text8.zip
!ls

In [0]:
import os
with open('text8','r') as f:
    text = f.read()
print(text[:100])

In [0]:
import torch
import numpy as np
import matplotlib
matplotlib.rcParams['axes.grid'] = False
from matplotlib import pyplot as plt
is_cuda = torch.cuda.is_available
import tqdm
tensor_to_numpy = lambda t:t.detach().cpu().numpy()

In [0]:
d = {1:'a',
    2:'b'}
d,sorted(d,key=d.get)

In [0]:
import collections
def replace_punctuation_with_token(text):
    text = text.lower()
    tokens = {'.':'<PERIOD>',
             ',':'<COMMA>',
             '"':'<QUOTATION_MARK>',
             ';':'<SEMICOLON>',
             '!':'<EXCLAMATION_MARK>',
             '?':'<QUESTION_MARK>',
             '(':'<PAREN_OPEN>',
             ')':'<PAREN_CLOSE>',
             '--':'<DOUBLE_DASH>',
             ':':'<COLON>',
             }
    for p,t in tokens.items():
        text.replace(p,t)
    return text

def remove_rare_words(text):
    words = text.split()
    counts = collections.Counter(words)
    words = [w for w in words if counts[w] > 5]
    return words

def mappings(words):
    counts = collections.Counter(words)
    sorted_words = sorted(counts,
                    key=counts.get,
                   reverse=True)
    

    
    idx_to_word = sorted_words
    word_to_idx = {w:idx for idx,w in enumerate(sorted_words)}

    return word_to_idx,idx_to_word

text = replace_punctuation_with_token(text)
trimmed = remove_rare_words(text)
word_to_idx,idx_to_word = mappings(trimmed)

In [0]:
int_words = [word_to_idx[w] for w in trimmed]
total_words = len(int_words)

In [0]:
#subsampling the words, will remove a lot of occurances of the most frequent words like 'the' , 'a', 'an' etc.
thresh = 1e-5
word_counts = collections.Counter(int_words)
word_freq = {w:(c*1./total_words) for (w,c) in word_counts.items()}
discard_prob = {w: 1. - np.sqrt(thresh/f) for (w,f) in word_freq.items()}
tosses = np.random.random(size=(len(int_words),))
train_words = [w for w,t in zip(int_words,tosses) if t<(1. -discard_prob[w])]
# keep_prob = 1 - np.sqrt(thresh/word_counts)
print('check if all words remain'
    ,len(np.unique(int_words)),
    len(np.unique(train_words)))


In [0]:
def get_context(words,idx,window=5):
    ctx_len = np.random.randint(1,window+1)
    start = max(idx - ctx_len,0)
    end = min(start + ctx_len,len(words))
    context = words[start:idx] + words[idx+1:end]
    return list(context)

if False:
    idx = 15
    ctx = get_context(train_words,idx)
    print(idx_to_word[idx],[idx_to_word[i] for i in ctx])

In [0]:
import itertools
def get_batches(words,
               batch_size,
               window=5):
    n_batches = (len(words) + batch_size -1)//batch_size
    batch_idxs = range(n_batches)
#     print(int_words[:10])
    for i in batch_idxs:
        x = words[i*batch_size:(i+1)*batch_size]
        y = [get_context(words,xi) for xi in x]
        x = [ [xi]*len(yi) for xi,yi in zip(x,y)]
        x = list(itertools.chain(*x))
        y = list(itertools.chain(*y))
        if False:
            print(x,y)
            import pdb;pdb.set_trace()
        yield x,y

if False:
    int_text = [i for i in range(20)]
    x,y = next(get_batches(int_text, batch_size=4, window=5))

    print('x\n', x)
    print('y\n', y)

In [0]:
'''
Defining the model
'''
class SkipGram(torch.nn.Module):
    def __init__(self,n_vocab,n_embed):
        super(SkipGram,self).__init__()
        self.embed = torch.nn.Embedding(n_vocab,
                                        n_embed)
        self.output = torch.nn.Linear(n_embed,
                                    n_vocab)
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        
        pass
    def forward(self,x):
        embed_x = self.embed(x)
        scores_x = self.output(embed_x)
        log_softmax_x = self.log_softmax(scores_x)
        
        return log_softmax_x
    
    

In [0]:
embed_dim = 300
vocab_size = len(idx_to_word)
skipgram = SkipGram(vocab_size,
                    embed_dim)
v = torch.tensor([1]).long()

if is_cuda:
    v = v.cuda()
    skipgram.cuda()
skipgram(v)    

In [0]:
nepochs = 100
val_freq = 1
opt = torch.optim.Adam(skipgram.parameters(),lr=3e-3)
nvalid = 16
freq_window = 100
idx_to_word_np = np.array(idx_to_word)
for e in tqdm.tqdm_notebook(range(nepochs)):
    for w,ctx in get_batches(train_words,512):
        w,ctx = torch.LongTensor(w),torch.LongTensor(ctx)
        if is_cuda:
            w = w.cuda()
            ctx = ctx.cuda()
        log_preds = skipgram(w)
        loss = torch.nn.functional.nll_loss(log_preds,
                                               ctx)
        opt.zero_grad()
        loss.backward()
        opt.step()
#         break
        
    if e%val_freq == 0:
        embed_matrix = skipgram.embed.weight
        embed_mag = embed_matrix.pow(2).sum(1).sqrt().unsqueeze(0)
        
        valid_words_freq = np.random.choice(np.arange(freq_window),size=nvalid//2)
        valid_words_mid = np.random.choice(np.arange(1000,1000+freq_window),size=nvalid//2)
        valid_words = np.concatenate([valid_words_freq,
                                valid_words_mid])
        valid_words = torch.tensor(valid_words).long()
        if is_cuda:
            valid_words = valid_words.cuda()
        embed_of_valid = skipgram.embed(valid_words)
        valid_similarities = torch.mm(embed_of_valid,
                               embed_matrix.t())/embed_mag
        _,closest_idxs = valid_similarities.topk(6)
        valid_words,closest_idxs = tensor_to_numpy(valid_words),tensor_to_numpy(closest_idxs)
        for ii,(w,syn) in enumerate(zip(valid_words,closest_idxs)):
            
            closest_words = idx_to_word_np[syn]
            print(f'Closest word to {idx_to_word_np[w]}:{closest_words}')
        print('-'*50)
        pass

In [0]:
idx_to_word[]