In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/enwik8/enwik8


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import Counter
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

In [3]:
data = open("/kaggle/input/enwik8/enwik8").read()[:-1].strip().lower()

In [4]:
words = data.split()
vocab = list(set(words))
len(vocab), len(words)

(253854, 17005206)

In [5]:
class Vectorizer:
    def __init__(self):
        self.data = None
        self.vocab = None
        self.vlen = None
        self.int_words = None
        self.word_counts = None
        self.word_to_int = None
        self.int_to_word = None 
        self.freqs = None
        
    def fit(self, data): 
        self.data = data
        words = data.split()
        word_counts = Counter(words)
        words = [w for w in words if word_counts[w] >= 5]
        self.vocab = list(set(words))
        self.vlen = len(self.vocab)
        self.word_to_int = {w : i for i, w in enumerate(self.vocab)}
        self.int_to_word = {i: w for i, w in enumerate(self.vocab)}
        self.int_words = [self.word_to_int[w] for w in words]
        self.word_counts = Counter(self.int_words)
        self.freqs = {k : v/len(self.int_words) for k, v in self.word_counts.items()}
        
    
    def enc(self, word): return self.word_to_int[word]
    
    def dec(self, idx): return self.int_to_word[idx]

In [6]:
V = Vectorizer()
V.fit(data)

In [7]:
V.vlen, len(V.int_words)

(71290, 16718843)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [9]:
#  p_discard = {k: (1 - np.sqrt(1e-4/v)) for k, v in V.freqs.items()}
# p_discard[V.enc('a')]
# a = [1,2,3,4]

In [10]:
def build_train_data():
    int_train = []
    t = 1e-5
    p_discard = {k: (1 - np.sqrt(t/v)) for k, v in V.freqs.items()}
    for int_word in V.int_words:
        rprob = np.random.random()
        if rprob < p_discard[int_word]: 
            continue
        int_train.append(int_word)
    print(len(int_train))
    return int_train

In [11]:
int_train = build_train_data()
tensor_train = torch.tensor(int_train).to(device)

4669215


In [12]:
def get_contexts(idx, k=5): 
    l = max(0, idx-k)
    u = idx+k
    ctx = torch.cat((tensor_train[l:idx],tensor_train[idx+1:u+1]),0)
    return ctx
    

In [13]:
def get_batch(batch_ids):
    target = tensor_train[batch_ids]
    pos = []
    res = []
    for idx in batch_ids:
        c = get_contexts(idx)
        res.append(c.shape[0])
        pos.extend(c)
    res = torch.tensor(res).to(device)
    target = torch.repeat_interleave(target, res)
    pos = torch.stack(pos).to(device)
    return target, pos

In [14]:
# get_batch(torch.tensor([0,1,2]))

In [15]:
class SkipGramNS(nn.Module):
    def __init__(self, n_vocab, n_embed):
        super().__init__()
        self.n_vocab = n_vocab
        self.n_embed = n_embed
        self.V = nn.Embedding(n_vocab, n_embed)        
        self.U = nn.Embedding(n_vocab, n_embed)
        self.V.weight.data.uniform_(-1,1)
        self.U.weight.data.uniform_(-1,1)
        
    def get_target(self, target):
        return self.V(target)
    
    def get_pos_context(self, pos):
        return self.U(pos)
    
    def get_neg_context(self, batch_size, num_samples, noise_dist):
        neg_samples = torch.multinomial(noise_dist, batch_size*num_samples).to(device)
        neg_context = self.U(neg_samples)
        neg_context = neg_context.view(batch_size, num_samples, self.n_embed)
        return neg_context
        

In [16]:
class NegativeSamplingLoss():
    def __init__(self):
        pass
    
    def __call__(self, target, pos, neg):
        bs, n_embed = pos.shape[0], pos.shape[1]
        target = target.view(bs, n_embed, 1)
        pos = pos.view(bs,1, n_embed)
        
        ploss = torch.bmm(pos, target).squeeze()
        ploss = F.logsigmoid(ploss)
        
        nloss = torch.bmm(neg.neg(), target).squeeze()
        nloss = F.logsigmoid(nloss).sum(1)
        loss = (ploss + nloss).mean()
        return -loss

In [17]:
# testing nll
# x = torch.tensor([[1.,2,3,4],[4,5,6,7],[7,8,9,1]])
# n = torch.tensor([[[1.,2,3,4],[1,1,1,1]],[[1,2,3,4],[1,2,3,4]],[[1,2,3,4],[1,2,3,4]]])
# p = torch.tensor([[1.,1,1,1],[4,5,6,7],[0,0,1,0]])
# nll = NegativeSamplingLoss()
# nll(x, p, n)

In [21]:
# noise_dist
arr_freq = list(sorted(V.freqs.items(), key=lambda x: x[0]))
unifreq = np.array([b for a, b in arr_freq])
noise_dist = torch.from_numpy(unifreq**0.75)

In [22]:
ds_index = TensorDataset(torch.arange(tensor_train.shape[0]))
dl_index = DataLoader(ds_index, 512)

In [23]:
# model
model = SkipGramNS(V.vlen, 300)
model.to(device)

SkipGramNS(
  (V): Embedding(71290, 300)
  (U): Embedding(71290, 300)
)

In [24]:
criterion = NegativeSamplingLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

In [25]:
# hyperparameters
epochs = 5
step = 0
stopt = 1500

In [26]:
%%time
#train model
for epoch in range(epochs):
    model.train()
    
    for batch_id in dl_index:
        batch_id = batch_id[0]
        target_id, pos_id = get_batch(batch_id)
        
        target = model.get_target(target_id)
        pos = model.get_pos_context(pos_id)
        neg = model.get_neg_context(target_id.shape[0], 5, noise_dist)
        loss = criterion(target, pos, neg)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if(step % stopt == 0):
            print(f'epoch: {epoch} | step : {step} | loss: {loss.item():.5f}')
        step+=1

            

epoch: 0 | step : 0 | loss: 14.47642
epoch: 0 | step : 1500 | loss: 7.80835
epoch: 0 | step : 3000 | loss: 4.47510
epoch: 0 | step : 4500 | loss: 4.28944
epoch: 0 | step : 6000 | loss: 4.21582
epoch: 0 | step : 7500 | loss: 3.32213
epoch: 0 | step : 9000 | loss: 2.44685
epoch: 1 | step : 10500 | loss: 2.52549
epoch: 1 | step : 12000 | loss: 2.13516
epoch: 1 | step : 13500 | loss: 2.43689
epoch: 1 | step : 15000 | loss: 2.52590
epoch: 1 | step : 16500 | loss: 2.51305
epoch: 1 | step : 18000 | loss: 2.39601
epoch: 2 | step : 19500 | loss: 2.67092
epoch: 2 | step : 21000 | loss: 2.61265
epoch: 2 | step : 22500 | loss: 2.48848
epoch: 2 | step : 24000 | loss: 2.28900
epoch: 2 | step : 25500 | loss: 2.25769
epoch: 2 | step : 27000 | loss: 2.23295
epoch: 3 | step : 28500 | loss: 2.17622
epoch: 3 | step : 30000 | loss: 2.01068
epoch: 3 | step : 31500 | loss: 2.09654
epoch: 3 | step : 33000 | loss: 2.04180
epoch: 3 | step : 34500 | loss: 2.22053
epoch: 3 | step : 36000 | loss: 2.17948
epoch: 4 

In [171]:
#plots and testing
model.eval()

DataParallel(
  (module): SkipGramNS(
    (V): Embedding(253854, 300)
    (U): Embedding(253854, 300)
  )
)

In [27]:
torch.save(model.state_dict(), '/kaggle/working/modelx1.pt')
# torch.save(embs, '/kaggle/working/word_embeddings_v1')

In [28]:
embs = model.V.weight.to('cpu')

In [29]:
@torch.no_grad()
def get_nearest(words, k = 10):
    nearest = {w: [] for w in words}
    for w in V.vocab:
        uu = embs[V.enc(w)]
        for w2 in words:
            vv = embs[V.enc(w2)]
            csm = torch.nn.functional.cosine_similarity(uu,vv,0)    
            nearest[w2].append((w, csm.item()))
    for w2 in words:
        nearest[w2] = dict(sorted(nearest[w2], reverse=True, key = lambda x: x[1])[:k])
    return nearest

In [51]:
def print_dicts(d):
    for k, v in d.items():
        print(f'{k.upper()}:')
        print(*list(v.keys()), sep = ', ')
        print()

In [56]:
%%time
# nearest words
d = get_nearest(['war', 'art', 'happiness', 'depression', 'dog', 'apple'], 20)
print_dicts(d)

WAR:
war, fought, troops, army, allies, allied, battle, military, wars, occupation, civil, german, germany, forces, wwii, casualties, uprising, soviet, communist, generals

ART:
art, paintings, gallery, painting, arts, museum, sculpture, painters, artist, styles, exhibitions, practitioners, artists, renaissance, martial, artistic, impressionism, exhibition, drawings, artwork

HAPPINESS:
happiness, minds, deceiving, god, pursuit, moral, pleasure, soul, compassion, immortality, thinking, goodness, ultimate, salvation, truly, perfection, eternal, deeds, redemption, divine

DEPRESSION:
depression, bipolar, schizophrenia, mania, depressed, anxiety, severe, manic, irritability, psychosis, symptoms, anticonvulsants, antidepressant, depressions, psychotic, depressive, vomiting, chronic, zyprexa, traumatic

DOG:
dog, dogs, cat, hound, breeds, breed, greyhound, hounds, terrier, scent, plural, animal, hairless, slang, animals, eat, hunting, ai, smell, cats

APPLE:
apple, macintosh, mac, os, windo