In [3]:
import torch
import numpy as np
import torch.functional as F
import torch.nn as nn
from nltk import ngrams
from IPython.display import display
import pandas as pd
from tqdm import tqdm

from nltk.tokenize import sent_tokenize, wordpunct_tokenize, word_tokenize
from gensim.models import KeyedVectors
from utils import get_distinct_words, read_corpus
from itertools import chain

In [4]:
print(torch.device('cuda:1'))
print(torch.cuda.device(0))

cuda:1
<torch.cuda.device object at 0x7fe36c86c9a0>


In [5]:
min_count = 2
ru_corpus_cp = read_corpus("ru_copy")
index_to_key, word_counter = get_distinct_words(ru_corpus_cp, min_count=min_count)
index_to_key = ["UNK", "PAD"] + index_to_key
key_to_index = {word: i for i, word in enumerate(index_to_key)}

In [6]:
len(ru_corpus_cp), ru_corpus_cp[:2]

(8,
 [['кстати',
   'как',
   'неожиданно',
   'кпрф',
   'становиться',
   'не',
   'все',
   'равный',
   'на',
   'судьба',
   'фермер',
   'именно',
   'накануне',
   'выборы'],
  ['можно',
   'и',
   'по',
   'другому',
   'сказать',
   'убогий',
   'клоунада',
   'кпрф',
   'это',
   'попытка',
   'отвечать',
   'на',
   'запрос',
   'молодой',
   'поколение',
   'левый',
   'не',
   'питать',
   'иллюзия',
   'по',
   'повод',
   'коммунистический',
   'номенклатура',
   'советский',
   'образец',
   'но',
   'в',
   'сила',
   'свой',
   'положение',
   'под',
   'давление',
   'вызов',
   'время',
   'они',
   'вынуждать',
   'быть',
   'меняться']])

In [7]:
def as_matrix(sequences, key_to_index, UNK="UNK", PAD="PAD", max_len=None):
    """ Convert a list of tokens into a matrix with padding """
    if isinstance(sequences[0], str):
        sequences = [x.split() for x in sequences]

    max_sequence_len = max([len(x) for x in sequences])
    if max_len is not None and max_sequence_len > max_len :
        max_sequence_len = max_len

    matrix = np.full((len(sequences), max_sequence_len), np.int32(key_to_index[PAD]))
    for i, seq in enumerate(sequences):
        row_ix = [key_to_index.get(word, key_to_index[UNK]) for word in seq[:max_sequence_len]]
        matrix[i, :len(row_ix)] = row_ix

    return matrix

In [8]:
display(len(ru_corpus_cp))
# display(as_matrix(ru_corpus_cp, key_to_index, max_len=10))
len(list(chain.from_iterable(ru_corpus_cp)))

8

309

In [9]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def pad_text(text: list, window_size: int, pad: str):
    appendix = [pad] * window_size

    return appendix + text + appendix

In [10]:
class BaseEmbeddings(KeyedVectors):
    def __init__(self, corpus, distinct_words=None, word_counter=None, vector_size=100, min_count=10):
        super().__init__(vector_size=vector_size)
        
        self.index_to_key = distinct_words
        self.word_counter = word_counter
        if distinct_words is None or word_counter is None:
            self.index_to_key, self.word_counter = get_distinct_words(corpus, min_count=min_count)
    
        self.key_to_index = {word: i for i, word in enumerate(self.index_to_key)}

In [11]:
np.random.choice(50, 20)


array([19, 15, 12, 45,  1, 26, 43,  6,  9, 29, 28, 30,  4, 35, 32, 13,  5,
        9, 41, 14])

In [45]:
def softmax(u):
    return torch.tensor([torch.exp(u_j) / torch.sum(torch.exp(u)) for u_j in u])

In [59]:
class Word2Vec(BaseEmbeddings):
    def __init__(self, corpus, distinct_words=None, vector_size=100, window_size=5,
                 min_count=10, batch_size=None, n_negative=5, n_epoches=5):
        super().__init__(corpus, vector_size=vector_size, distinct_words=distinct_words, min_count=min_count)

        self.W1 = torch.randn((len(self.index_to_key), vector_size))  # vocab_size, vector_size    #, device=torch.cuda.device(0))
        self.W2 = torch.randn((vector_size, len(self.index_to_key)))  # vector_size, vocab_size    #, device=torch.cuda.device(0))
        # self.emb = nn.Embedding(len(self.index_to_key), vector_size)
        # self.exp = nn.Linear(vector_size, len(self.index_to_key))

        self.corpus = corpus
        self.window_size = window_size
        self.batch_size = batch_size
        if batch_size is None:
            self.batch_size = np.max([len(text) for text in corpus])
        self.n_negative = n_negative
        self.alpha = 0.001
        
        self.train(n_epoches)
        self.vectors = self.W1

    def one_hot_vector(self, word: str):
        vector = torch.zeros(len(self.index_to_key))
        vector[self.key_to_index[word]] = 1

        return vector
    

    def train(self, n_epoches=5):
        """
        trains self.center_W and self.context_W matrices
        """
        for epoch in tqdm(range(n_epoches)):
            for text in self.corpus:
                for batch in chunks(text, self.batch_size):
                    batch_loss = 0

                    for j, center in enumerate(batch):
                        if center in self.index_to_key:

                            window = [batch[i] for i in range(-self.window_size, self.window_size + 1, 1) if i + j >= 0 and i + j < len(batch) and j != 0 and batch[i] in self.index_to_key] 
                            
                            h = self.one_hot_vector(center) @ self.W1      # 1, vocab_size x vocab_size, vec_size = 1, vec_size
                            # print("h", h.shape)
                            u = h @ self.W2                                # 1, vec_size x vec_size, vocab_size = 1, voc_size
                            # print("u", u.shape)
                            y = softmax(u)
                            # print("y", y.shape)
                            k_neg = np.random.choice(len(self.index_to_key), self.n_negative)
                            W2_neg = self.W2[:, k_neg]                     # vec_size, k_neg  
                            # print("W2_neg", W2_neg.shape)
                            neg_sum = torch.sum(torch.exp(h @ W2_neg))
                            # print("neg_sum", neg_sum)
                            # print(window)
                            u_c = torch.sum(torch.tensor([h @ self.W2[:, self.key_to_index[context]] for context in window]))  # 1, vec_size x vec_size, 1

                            batch_loss = -u_c + self.n_negative * neg_sum
                            print("batch_loss", batch_loss)

                            t = torch.zeros_like(y)
                            t[k_neg] = 1
                            # print(t)
                            dl_du = -t + y  # 1, voc_size
                            dl_dW2 = h.view(-1, 1) @ dl_du.view(1, -1)  # vec_size, 1 x 1, voc_size
                    
                            dl_dh = dl_du.view(1, -1) @ self.W2.T  # 1, voc_size x vec_size, voc_size
                            dl_dW1 = self.one_hot_vector(center).view(-1, 1) @ dl_dh  # voc_size, 1 x 1, vec_size

                            # print(dl_dW1.shape, dl_dW2.shape)


                            # e = np.zeros(len(self.index_to_key))
                            

                            # dot_U_V = self.U[self.key_to_index[context]] @ self.V[self.key_to_index[center]]  # vec_size x vec_size = 1
                            # dot_U_V = 
                            # print(dot_U_V)
                    
                            # k_neg = np.random.choice(len(self.index_to_key), self.n_negative)
                            # dot_Uneg_V = self.U[k_neg] @ self.V[self.key_to_index[center]]  # k_neg, vec_size x vec_size, 1 = k_neg, 1
                            # exp_Uneg_V = torch.expm1(dot_Uneg_V)
                            # print(exp_Uneg_V)
                            # loss_ij = -dot_U_V + torch.sum(exp_Uneg_V, axis=0)
                        
                            # batch_loss += loss_ij
                            # print(batch_loss)
                            
                            # dL_dV = torch.sum(-self.U[self.key_to_index[context]] + torch.sum(exp_Uneg_V.view(-1, 1) * self.U[self.key_to_index[center]], axis=0))
                            # dL_dU = torch.sum(-self.V[self.key_to_index[center]] + torch.sum(exp_Uneg_V.view(-1, 1) * self.V[self.key_to_index[center]], axis=0))

                            self.W1 -= self.alpha * dl_dW1
                            self.W2 -= self.alpha * dl_dW2

                    # print(batch_loss)

                # break
            # break


w2v = Word2Vec(ru_corpus_cp, min_count=2)

  0%|          | 0/5 [00:00<?, ?it/s]

batch_loss tensor(8.8771e+11)
batch_loss tensor(244105.3125)
batch_loss tensor(882273.8125)
batch_loss tensor(10880132.)
batch_loss tensor(17419.5645)
batch_loss tensor(1232314.6250)
batch_loss tensor(20.1731)
batch_loss tensor(478.6078)
batch_loss tensor(10292556.)
batch_loss tensor(11262.0527)
batch_loss tensor(4530980.5000)
batch_loss tensor(3.5248e+08)
batch_loss tensor(6.0777e+09)
batch_loss tensor(19.7348)
batch_loss tensor(1.3265e+08)
batch_loss tensor(7652.3301)
batch_loss tensor(167474.5000)
batch_loss tensor(821608.8750)
batch_loss tensor(2.3367e+08)
batch_loss tensor(90153680.)
batch_loss tensor(3026.3611)
batch_loss tensor(9449.4238)
batch_loss tensor(2982.3794)
batch_loss tensor(13302.7979)
batch_loss tensor(428.6572)
batch_loss tensor(3560.0234)
batch_loss tensor(966033.3750)
batch_loss tensor(153.6170)
batch_loss tensor(25465.2324)
batch_loss tensor(188991.0156)
batch_loss tensor(3167967.)
batch_loss tensor(5886314.)
batch_loss tensor(88146.8281)
batch_loss tensor(741.40

 20%|██        | 1/5 [00:00<00:01,  3.11it/s]

tensor(1.0754e+08)
batch_loss tensor(282485.9062)
batch_loss tensor(1944.5731)
batch_loss tensor(6006041.)
batch_loss tensor(1.1148e+08)
batch_loss tensor(7603505.)
batch_loss tensor(2.7608e+08)
batch_loss tensor(1.8910e+09)
batch_loss tensor(2340993.7500)
batch_loss tensor(585.4722)
batch_loss tensor(6392131.)
batch_loss tensor(1.4267e+16)
batch_loss tensor(7716685.5000)
batch_loss tensor(2.0248e+08)
batch_loss tensor(7373.6377)
batch_loss tensor(144343.9531)
batch_loss tensor(1.1595e+10)
batch_loss tensor(2998948.5000)
batch_loss tensor(6.7255e+10)
batch_loss tensor(8.8092e+08)
batch_loss tensor(16800.5801)
batch_loss tensor(835714.5625)
batch_loss tensor(6.0344e+10)
batch_loss tensor(12754989.)
batch_loss tensor(26.0321)
batch_loss tensor(313218.2188)
batch_loss tensor(8593483.)
batch_loss tensor(20280240.)
batch_loss tensor(453090.9375)
batch_loss tensor(28507594.)
batch_loss tensor(2.9762e+08)
batch_loss tensor(4611320.5000)
batch_loss tensor(204.0508)
batch_loss tensor(91006.3828

 40%|████      | 2/5 [00:00<00:00,  3.18it/s]

batch_loss tensor(2.2951e+09)
batch_loss tensor(9699.6426)
batch_loss tensor(1.7965e+08)
batch_loss tensor(789416.3125)
batch_loss tensor(-8.8743)
batch_loss tensor(230.2584)
batch_loss tensor(1320636.2500)
batch_loss tensor(-24.8871)
batch_loss tensor(74949400.)
batch_loss tensor(194850.2969)
batch_loss tensor(55257536.)
batch_loss tensor(171569.0156)
batch_loss tensor(7151126.)
batch_loss tensor(15727734.)
batch_loss tensor(42583.9609)
batch_loss tensor(8.0454e+08)
batch_loss tensor(22.2901)
batch_loss tensor(1.2108e+09)
batch_loss tensor(1.4062e+12)
batch_loss tensor(157913.7500)
batch_loss tensor(16131515.)
batch_loss tensor(1.1990e+08)
batch_loss tensor(208952.2344)
batch_loss tensor(283063.1875)
batch_loss tensor(77199080.)
batch_loss tensor(14821.4453)
batch_loss tensor(16995710.)
batch_loss tensor(11132824.)
batch_loss tensor(47326.3203)
batch_loss tensor(7.7163e+08)
batch_loss tensor(1258.5000)
batch_loss tensor(233924.4219)
batch_loss tensor(7561.5356)
batch_loss tensor(29150

 60%|██████    | 3/5 [00:00<00:00,  2.95it/s]

batch_loss tensor(370.2365)
batch_loss tensor(23559.2793)
batch_loss tensor(27.8494)
batch_loss tensor(-3.7296)
batch_loss tensor(1.3612e+09)
batch_loss tensor(238.9479)
batch_loss tensor(6.2501e+09)
batch_loss tensor(5.5691)
batch_loss tensor(396312.9062)
batch_loss tensor(1406.6423)
batch_loss tensor(247832.2344)
batch_loss tensor(26778908.)
batch_loss tensor(3.9276e+08)
batch_loss tensor(201.0893)
batch_loss tensor(8451.4395)
batch_loss tensor(2381.2583)
batch_loss tensor(248938.3281)
batch_loss tensor(220.0569)
batch_loss tensor(574.6126)
batch_loss tensor(469718.3750)
batch_loss tensor(59684820.)
batch_loss tensor(80.1669)
batch_loss tensor(4478171.)
batch_loss tensor(4087028.7500)
batch_loss tensor(5.5426e+09)
batch_loss tensor(2555.9194)
batch_loss tensor(42.9979)
batch_loss tensor(1.6584e+08)
batch_loss tensor(1230.5305)
batch_loss tensor(1979324.5000)
batch_loss tensor(49438.2188)
batch_loss tensor(3993263.)
batch_loss tensor(8.5419)
batch_loss tensor(60192568.)
batch_loss ten

 80%|████████  | 4/5 [00:01<00:00,  3.01it/s]

tensor(1.1215e+08)
batch_loss tensor(2471.8960)
batch_loss tensor(44145.7344)
batch_loss tensor(39.5793)
batch_loss tensor(259235.6562)
batch_loss tensor(85375.4375)
batch_loss tensor(9.9356e+13)
batch_loss tensor(2043103.2500)
batch_loss tensor(7430732.)
batch_loss tensor(656662.7500)
batch_loss tensor(23.8300)
batch_loss tensor(3078.3796)
batch_loss tensor(36.5990)
batch_loss tensor(1578344.2500)
batch_loss tensor(1.8300e+10)
batch_loss tensor(3301860.2500)
batch_loss tensor(3325257.2500)
batch_loss tensor(1323.3457)
batch_loss tensor(8.8730e+09)
batch_loss tensor(2511521.5000)
batch_loss tensor(332.4497)
batch_loss tensor(5.5875e+09)
batch_loss tensor(2.0446e+10)
batch_loss tensor(1358547.1250)
batch_loss tensor(4093.9187)
batch_loss tensor(2416478.)
batch_loss tensor(949171.1875)
batch_loss tensor(32138978.)
batch_loss tensor(294.6530)
batch_loss tensor(19428034.)
batch_loss tensor(0.3343)
batch_loss tensor(763.2670)
batch_loss tensor(10989370.)
batch_loss tensor(323.8268)
batch_lo

100%|██████████| 5/5 [00:01<00:00,  3.04it/s]

batch_loss tensor(24203798.)
batch_loss tensor(63512180.)
batch_loss tensor(62321172.)
batch_loss tensor(227462.4531)
batch_loss tensor(1729.5999)
batch_loss tensor(3090278.5000)





In [65]:
U = torch.randn((len(index_to_key[:10]), 5))
V = torch.randn((len(index_to_key[:10]), 5))
neg = np.random.choice(len(index_to_key[:10]), 3)
print(U)
print(neg)
print(U[neg] @ V[0])

tensor([[-0.2165, -1.4011, -0.5845, -0.3900,  0.4657],
        [-0.9166,  0.2588, -0.9368,  0.6109,  1.8453],
        [-0.8238,  0.1415, -0.8317, -0.0836,  0.3332],
        [ 0.6532,  0.8278,  1.7981, -0.5450,  0.7319],
        [-0.5642, -1.0860, -0.7013, -0.1075,  1.2076],
        [ 0.6756, -1.0019,  1.4159, -0.3987,  0.1787],
        [-0.4165, -1.3151, -0.9172,  0.3943, -0.0899],
        [-0.2833, -0.8074,  0.0502,  1.3422,  0.4793],
        [ 0.5246, -0.7201, -1.1037,  0.0930, -0.2626],
        [-0.5134,  1.1586, -1.5844,  1.5967, -0.0292]])
[6 9 3]
tensor([-2.6095, -2.1787,  3.7570])


In [73]:
torch.sum((torch.arange(5).view(-1, 1) * torch.arange(100)), axis=0)

tensor([  0,  10,  20,  30,  40,  50,  60,  70,  80,  90, 100, 110, 120, 130,
        140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270,
        280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410,
        420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550,
        560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690,
        700, 710, 720, 730, 740, 750, 760, 770, 780, 790, 800, 810, 820, 830,
        840, 850, 860, 870, 880, 890, 900, 910, 920, 930, 940, 950, 960, 970,
        980, 990])

array([8, 2, 4, 5, 1])

In [None]:
from torch.utils.data import Dataset, DataLoader

class Word2VecDataset(Dataset):
    """
    Takes a HuggingFace dataset as an input, to be used for a Word2Vec dataloader.
    """
    def __init__(self, dataset, vocab_size):
        self.dataset = dataset
        self.vocab_size = vocab_size
        self.data = [i for s in dataset['moving_window'] for i in s]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]