In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np

In [2]:
WINDOW_SIZE, K = 3, 100
VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE = 30000, 100, 100 
BATCH_SIZE = 128

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

vocab_file = 'text8'
#prepare vocabulary
def tokenize(text):
    return text.split(' ')

text = None
with open(vocab_file) as fr:
    text = fr.readlines()[0]
    token_list = tokenize(text)
vocab =  Counter(token_list).most_common(VOCAB_SIZE - 1)
idx_to_word = [item[0] for item in vocab]
idx_to_word.append('UNK')
word_to_idx = {item: i for i, item in enumerate(idx_to_word)}
word_counts = [item[1] for item in vocab]
word_counts.append(len(text) - np.sum(word_counts))
frequence = word_counts / np.sum(word_counts)
frequence = frequence ** (3 / 4)
frequence = frequence / np.sum(frequence)

In [7]:
#data loader
class MyDataset(tud.Dataset):
    def __init__(self, text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence):
        super(MyDataset, self).__init__()
        self.text = text
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        self.window_size = WINDOW_SIZE
        self.k = K
        self.device = device
        self.frequence = torch.FloatTensor(frequence)
        self.word_encode = torch.LongTensor([self.word_to_idx.get(word, self.word_to_idx['UNK']) \
                                             for word in self.text])
        
        
    def __len__(self):
        return len(self.word_encode)
    
    def __getitem__(self, idx):
        center_word = self.word_encode[idx]
        pos_index = list(range(idx - self.window_size, idx)) + list(range(idx + 1, idx + self.window_size + 1))
        pos_index = [idx % len(self.word_encode) for idx in pos_index]
        pos_words = self.word_encode[pos_index]
        neg_words = torch.multinomial(self.frequence, self.k * self.window_size * 2, replacement = True)
        return center_word, pos_words, neg_words

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = MyDataset(text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence)  
dataloader = tud.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)  
for i, (center_word, pos_words, neg_words) in enumerate(dataloader):
    print("iter: {}".format(i))
    print(center_word.shape)
    print(pos_words.shape)
    print(neg_words.shape)
    break

iter: 0
torch.Size([128])
torch.Size([128, 6])
torch.Size([128, 600])


In [3]:
#model
class MyModel(nn.Module):
    def __init__(self, VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE):
        super(MyModel, self).__init__()
        self.in_embedding = nn.Embedding(VOCAB_SIZE, IN_EMBED_SIZE)
        initrange = 0.5 / IN_EMBED_SIZE
        self.in_embedding.weight.data.uniform_(-initrange, initrange)
        self.out_embedding = nn.Embedding(VOCAB_SIZE, OUT_ENBED_SIZE)
        self.out_embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, center_word, pos_words, neg_words):
        center_embeddings = self.in_embedding(center_word)  #[batch-size, embed_size]
        pos_embeddings = self.in_embedding(pos_words)      #[batch-size, window_size * 2, embed_size]
        neg_embeddings = self.out_embedding(neg_words)      #[batch-size, window_size * 2 * K, embed_size]
        center_embeddings = center_embeddings.unsqueeze(2)                     #[batch-size, embed_size, 1]
        
        pos_dot = pos_embeddings
        
        pos_dot = torch.bmm(pos_embeddings, center_embeddings).squeeze(2)  #[batch-size, window_size * 2]
        neg_dot = torch.bmm(neg_embeddings, -center_embeddings).squeeze(2)  #[batch-size, window_size * 2 * K]
#         print("shape of pos_dot: {}".format(pos_dot.shape))
#         print("shape of neg_dot: {}".format(neg_dot.shape))
        
        pos_loss = F.logsigmoid(pos_dot).sum(1)
        neg_loss = F.logsigmoid(neg_dot).sum(1)
        return - (pos_loss + neg_loss)
    
    def input_embeddings(self):
        return self.in_embedding.data.cpu().numpy()

In [4]:
model = MyModel(len(idx_to_word), IN_EMBED_SIZE, OUT_ENBED_SIZE)
learning_rate = 4e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [6]:
optimizer.param_groups

[{'params': [Parameter containing:
   tensor([[-4.0488e-03, -2.0225e-03,  2.6374e-03,  ..., -1.8155e-03,
            -3.1009e-03,  1.0922e-03],
           [-3.6676e-03,  2.2173e-03,  3.4155e-03,  ..., -2.1890e-03,
             1.0874e-03,  3.0897e-03],
           [ 1.0879e-03, -2.6146e-03, -4.3086e-04,  ..., -3.7155e-03,
             3.4755e-03, -3.5009e-03],
           ...,
           [ 3.4506e-03,  4.5707e-03, -4.9445e-03,  ...,  2.0560e-04,
            -1.7338e-03,  3.9236e-03],
           [-3.4154e-03, -9.6818e-04, -4.0330e-03,  ..., -1.3523e-03,
             2.4831e-05,  2.2506e-03],
           [ 3.7988e-03, -4.3737e-03, -9.2461e-04,  ..., -3.8285e-03,
             3.5673e-03, -3.3511e-03]], requires_grad=True),
   Parameter containing:
   tensor([[ 2.2757e-04, -3.1509e-03, -2.5605e-03,  ...,  4.8006e-03,
            -3.3570e-03, -4.6002e-03],
           [ 3.2799e-03, -1.2792e-03, -3.6907e-03,  ...,  4.4902e-03,
            -2.9896e-03, -3.6164e-03],
           [ 2.4741e-03,  3.21

In [None]:
#train
model = MyModel(len(idx_to_word), IN_EMBED_SIZE, OUT_ENBED_SIZE)
learning_rate = 4e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

for epoch in range(2):
    for iter, (center_word, pos_words, neg_words) in enumerate(dataloader):
        loss = model(center_word, pos_words, neg_words).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if iter % 1000 == 0:
            print('epoch: {}, iter: {}, loss: {}'.format(epoch, iter, loss))

input_embedding = model.input_embeddings()
with open('input_embedding', 'w') as fw:
    fw.writelines(input_embedding)
            

epoch: 0, iter: 0, loss: 420.0423889160156
epoch: 0, iter: 1000, loss: 4.5623908042907715
epoch: 0, iter: 2000, loss: 0.971217155456543
epoch: 0, iter: 3000, loss: 0.34077128767967224
epoch: 0, iter: 4000, loss: 0.16006360948085785
epoch: 0, iter: 5000, loss: 0.07626575231552124
epoch: 0, iter: 6000, loss: 0.04320208728313446
epoch: 0, iter: 7000, loss: 0.02256978675723076
epoch: 0, iter: 8000, loss: 0.012630621902644634
epoch: 0, iter: 9000, loss: 0.007155262865126133
epoch: 0, iter: 10000, loss: 0.003934438340365887
epoch: 0, iter: 11000, loss: 0.0023107498418539762
epoch: 0, iter: 12000, loss: 0.0013333435636013746
epoch: 0, iter: 13000, loss: 0.0007793003460392356
epoch: 0, iter: 14000, loss: 0.0004605937283486128
epoch: 0, iter: 15000, loss: 0.0002624643675517291
epoch: 0, iter: 16000, loss: 0.0001480951759731397
epoch: 0, iter: 17000, loss: 9.69022439676337e-05
epoch: 0, iter: 18000, loss: 4.51812484243419e-05
epoch: 0, iter: 19000, loss: 1.8502583770896308e-05
epoch: 0, iter: 20