In [107]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np

In [110]:
WINDOW_SIZE, K = 2, 10
VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE = 2000, 512, 512 
BATCH_SIZE = 32

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

vocab_file = 'text8'
#prepare vocabulary
def tokenize(text):
    return text.split(' ')

text = None
with open(vocab_file) as fr:
    text = fr.readlines()[0]
    token_list = tokenize(text)
vocab =  Counter(token_list).most_common(VOCAB_SIZE - 1)
idx_to_word = [item[0] for item in vocab]
idx_to_word.append('UNK')
word_to_idx = {item: i for i, item in enumerate(idx_to_word)}
word_counts = [item[1] for item in vocab]
word_counts.append(len(text) - np.sum(word_counts))
frequence = word_counts / np.sum(word_counts)
frequence = frequence ** (3 / 4)
frequence = frequence / np.sum(frequence)

In [111]:
#data loader
class MyDataset(tud.Dataset):
    def __init__(self, text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence):
        super(MyDataset, self).__init__()
        self.text = text
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        self.window_size = WINDOW_SIZE
        self.k = K
        self.device = device
        self.frequence = torch.FloatTensor(frequence)
        self.word_encode = torch.LongTensor([self.word_to_idx.get(word, self.word_to_idx['UNK']) \
                                             for word in self.text])
        
        
    def __len__(self):
        return len(self.word_encode)
    
    def __getitem__(self, idx):
        center_word = self.word_encode[idx]
        pos_index = list(range(idx - self.window_size, idx)) + list(range(idx + 1, idx + self.window_size + 1))
        pos_index = [idx % len(self.word_encode) for idx in pos_index]
        pos_words = self.word_encode[pos_index]
        neg_words = torch.multinomial(self.frequence, self.k * self.window_size * 2, replacement = True)
        return center_word, pos_words, neg_words

In [112]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = MyDataset(text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence)  
dataloader = tud.DataLoader(dataset, batch_size = BATCH_SIZE,shuffle = True)  
for i, (center_word, pos_words, neg_words) in enumerate(dataloader):
    print("iter: {}".format(i))
    print(center_word.shape)
    print(pos_words.shape)
    print(neg_words.shape)
    break

iter: 0
torch.Size([32])
torch.Size([32, 4])
torch.Size([32, 40])


In [113]:
#model
class MyModel(nn.Module):
    def __init__(self, VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE):
        super(MyModel, self).__init__()
        self.in_embedding = nn.Embedding(VOCAB_SIZE, IN_EMBED_SIZE)
        initrange = 0.5 / IN_EMBED_SIZE
        self.in_embedding.weight.data.uniform_(-initrange, initrange)
        self.out_embedding = nn.Embedding(VOCAB_SIZE, OUT_ENBED_SIZE)
        self.out_embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, center_word, pos_words, neg_words):
        center_embeddings = self.in_embedding(center_word)  #[batch-size, embed_size]
        pos_embeddings = self.in_embedding(pos_words)      #[batch-size, window_size * 2, embed_size]
        neg_embeddings = self.out_embedding(neg_words)      #[batch-size, window_size * 2 * K, embed_size]
        center_embeddings = center_embeddings.unsqueeze(2)                     #[batch-size, embed_size, 1]
        
        pos_dot = pos_embeddings
        
        pos_dot = torch.bmm(pos_embeddings, center_embeddings).squeeze(2)  #[batch-size, window_size * 2]
        neg_dot = torch.bmm(neg_embeddings, -center_embeddings).squeeze(2)  #[batch-size, window_size * 2 * K]
#         print("shape of pos_dot: {}".format(pos_dot.shape))
#         print("shape of neg_dot: {}".format(neg_dot.shape))
        
        pos_loss = F.logsigmoid(pos_dot).sum(1)
        neg_loss = F.logsigmoid(neg_dot).sum(1)
        return pos_loss + neg_loss

In [114]:
#train
model = MyModel(len(idx_to_word), IN_EMBED_SIZE, OUT_ENBED_SIZE)
learning_rate = 4e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

for epoch in range(100):
    for iter, (center_word, pos_words, neg_words) in enumerate(dataloader):
        loss = model(center_word, pos_words, neg_words).mean()
        print('epoch: {}, iter: {}, loss: {}'.format(epoch, iter, loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

epoch: 0, iter: 0, loss: -294.1334228515625
epoch: 0, iter: 1, loss: -300.57476806640625
epoch: 0, iter: 2, loss: -336.6136169433594
epoch: 0, iter: 3, loss: -338.8317565917969
epoch: 0, iter: 4, loss: -359.3650207519531
epoch: 0, iter: 5, loss: -318.126708984375
epoch: 0, iter: 6, loss: -350.65216064453125
epoch: 0, iter: 7, loss: -332.4310302734375
epoch: 0, iter: 8, loss: -393.54791259765625
epoch: 0, iter: 9, loss: -295.2746276855469
epoch: 0, iter: 10, loss: -292.5108947753906
epoch: 0, iter: 11, loss: -339.0655517578125
epoch: 0, iter: 12, loss: -314.77655029296875
epoch: 0, iter: 13, loss: -303.7771301269531
epoch: 0, iter: 14, loss: -328.57135009765625
epoch: 0, iter: 15, loss: -359.9826965332031
epoch: 0, iter: 16, loss: -314.71478271484375
epoch: 0, iter: 17, loss: -325.1954040527344
epoch: 0, iter: 18, loss: -389.9200744628906
epoch: 0, iter: 19, loss: -335.7273254394531
epoch: 0, iter: 20, loss: -400.587158203125
epoch: 0, iter: 21, loss: -336.615234375
epoch: 0, iter: 22, 

KeyboardInterrupt: 