# Генерация текста с помощью LSTM-сетей

Сеть способна выучить распределение символов в последовательностях


Датасет формируем проходясь окном по текстовому корпусу, задача сети - предсказывать следующий символ на основании нескольких предыдущих.
Данный подход можно улучшить, используя только отдельные предложения с паддингами.

In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm

### 0. Получение данных для обучения

Для обучения используется датасет российских новостей, который был сохранён в файл `text_corpus.parquet` со следующими параметрами:

In [12]:
# data.to_parquet("text_corpus.parquet", engine="pyarrow", compression="gzip", index=False)

In [4]:
data = pd.read_parquet("text_corpus.parquet", engine="pyarrow", )

In [5]:
len(data)

50000

In [6]:
# после обучения токенизатора можно уменьшить тренировочную выборку
# но нужно не забыть обновить переменную corpus
data = data.sample(10000)

### 1. Вспомогательные функции:
+ Визуализация процесса обучения
    + Сможем посмотреть, как меняется качество с течением времени.
+ Коллбек ModelCheckpoint
    + Процесс обучения LSTM сетей достаточно длительный. Будет обидно, если из-за непредвиденного сбоя потеряется прогресс за многие часы обучения.
+ Колбек динамической подстройки размера батча и learning rate
    + Подстраивать LR это уже стандартная практика, а я хочу ещё и размер батча менять: предположу, что большой батч позволит дать некое "обобщённое" представление о распределении токенов, а маленький батч улучшит "грамотность".



In [6]:
import matplotlib.pyplot as plt


def plot_graphs(history, string):
    plt.plot(history.history[string])
    plt.xlabel("Epochs")
    plt.ylabel(string)
    plt.show()

In [9]:
# весь текст одной "портянкой", чтобы заранее оценить, какие символы могут нам попадаться
# raw_text = " ".join(data.text)
# chars = sorted(list(set(raw_text)))
# chars

### 3. Предобработка и создание датасета

Для тренировки LSTM модели понадобится немного поработать с форматами

In [7]:
import re

In [8]:
corpus = " \n".join(data.text.to_list()).lower()
# Хочу отделить всю пунктуацию от слов пробелом
corpus = " ".join(re.findall(r"[\w']+|[.,!?;\n]", corpus))

In [42]:
total_words = 800
max_sequence_length = 80



#### 3.1 Токенизация BPE 

BPE токенизация посредством yttm эффективна, но потребуется поработать с файлом

In [10]:
import youtokentome as yttm

In [11]:
bpe_model_path = "bpe.yttm"

##### 3.2 Создаём токенизатор BPE и обучаем его

In [12]:
def create_bpe_tokenizer_from_scratch(corpus, train_data_path="yttm_train_data.txt"):
    with open(train_data_path, "w") as _file:
        _file.writelines(corpus)
    # Training model
    # (data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3)
    return yttm.BPE.train(data=train_data_path, vocab_size=total_words, model=bpe_model_path)

In [21]:
# Creating model
bpe = create_bpe_tokenizer_from_scratch(corpus)

In [12]:
# Loading model
bpe = yttm.BPE(model=bpe_model_path)

In [13]:
print(' '.join(bpe.vocab())[:300])

<PAD> <UNK> <BOS> <EOS> ▁ о е и а н т с р в л к п д м у я ы г з б , ь ч й . х ж ' ц ю ш ф щ э ъ ? ё ! ; _ ▁п ▁с ▁в ▁, ст ни ра ро но ре на ▁о ко то ▁. ▁и ▁по го не де те ли ва ▁м за ны ▁на ль ка ри та ле ла ▁д во ве ▁б ти ци ▁со ви ▁ч ки ло ▁у ▁за ▁' да ть ен ми ▁а ▁не ▁ко сс ▁пре ет ру ся ди ▁про н


In [102]:
# encode(self, 
#     sentences, 
#     output_type=yttm.OutputType.ID, 
#     bos=False, 
#     eos=False, 
#     reverse=False, 
#     dropout_prob=0)

In [14]:
encoded_corpus = np.array(bpe.encode(corpus))

# sequences = sequence[:-(len(sequence)%max_sequence_length)].reshape((len(sequence)//max_sequence_length, max_sequence_length))

In [33]:
from torch.utils.data import Dataset, DataLoader

In [30]:
class TextDataset(Dataset):
    def __init__(self, encoded_corpus, sequence_length=80):
        self.sequence_length = sequence_length
        self.n_samples = encoded_corpus.shape[0]-max_sequence_length
        self.X = torch.from_numpy(encoded_corpus)
        
    def __getitem__(self, index):
        return self.X[index : index+self.sequence_length], self.X[index+self.sequence_length + 1].view((1))
    
    def __len__(self):
        return self.n_samples

In [43]:
dataset = TextDataset(encoded_corpus, sequence_length=max_sequence_length)


## 4. Модель

В качестве модели будет применяться LSTM сеть с двумя слоями LSTM

TODO
+ Gradient clipping
+ More layers?

In [83]:
class LSTMModel(nn.Module):
    def __init__(
            self, 
            input_size=max_sequence_length,
            num_classes=total_words,
            hidden_dim=64,
            num_layers=2,
            batch_size=128,
                ):
        super(LSTMModel, self).__init__()
        self.input_size = input_size
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(0.25)
        
        # Embedding layer
        self.embedding = nn.Embedding(self.num_classes, self.hidden_dim, padding_idx=0)
        # Bi-LSTM
        # Forward and backward
        self.lstm_cell_forward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
        self.lstm_cell_backward = nn.LSTMCell(self.hidden_dim, self.hidden_dim)
        # LSTM layer
        self.lstm_cell = nn.LSTMCell(self.hidden_dim * 2, self.hidden_dim * 2)
        
#         self.lstm = nn.LSTM(
#             max_input_length,  # input_size – The number of expected features in the input x
#             hidden_dim, # hidden_size – The number of features in the hidden state h
#             num_layers, # num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1
#             # bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
#             batch_first=True# batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False
#             # dropout – If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
#             bidirectional=True# bidirectional – If True, becomes a bidirectional LSTM. Default: False
#             # proj_size – If > 0, will use LSTM with projections of corresponding size. Default: 0
#         )
        
        self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
        
        
    def forward(self, X):
        # Bi-LSTM
        # hs = [batch_size x hidden_size]
        # cs = [batch_size x hidden_size]
        hs_forward = torch.zeros(X.size(0), self.hidden_dim)
        cs_forward = torch.zeros(X.size(0), self.hidden_dim)
        hs_backward = torch.zeros(X.size(0), self.hidden_dim)
        cs_backward = torch.zeros(X.size(0), self.hidden_dim)

        # LSTM
        # hs = [batch_size x (hidden_size * 2)]
        # cs = [batch_size x (hidden_size * 2)]
        hs_lstm = torch.zeros(X.size(0), self.hidden_dim * 2)
        cs_lstm = torch.zeros(X.size(0), self.hidden_dim * 2)

        # Weights initialization
        torch.nn.init.kaiming_normal_(hs_forward)
        torch.nn.init.kaiming_normal_(cs_forward)
        torch.nn.init.kaiming_normal_(hs_backward)
        torch.nn.init.kaiming_normal_(cs_backward)
        torch.nn.init.kaiming_normal_(hs_lstm)
        torch.nn.init.kaiming_normal_(cs_lstm)

        # From idx to embedding
        out = self.embedding(X) 
#         print(f"Embedding output shape: {out.shape}") # [20,80,64]
        # Prepare the shape for LSTM Cells
        # out = out.view(self.sequence_len, X.size(0), -1)
        
        
        forward = []
        backward = []

        # Unfolding Bi-LSTM
        # Forward
        for i in range(self.input_size):
            hs_forward, cs_forward = self.lstm_cell_forward(out[:, i], (hs_forward, cs_forward))
            hs_forward = self.dropout(hs_forward)
            cs_forward = self.dropout(cs_forward)
            forward.append(hs_forward)

         # Backward
        for i in reversed(range(self.input_size)):
            hs_backward, cs_backward = self.lstm_cell_backward(out[:, i], (hs_backward, cs_backward))
            hs_backward = self.dropout(hs_backward)
            cs_backward = self.dropout(cs_backward)
            backward.append(hs_backward)
            
            
         # LSTM
        for fwd, bwd in zip(forward, backward):
            input_tensor = torch.cat((fwd, bwd), 1)
            hs_lstm, cs_lstm = self.lstm_cell(input_tensor, (hs_lstm, cs_lstm))

         # Last hidden state is passed through a linear layer
        out = self.linear(hs_lstm)
#         print(f"Linear input shape: {hs_lstm.shape}") [20, 128]
#         print(f"Linear output shape: {out.shape}") [20, 800]
        return out

In [90]:
learning_rate = 1e-5
num_epochs = 1
num_batches = int(len(dataset) / batch_size)
batch_size=200
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
# Model initialization
model = LSTMModel(
            input_size=max_sequence_length,
            num_classes=total_words,
            hidden_dim=64,
            num_layers=2,
            batch_size=batch_size)

In [85]:
X, y = next(iter(dataloader))
y_ = model(X)

In [91]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
model.train()
loss_history = []
# Training pahse
for epoch in range(num_epochs):

    # Mini batches
    for i, (X, y) in tqdm(enumerate(dataloader), total=num_batches):
        # Feed the model
        y_pred = model(X)

        # Loss calculation
        loss = F.cross_entropy(y_pred, y.squeeze())

        # Clean gradients
        optimizer.zero_grad()

        # Calculate gradientes
        loss.backward()

        # Updated parameters
        optimizer.step()

        print("Epoch: %d ,  loss: %.5f " % (epoch, loss.item()))

  0%|          | 1/153987 [00:00<25:03:43,  1.71it/s]

Epoch: 0 ,  loss: 6.68149 


  0%|          | 2/153987 [00:01<21:59:45,  1.94it/s]

Epoch: 0 ,  loss: 6.67743 


  0%|          | 3/153987 [00:01<20:49:47,  2.05it/s]

Epoch: 0 ,  loss: 6.68391 


  0%|          | 4/153987 [00:01<19:45:26,  2.16it/s]

Epoch: 0 ,  loss: 6.67764 


  0%|          | 5/153987 [00:02<18:59:03,  2.25it/s]

Epoch: 0 ,  loss: 6.68883 


  0%|          | 6/153987 [00:02<18:52:48,  2.27it/s]

Epoch: 0 ,  loss: 6.68083 


  0%|          | 7/153987 [00:03<20:29:29,  2.09it/s]

Epoch: 0 ,  loss: 6.68436 


  0%|          | 8/153987 [00:03<19:45:01,  2.17it/s]

Epoch: 0 ,  loss: 6.68712 


  0%|          | 9/153987 [00:04<19:45:37,  2.16it/s]

Epoch: 0 ,  loss: 6.67680 


  0%|          | 10/153987 [00:04<19:32:33,  2.19it/s]

Epoch: 0 ,  loss: 6.67562 


  0%|          | 11/153987 [00:05<19:38:55,  2.18it/s]

Epoch: 0 ,  loss: 6.67980 


  0%|          | 12/153987 [00:05<23:09:57,  1.85it/s]

Epoch: 0 ,  loss: 6.68308 


  0%|          | 13/153987 [00:06<24:58:34,  1.71it/s]

Epoch: 0 ,  loss: 6.67751 


  0%|          | 14/153987 [00:07<24:24:33,  1.75it/s]

Epoch: 0 ,  loss: 6.68549 


  0%|          | 15/153987 [00:07<22:21:09,  1.91it/s]

Epoch: 0 ,  loss: 6.69014 


  0%|          | 16/153987 [00:07<22:01:12,  1.94it/s]

Epoch: 0 ,  loss: 6.68752 


  0%|          | 17/153987 [00:08<20:57:05,  2.04it/s]

Epoch: 0 ,  loss: 6.69177 


  0%|          | 18/153987 [00:09<23:29:51,  1.82it/s]

Epoch: 0 ,  loss: 6.68762 


  0%|          | 19/153987 [00:09<22:56:54,  1.86it/s]

Epoch: 0 ,  loss: 6.68187 


  0%|          | 20/153987 [00:10<21:34:24,  1.98it/s]

Epoch: 0 ,  loss: 6.67969 


  0%|          | 21/153987 [00:10<22:11:25,  1.93it/s]

Epoch: 0 ,  loss: 6.68512 


  0%|          | 22/153987 [00:11<23:04:22,  1.85it/s]

Epoch: 0 ,  loss: 6.68725 


  0%|          | 23/153987 [00:11<23:07:44,  1.85it/s]

Epoch: 0 ,  loss: 6.68889 


  0%|          | 24/153987 [00:12<21:59:14,  1.95it/s]

Epoch: 0 ,  loss: 6.68028 


  0%|          | 25/153987 [00:13<26:02:09,  1.64it/s]

Epoch: 0 ,  loss: 6.67512 


  0%|          | 26/153987 [00:13<25:06:03,  1.70it/s]

Epoch: 0 ,  loss: 6.67841 


  0%|          | 27/153987 [00:13<22:53:14,  1.87it/s]

Epoch: 0 ,  loss: 6.68118 


  0%|          | 28/153987 [00:14<21:33:53,  1.98it/s]

Epoch: 0 ,  loss: 6.68327 


  0%|          | 29/153987 [00:14<21:06:08,  2.03it/s]

Epoch: 0 ,  loss: 6.68198 


  0%|          | 30/153987 [00:15<21:18:37,  2.01it/s]

Epoch: 0 ,  loss: 6.68240 


  0%|          | 31/153987 [00:15<20:52:53,  2.05it/s]

Epoch: 0 ,  loss: 6.69246 


  0%|          | 32/153987 [00:16<20:51:55,  2.05it/s]

Epoch: 0 ,  loss: 6.68295 


  0%|          | 33/153987 [00:16<20:04:27,  2.13it/s]

Epoch: 0 ,  loss: 6.68039 


  0%|          | 34/153987 [00:17<19:17:08,  2.22it/s]

Epoch: 0 ,  loss: 6.68486 


  0%|          | 35/153987 [00:17<18:40:50,  2.29it/s]

Epoch: 0 ,  loss: 6.69505 


  0%|          | 36/153987 [00:17<18:21:52,  2.33it/s]

Epoch: 0 ,  loss: 6.67902 


  0%|          | 37/153987 [00:18<19:56:06,  2.15it/s]

Epoch: 0 ,  loss: 6.68619 


  0%|          | 38/153987 [00:18<19:11:25,  2.23it/s]

Epoch: 0 ,  loss: 6.68533 


  0%|          | 39/153987 [00:19<19:56:55,  2.14it/s]

Epoch: 0 ,  loss: 6.68184 


  0%|          | 40/153987 [00:19<19:05:44,  2.24it/s]

Epoch: 0 ,  loss: 6.67284 


  0%|          | 41/153987 [00:20<19:07:26,  2.24it/s]

Epoch: 0 ,  loss: 6.67223 


  0%|          | 42/153987 [00:21<23:13:44,  1.84it/s]

Epoch: 0 ,  loss: 6.68036 


  0%|          | 43/153987 [00:21<23:48:30,  1.80it/s]

Epoch: 0 ,  loss: 6.68259 


  0%|          | 44/153987 [00:22<21:53:41,  1.95it/s]

Epoch: 0 ,  loss: 6.67310 


  0%|          | 45/153987 [00:22<24:40:02,  1.73it/s]

Epoch: 0 ,  loss: 6.67790 


  0%|          | 46/153987 [00:23<22:39:16,  1.89it/s]

Epoch: 0 ,  loss: 6.67803 


  0%|          | 47/153987 [00:23<21:08:31,  2.02it/s]

Epoch: 0 ,  loss: 6.68661 


  0%|          | 48/153987 [00:24<23:28:56,  1.82it/s]

Epoch: 0 ,  loss: 6.67695 


  0%|          | 49/153987 [00:25<25:58:13,  1.65it/s]

Epoch: 0 ,  loss: 6.67843 


  0%|          | 50/153987 [00:25<29:56:36,  1.43it/s]

Epoch: 0 ,  loss: 6.67896 


  0%|          | 51/153987 [00:26<28:01:38,  1.53it/s]

Epoch: 0 ,  loss: 6.67446 


  0%|          | 52/153987 [00:26<25:06:33,  1.70it/s]

Epoch: 0 ,  loss: 6.68596 


  0%|          | 53/153987 [00:27<24:01:11,  1.78it/s]

Epoch: 0 ,  loss: 6.68230 


  0%|          | 54/153987 [00:27<22:50:15,  1.87it/s]

Epoch: 0 ,  loss: 6.67542 


  0%|          | 55/153987 [00:28<23:33:23,  1.82it/s]

Epoch: 0 ,  loss: 6.68294 


  0%|          | 56/153987 [00:29<25:14:21,  1.69it/s]

Epoch: 0 ,  loss: 6.67939 


  0%|          | 57/153987 [00:29<27:03:19,  1.58it/s]

Epoch: 0 ,  loss: 6.68044 


  0%|          | 58/153987 [00:30<25:14:57,  1.69it/s]

Epoch: 0 ,  loss: 6.68826 


  0%|          | 59/153987 [00:30<23:43:19,  1.80it/s]

Epoch: 0 ,  loss: 6.68121 


  0%|          | 60/153987 [00:31<23:13:59,  1.84it/s]

Epoch: 0 ,  loss: 6.67123 


  0%|          | 61/153987 [00:31<23:50:18,  1.79it/s]

Epoch: 0 ,  loss: 6.67732 


  0%|          | 62/153987 [00:32<24:42:26,  1.73it/s]

Epoch: 0 ,  loss: 6.67387 


  0%|          | 63/153987 [00:33<25:11:54,  1.70it/s]

Epoch: 0 ,  loss: 6.67818 


  0%|          | 64/153987 [00:33<24:16:56,  1.76it/s]

Epoch: 0 ,  loss: 6.67644 


  0%|          | 65/153987 [00:34<23:49:24,  1.79it/s]

Epoch: 0 ,  loss: 6.67270 


  0%|          | 66/153987 [00:34<22:51:27,  1.87it/s]

Epoch: 0 ,  loss: 6.68142 


  0%|          | 67/153987 [00:35<21:56:58,  1.95it/s]

Epoch: 0 ,  loss: 6.68029 


  0%|          | 68/153987 [00:35<20:56:02,  2.04it/s]

Epoch: 0 ,  loss: 6.67054 


  0%|          | 69/153987 [00:36<23:14:40,  1.84it/s]

Epoch: 0 ,  loss: 6.67719 


  0%|          | 70/153987 [00:36<22:38:00,  1.89it/s]

Epoch: 0 ,  loss: 6.67953 


  0%|          | 71/153987 [00:37<21:13:52,  2.01it/s]

Epoch: 0 ,  loss: 6.67127 


  0%|          | 72/153987 [00:37<21:01:03,  2.03it/s]

Epoch: 0 ,  loss: 6.67767 


  0%|          | 73/153987 [00:38<20:02:58,  2.13it/s]

Epoch: 0 ,  loss: 6.67984 


  0%|          | 74/153987 [00:38<24:23:45,  1.75it/s]

Epoch: 0 ,  loss: 6.68615 


  0%|          | 75/153987 [00:39<27:02:24,  1.58it/s]

Epoch: 0 ,  loss: 6.67699 


  0%|          | 76/153987 [00:40<27:27:07,  1.56it/s]

Epoch: 0 ,  loss: 6.67757 


  0%|          | 77/153987 [00:41<34:26:08,  1.24it/s]

Epoch: 0 ,  loss: 6.67669 


  0%|          | 78/153987 [00:42<30:48:18,  1.39it/s]

Epoch: 0 ,  loss: 6.67970 


  0%|          | 79/153987 [00:42<27:19:46,  1.56it/s]

Epoch: 0 ,  loss: 6.67105 


  0%|          | 80/153987 [00:43<25:37:37,  1.67it/s]

Epoch: 0 ,  loss: 6.67617 


  0%|          | 81/153987 [00:43<23:59:38,  1.78it/s]

Epoch: 0 ,  loss: 6.68061 


  0%|          | 82/153987 [00:44<22:58:25,  1.86it/s]

Epoch: 0 ,  loss: 6.66734 





KeyboardInterrupt: 

In [None]:
# Save weights
torch.save(model.state_dict(), 'weights/lstm__model.pt')

## 5. Инференс полученной модели

In [102]:
def generator(model, sequences, idx_to_char, n_chars):
  
    # Set the model in evalulation mode
    model.eval()

    # Define the softmax function
    softmax = nn.Softmax(dim=1)

    # Randomly is selected the index from the set of sequences
    start = np.random.randint(0, len(sequences)-1)

    # The pattern is defined given the random idx
    pattern = sequences[start]

    # By making use of the dictionaries, it is printed the pattern
    print("\nPattern: \n")
    print(''.join([idx_to_char[value] for value in pattern]), "\"")

    # In full_prediction we will save the complete prediction
    full_prediction = pattern.copy()

    # The prediction starts, it is going to be predicted a given
    # number of characters
    for i in range(n_chars):

        # The numpy patterns is transformed into a tesor-type and reshaped
        pattern = torch.from_numpy(pattern).type(torch.LongTensor)
        pattern = pattern.view(1,-1)

        # Make a prediction given the pattern
        prediction = model(pattern)
        # It is applied the softmax function to the predicted tensor
        prediction = softmax(prediction)

        # The prediction tensor is transformed into a numpy array
        prediction = prediction.squeeze().detach().numpy()
        # It is taken the idx with the highest probability
        arg_max = np.argmax(prediction)

        # The current pattern tensor is transformed into numpy array
        pattern = pattern.squeeze().detach().numpy()
        # The window is sliced 1 character to the right
        pattern = pattern[1:]
        # The new pattern is composed by the "old" pattern + the predicted character
        pattern = np.append(pattern, arg_max)

        # The full prediction is saved
        full_prediction = np.append(full_prediction, arg_max)

    print("Prediction: \n")
    print(''.join([idx_to_char[value] for value in full_prediction]), "\"")

приветствие затянулось на несколько вродители,
им зае призиденторта согорами в сосесуми рроведе с сосрийской ке просева об пнобо с общела он вобороднтое влодедать в инсотия чесьдае в сосрий сазорунотти в соссий сакали проведа пода саков


In [120]:
output_character

['<PAD>']

In [72]:
bpe.decode([list(X[522].reshape((max_sequence_length-1)))])

["не планирует передислоцировать наблюдательные пункты в идлибской зоне деэскалации , при этом турция продолжит отправлять военных и бронетехнику в этот район 'в целях защиты мирного населения' . ка"]

In [28]:
# Для bpe
composition = "не планирует передислоцировать наблюдательные пункты в идлибской зоне деэскалации , при этом турция продолжит отправлять военных и бронетехнику в этот район 'в целях защиты мирного населения' . ка"
next_words = 200
  
for _ in range(next_words):
    token_list = bpe.encode(composition)
    token_list = pad_sequences([token_list], maxlen=max_sequence_length-1, padding='pre', truncating="pre")
    token_list = token_list.reshape((1,max_sequence_length-1,1))
    predicted = np.argmax(model2.predict(token_list), axis=-1)
    output_character = bpe.decode([predicted])[0]
    composition += output_character
print(composition)

не планирует передислоцировать наблюдательные пункты в идлибской зоне деэскалации , при этом турция продолжит отправлять военных и бронетехнику в этот район 'в целях защиты мирного населения' . ка,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [57]:
bpe.encode(["приветствие затянулось на несколько "])

[[126, 341, 322, 6, 90, 10, 20, 156, 603, 71, 97, 109, 448]]

In [44]:
import random

In [45]:
def return_ordered_indices(ar):
    d = {i:v for i,v in enumerate(ar)}
    return sorted(d, key=d.get, reverse=True)

In [46]:
# ensures always sums to 1
def normalize_softmax(ar):
    s = sum(ar)
    if (s!=1):
        ar[0] += 1-s
    return ar
        

In [47]:
composition = "как было сказано "
next_words = 200
T = 2 # токены из top-T будут случайно выбираться
temperature = 1 # параметр сглаживания распределения выбранных токенов
  
for _ in range(next_words):
    token_list = tokenizer.texts_to_sequences([composition])[0]
    token_list = pad_sequences([token_list], maxlen=window_length-1, padding='pre')
    token_list = token_list.reshape((1,window_length-1,1))
    
    output = model2.predict(token_list)
    topmost_indicies = return_ordered_indices(output[0, :])[:T]
    probs = tf.nn.softmax(output[0, topmost_indicies] /  temperature).numpy()
    probs = normalize_softmax(probs)
    predicted = np.random.choice(topmost_indicies, p=probs)
#     predicted = topmost_indicies[0]
    output_character = tokenizer.sequences_to_texts([[predicted]])[0]
    composition += output_character
print(composition)

как было сказано осенее с онкет пантоти и ресетиеее ес о осессон в паттеле онти портовония сакомо презедлитеви презисселе нерари соргорам сосриий,
подономо по накраветения подоваи полодать по накогния сообщал возении,
