In [1]:
import numpy as np 
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re
import string
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torch
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Maxim\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Maxim\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
window_size = 3

epochs = 50
lr = 1e-2
batch_size = 16
embed_size = 300
path = 'trans.txt'
sw = stopwords.words('russian')

In [3]:
def preprocess_raw_text(path):
    with open(path, 'r', encoding='utf-8') as f:
        text = f.read().replace('\n', ' ')

        text = text.lower()
        
        text = text.replace('\xa0', ' ')

        text = re.sub(r'спикер (\d+|\?):', '', text)

        for punct in string.punctuation +'«»':
            text = text.replace(punct, ' ')

        while '  ' in text:
            text = text.replace('  ', ' ')

        raw_tokens = word_tokenize(text)
        
        tokens = []

        for token in raw_tokens:
            if token not in sw and token.isnumeric() == False:
                tokens.append(token)

        return tokens
        
tokens = preprocess_raw_text(path)

In [4]:
class SkipGramDataset(Dataset):
    def __init__(self, tokens, window_size=window_size):
        self.pairs = []
        for i in range(window_size, len(tokens) - window_size):
            target = tokens[i]
            
            context = tokens[i-window_size:i + window_size + 1]

            context.remove(target)

            for word in context:
                self.pairs.append((target, word))
    
    def __len__(self):
        return len(self.pairs)   

    def __getitem__(self, i):
        return self.pairs[i]

    
dataset = SkipGramDataset(tokens=tokens )

In [5]:
vocab = sorted(set(tokens))
wtoi = {i : word for word, i in enumerate(vocab)}
itow = {word : i for word, i in enumerate(vocab)}
wtoi['крош'],itow[491]

(491, 'крош')

In [6]:
X, y = [], []
for target, context in dataset:
    X.append(wtoi[target])
    y.append(wtoi[context])

X_train = torch.LongTensor(X)
y_train = torch.LongTensor(y)

In [7]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embed_size ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.linear = nn.Linear(embed_size, vocab_size)
        self.softmax = nn.Softmax()

    def forward(self, x):
        return self.softmax(self.linear(self.embed(x)))

In [8]:
model = SkipGramModel(len(vocab), embed_size=embed_size)

loss_func = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=lr)

In [9]:
for epoch in range(epochs):
    total_loss = 0
    for i in range(0, len(X), batch_size):

        x = X_train[i:i+batch_size]
        y = y_train[i:i+batch_size]

        opt.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y.view(-1))
        loss.backward()
        opt.step()
        
        total_loss += loss.item()
    print(f'Epoch num: {epoch+1}, loss value: {total_loss:.3f}')

  return self._call_impl(*args, **kwargs)


Epoch num: 1, loss value: 6678.929
Epoch num: 2, loss value: 6629.781
Epoch num: 3, loss value: 6603.886
Epoch num: 4, loss value: 6599.707
Epoch num: 5, loss value: 6598.539
Epoch num: 6, loss value: 6598.410
Epoch num: 7, loss value: 6598.361
Epoch num: 8, loss value: 6598.067
Epoch num: 9, loss value: 6598.060
Epoch num: 10, loss value: 6597.998
Epoch num: 11, loss value: 6597.998
Epoch num: 12, loss value: 6597.992
Epoch num: 13, loss value: 6597.935
Epoch num: 14, loss value: 6597.811
Epoch num: 15, loss value: 6597.811
Epoch num: 16, loss value: 6597.936
Epoch num: 17, loss value: 6597.873
Epoch num: 18, loss value: 6597.935
Epoch num: 19, loss value: 6597.874
Epoch num: 20, loss value: 6598.066
Epoch num: 21, loss value: 6598.079
Epoch num: 22, loss value: 6598.059
Epoch num: 23, loss value: 6598.080
Epoch num: 24, loss value: 6598.091
Epoch num: 25, loss value: 6598.061
Epoch num: 26, loss value: 6597.873
Epoch num: 27, loss value: 6597.811
Epoch num: 28, loss value: 6597.826
E