In [0]:
!pip install -I --no-cache-dir pillow

In [0]:
import torch
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import collections #for counter
is_cuda = torch.cuda.is_available
tensor_to_numpy = lambda t:t.detach().cpu().numpy()

Get the data

In [0]:
!wget https://s3.amazonaws.com/video.udacity-data.com/topher/2018/October/5bbe6499_text8/text8.zip 
!unzip text8.zip

In [0]:
with open('text8','rb') as f:
    text = f.read()
text = text.lower()
text = text.split()


1. remove rare words
2. make mappings to and from integer


In [0]:
def remove_rare(text,thresh_freq=5):
    
    word_counts = collections.Counter(text)
    text = [w for w in text if word_counts[w] > thresh_freq]
    return text

text = remove_rare(text)
print(text[:100])


In [0]:
def word_integer_mappings(text):

    word_counts = collections.Counter(text)
    int_to_word = sorted(word_counts,
                         key=word_counts.get,
                        reverse=True)
#     int_to_word = list(word_counts.keys())
    word_to_int = {w:i for i,w in enumerate(int_to_word)}
    return int_to_word,word_to_int
int_to_word,word_to_int = word_integer_mappings(text)
print(int_to_word[:10],
     list(word_to_int.keys())[:10])

subsample extremely frequent words

In [0]:
def subsample(text,thresh=1e-5):
    nwords = len(text)
    word_counts = collections.Counter(text)
    subsampled = [w for w in text if np.random.random() < 1 - np.sqrt(thresh * nwords/word_counts[w])]
    return subsampled
print(len(text))
subsampled = subsample(text)
print(len(subsampled))

Make indexed representation for training words

In [0]:
train_words = [word_to_int[w] for w in subsampled]


1. Make contexts for every word 
2. make training batches

In [0]:
def make_target(text,
               idx,
               max_context_window = 5):
    context_window = np.random.randint(1,max_context_window+1)
    text_len = len(text)
#     print(idx,context_window,text_len)
    ix_min = max(idx-context_window,0)
    ix_max = min(idx+context_window,text_len)
    
    context = text[ix_min:idx] + text[idx+1:ix_max+1]
#     print(context.__len__())
    return list(context)

    

In [0]:
def make_batches(text,
                batch_size,
                max_context_window = 5):
    nwords = len(text)
    nbatches = (nwords + batch_size - 1)//batch_size

    for bix in range(nbatches):
        batchx = []
        batchy = []

        for idx in range(batch_size*bix,min(batch_size*(bix+1),nwords)):
#             print(idx)

            ctx = make_target(text,
                             idx,
                             max_context_window = max_context_window)
            batchy.extend(ctx)
            batchx.extend([text[idx]]*len(ctx))
        yield batchx,batchy
if True:
    ''' Testing the function '''
    int_text = [i for i in range(20)]
    x,y = next(make_batches(int_text, batch_size=4, max_context_window=5))
#     x,y = make_batches(int_text, batch_size=4, max_context_window=5)

    print('x\n', x)
    print('y\n', y)

Define the Negatively Sampled SkipGram Model

In [0]:
class SkipGramNeg(torch.nn.Module):
    def __init__(self,vocab_len,embed_len):
        super(SkipGramNeg,self).__init__()
        self.vocab_len = vocab_len
        self.embed_len = embed_len
        
        self.in_embed = torch.nn.Embedding(vocab_len,
                                           embed_len)
        
        self.out_embed = torch.nn.Embedding(vocab_len,
                                            embed_len)
        
        self.in_embed.weight.data.uniform_(-1,1)
        self.out_embed.weight.data.uniform_(-1,1)
    def forward_input(self,x):
        return self.in_embed(x)
    def forward_output(self,x):
        return self.out_embed(x)
    def forward_noise(self,batch_size,n_samples,word_probs = None):
        '''say n_samples = 10, batch_size =5
        will return 10 noise words per input word
        i.e. returns a 10x5 tensor'''
        if word_probs == None:
            word_probs = torch.ones(self.vocab_len,)
        noise_idx = torch.multinomial(word_probs,
                                       batch_size*n_samples,
                                       replacement=True)
        if is_cuda:
            noise_idx = noise_idx.cuda()
        noise_embed = self.out_embed(noise_idx).view(batch_size,n_samples,self.embed_len)
        return noise_embed
        
        

Training Loop

In [0]:
train_words.__len__()
# next(make_batches(train_words,batch_size))

In [0]:
embed_len = 300
model = SkipGramNeg(len(int_to_word),embed_len)


In [0]:
import tqdm
opt = torch.optim.Adam(model.parameters(),lr=3e-3)
if is_cuda:
    model.cuda()
    pass
nepochs = 10
batch_size = 512
nwords = len(train_words)
nbatches = (nwords + batch_size - 1)//batch_size

nnoise = 5
vis_every = nepochs - 1
nvalid = 16
trends = {'loss':[]}
for e in tqdm.tqdm_notebook(range(nepochs)):
    for bi,(bx,by) in enumerate(tqdm.tqdm_notebook(make_batches(train_words,batch_size),total = nbatches)):
        ''' Make the embeddings for the input, context and the noise '''
        
        bx,by = torch.LongTensor(bx),torch.LongTensor(by)
#         print(bi,
#               bx.shape)
        if is_cuda:
            bx = bx.cuda()
            by = by.cuda()
            pass
        bx_embed = model.forward_input(bx)
        by_embed = model.forward_output(by)
        
        noise_embed = model.forward_noise(batch_size=bx_embed.shape[0],n_samples=5)
        
        bx_embed = bx_embed.view(bx_embed.shape[0],embed_len)
        by_embed = by_embed.view(by_embed.shape[0],embed_len)
#         noise_embed = noise_embed.view(-1,embed_len)
        ''' Negative Sampling Loss '''
        ctx_dot = torch.einsum('ij,ij->i',[bx_embed,by_embed])
        noise_dot = torch.einsum('bij, bkj -> bik',[bx_embed.unsqueeze(1),noise_embed])
        noise_dot = noise_dot.unsqueeze(1)
        
        if 'check if the dot products are nan' and (torch.isnan(ctx_dot).sum() > 0 or torch.isnan(noise_dot).sum() > 0):
            import pdb;pdb.set_trace()

        if 'check for zeros in the dot product' and torch.any(ctx_dot == 0.) or torch.any(noise_dot == 0.):
            import pdb;pdb.set_trace()
        
        # ! IMPORTANT *.sigmoid().log() does not work as well, leading to Inf.
        # use torch.nn.functional.logsigmoid() instead
        
        log_sigmoid_ctx = torch.nn.functional.logsigmoid(ctx_dot)
        log_sigmoid_noise = torch.nn.functional.logsigmoid(noise_dot)
        
#         log_sigmoid_ctx[torch.isinf(log_sigmoid_ctx)] = 0.
#         log_sigmoid_noise[torch.isinf(log_sigmoid_noise)] = 0.
        
        pos = log_sigmoid_ctx
        neg = log_sigmoid_noise.sum(-1)
        loss = (pos+neg).mean()
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if 'check if grad is nan' and torch.isnan(model.in_embed.weight.grad).sum() > 0:
            import pdb;pdb.set_trace()
        
        trends['loss'].append(tensor_to_numpy(loss))
        
#         break
    if e%vis_every == 0:
        plt.figure()
        plt.plot(trends['loss'])
        plt.show()
        
        embed_vec = model.in_embed.weight
        embed_vec = embed_vec/embed_vec.pow(2).sum(-1).sqrt().unsqueeze(-1)
        valid_idx = np.random.randint(0,
                                      100,
                                      size=(nvalid//2,))
        valid_idx = np.append(valid_idx,
                          np.random.randint(1000,
                                            1000+100,
                                            size=(nvalid//2,)))
        valid_idx = torch.LongTensor(valid_idx)
        if is_cuda:
            valid_idx.cuda()
        valid_embed = embed_vec[valid_idx]
        similarities = torch.einsum('ij,kj->ik',[valid_embed,embed_vec])
        
#         valid_idx_,sim_ = tensor_to_numpy(valid_idx),tensor_to_numpy(similarities)
        for v,s in zip(valid_idx,
                       similarities):
            
            w = int_to_word[tensor_to_numpy(v)].decode('utf-8')
            
            top_s = s.topk(6)[-1]
            top_s = [int_to_word[s_i].decode('utf-8') for s_i in tensor_to_numpy(top_s)]
  
            
            print(w + ':' + ' '.join(top_s))
        print('-'*20)
        pass
    
#     break

In [0]:
noise_embed = model.forward_noise(batch_size=bx_embed.shape[0],n_samples=5)
bx_embed.unsqueeze(1).shape,noise_embed.shape

In [0]:
model.in_embed.weight