In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random

In [3]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [4]:
raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells."

In [5]:
def clean_text(text):
    text = text.replace(',', '').replace('.','')
    return text

In [6]:
text = list(set(clean_text(raw_text).split()))

In [26]:
def build_vocab(words, window=3):
    start_token = '<'
    end_token = '>'
    
    ngram_list = []
    all_tokens = []
    for word in words:
        word = start_token + word + end_token
        ngrams = []
        for ix in range(0, len(word) - window//2 - 1):
            ngram = word[ix:ix+window]
            ngrams.append(ngram)
            all_tokens.append(ngram)
        ngram_list.append(ngrams)
    
    return ngram_list, all_tokens

In [27]:
ngram_list, all_tokens = build_vocab(text)

In [28]:
print(ngram_list)

[['<sp', 'spe', 'pel', 'ell', 'lls', 'ls>'], ['<As', 'As>'], ['<Th', 'The', 'he>'], ['<co', 'com', 'omp', 'mpu', 'put', 'ute', 'ter', 'ers', 'rs>'], ['<pr', 'pro', 'rog', 'ogr', 'gra', 'ram', 'am>'], ['<th', 'the', 'hey', 'ey>'], ['<is', 'is>'], ['<In', 'In>'], ['<ru', 'rul', 'ule', 'les', 'es>'], ['<th', 'the', 'he>'], ['<id', 'ide', 'dea', 'ea>'], ['<sp', 'spi', 'pir', 'iri', 'rit', 'its', 'ts>'], ['<ab', 'abo', 'bou', 'out', 'ut>'], ['<We', 'We>'], ['<ma', 'man', 'ani', 'nip', 'ipu', 'pul', 'ula', 'lat', 'ate', 'te>'], ['<ab', 'abs', 'bst', 'str', 'tra', 'rac', 'act', 'ct>'], ['<ev', 'evo', 'vol', 'olv', 'lve', 've>'], ['<by', 'by>'], ['<to', 'to>'], ['<ef', 'eff', 'ffe', 'fec', 'ect', 'ct>'], ['<co', 'com', 'omp', 'mpu', 'put', 'ute', 'ter', 'er>'], ['<a>'], ['<in', 'inh', 'nha', 'hab', 'abi', 'bit', 'it>'], ['<th', 'thi', 'hin', 'ing', 'ngs', 'gs>'], ['<pr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ss>'], ['<ca', 'cal', 'all', 'lle', 'led', 'ed>'], ['<di', 'dir', 'ire', 'rec', 'ect', '

In [30]:
print(all_tokens)

['<sp', 'spe', 'pel', 'ell', 'lls', 'ls>', '<As', 'As>', '<Th', 'The', 'he>', '<co', 'com', 'omp', 'mpu', 'put', 'ute', 'ter', 'ers', 'rs>', '<pr', 'pro', 'rog', 'ogr', 'gra', 'ram', 'am>', '<th', 'the', 'hey', 'ey>', '<is', 'is>', '<In', 'In>', '<ru', 'rul', 'ule', 'les', 'es>', '<th', 'the', 'he>', '<id', 'ide', 'dea', 'ea>', '<sp', 'spi', 'pir', 'iri', 'rit', 'its', 'ts>', '<ab', 'abo', 'bou', 'out', 'ut>', '<We', 'We>', '<ma', 'man', 'ani', 'nip', 'ipu', 'pul', 'ula', 'lat', 'ate', 'te>', '<ab', 'abs', 'bst', 'str', 'tra', 'rac', 'act', 'ct>', '<ev', 'evo', 'vol', 'olv', 'lve', 've>', '<by', 'by>', '<to', 'to>', '<ef', 'eff', 'ffe', 'fec', 'ect', 'ct>', '<co', 'com', 'omp', 'mpu', 'put', 'ute', 'ter', 'er>', '<a>', '<in', 'inh', 'nha', 'hab', 'abi', 'bit', 'it>', '<th', 'thi', 'hin', 'ing', 'ngs', 'gs>', '<pr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ss>', '<ca', 'cal', 'all', 'lle', 'led', 'ed>', '<di', 'dir', 'ire', 'rec', 'ect', 'cte', 'ted', 'ed>', '<ou', 'our', 'ur>', '<pr', 'pro'

In [31]:
id2word = dict()
word2id = dict()
for i, word in enumerate(list(set(all_tokens))):
    id2word[i] = word
    word2id[word] = i

In [32]:
def skipgram(sentence, window_size=2, neg_samples=5, raw_sentence=False):
    pairs = []
    if raw_sentence:
        sentence = sentence.lower().split()
    
    for i, word in enumerate(sentence):
        cnt = 0
        for j in range(-window_size, window_size+1):
            if j != 0 and (i+j) >= 0 and (i+j) < len(sentence):
                pairs.append((sentence[i], sentence[i+j], 1))
                cnt += 1
                
        #NAIVE negative sampling
        for _ in range(neg_samples):
            ran_num = random.randint(0, len(word2id)-1)
            while ran_num == word2id[word]:
                ran_num = random.randint(0, len(word2id)-1)
            neg_sample = id2word[ran_num]
            pairs.append((sentence[i], neg_sample, 0))

    return pairs

In [35]:
pairs = skipgram(all_tokens)

In [38]:
def pair_to_input(pairs, id2word, word2id):
    center_ix = []
    context_ix = []
    targets = []
    for pair in pairs:
        center_ix.append(word2id[pair[0]])
        context_ix.append(word2id[pair[1]])
        targets.append(pair[2])
    
    return center_ix, context_ix, targets

In [39]:
center_ix, context_ix, targets = pair_to_input(pairs, id2word, word2id)

In [37]:
class Skipgram(nn.Module):
    
    def __init__(self, vocab_size, embed_dim):
        super(Skipgram, self).__init__()
        self.center_emb = nn.Embedding(vocab_size, embed_dim)
        self.context_emb = nn.Embedding(vocab_size, embed_dim)
        
    def forward(self, u_ix, v_ix):
        
        u = self.context_emb(u_ix).view(1,-1)
        v = self.center_emb(v_ix).view(1,-1)
        score = torch.mm(u, v.transpose(1,0))
        
        return torch.sigmoid(score)

In [40]:
num_epoch = 30

In [41]:
model = Skipgram(len(word2id), 100)
model.to(device)

Skipgram(
  (center_emb): Embedding(191, 100)
  (context_emb): Embedding(191, 100)
)

In [46]:
def train():
    
    optimizer = optim.SGD(lr=1e-3, params=model.parameters())
    
    for epoch in range(num_epoch):
        total_loss = 0.0
        for i in range(len(pairs)):
            center = torch.tensor(center_ix[i]).to(device).long()
            context = torch.tensor(context_ix[i]).to(device).long()
            target = torch.tensor(targets[i]).view(1,-1).to(device).float()
            
            out = model.forward(context, center)
            loss = F.binary_cross_entropy(out, target)
            total_loss += loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #if (i+1) % 300 == 0:
            #    print('Epoch %d  %d steps, loss : %0.4f' %(epoch+1, i+1, loss))
        print('Epoch %d | loss : %0.4f' %(epoch+1, total_loss / i))

In [47]:
train()

Epoch 1 | loss : 1.6873
Epoch 2 | loss : 1.6495
Epoch 3 | loss : 1.6125
Epoch 4 | loss : 1.5770
Epoch 5 | loss : 1.5426
Epoch 6 | loss : 1.5093
Epoch 7 | loss : 1.4772
Epoch 8 | loss : 1.4463
Epoch 9 | loss : 1.4165
Epoch 10 | loss : 1.3878
Epoch 11 | loss : 1.3501
Epoch 12 | loss : 1.3233
Epoch 13 | loss : 1.2975
Epoch 14 | loss : 1.2726
Epoch 15 | loss : 1.2487
Epoch 16 | loss : 1.2252
Epoch 17 | loss : 1.2026
Epoch 18 | loss : 1.1808
Epoch 19 | loss : 1.1599
Epoch 20 | loss : 1.1396
Epoch 21 | loss : 1.1201
Epoch 22 | loss : 1.1014
Epoch 23 | loss : 1.0833
Epoch 24 | loss : 1.0659
Epoch 25 | loss : 1.0491
Epoch 26 | loss : 1.0331
Epoch 27 | loss : 1.0177
Epoch 28 | loss : 1.0029
Epoch 29 | loss : 0.9887
Epoch 30 | loss : 0.9751
