# Контекст

- Создать модель генерации текста на основе RNN LSTM
- Проверить модель с помощью класса Categorical для большего разнообразия текста

# Импорты

In [23]:
# база
import numpy as np

# troch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions.categorical import Categorical


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Чтение

In [2]:
with open('Жюль_Верн.txt', 'r', encoding = 'utf-8') as file:
    text = file.read()
    delete_tokens = ['\n', '\t', '\r']
    text = ''.join([token for token in text if token not in delete_tokens])

char_set = set(text)
char_set

{' ',
 '!',
 '&',
 '(',
 ')',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 ';',
 '=',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '‘',
 '’',
 '“',
 '”'}

In [3]:
print('Вся длина текста:  ', len(text))
print('Количество уникальных символов:  ', len(char_set))

Вся длина текста:   1089154
Количество уникальных символов:   78


# Словари

In [4]:
char2int = {char:i for i, char in enumerate(sorted(char_set))}
int2char = {char:i for char, i in enumerate(char2int)}

int2char

{0: ' ',
 1: '!',
 2: '&',
 3: '(',
 4: ')',
 5: ',',
 6: '-',
 7: '.',
 8: '/',
 9: '0',
 10: '1',
 11: '2',
 12: '3',
 13: '4',
 14: '5',
 15: '6',
 16: '7',
 17: '8',
 18: '9',
 19: ':',
 20: ';',
 21: '=',
 22: '?',
 23: 'A',
 24: 'B',
 25: 'C',
 26: 'D',
 27: 'E',
 28: 'F',
 29: 'G',
 30: 'H',
 31: 'I',
 32: 'J',
 33: 'K',
 34: 'L',
 35: 'M',
 36: 'N',
 37: 'O',
 38: 'P',
 39: 'Q',
 40: 'R',
 41: 'S',
 42: 'T',
 43: 'U',
 44: 'V',
 45: 'W',
 46: 'Y',
 47: 'Z',
 48: 'a',
 49: 'b',
 50: 'c',
 51: 'd',
 52: 'e',
 53: 'f',
 54: 'g',
 55: 'h',
 56: 'i',
 57: 'j',
 58: 'k',
 59: 'l',
 60: 'm',
 61: 'n',
 62: 'o',
 63: 'p',
 64: 'q',
 65: 'r',
 66: 's',
 67: 't',
 68: 'u',
 69: 'v',
 70: 'w',
 71: 'x',
 72: 'y',
 73: 'z',
 74: '‘',
 75: '’',
 76: '“',
 77: '”'}

In [5]:
text_encoded = np.array([char2int[char] for char in text])
len(text_encoded)

print(text[:21], ' ====> ', text_encoded[:10])
print(text_encoded[100:124], ' ====> ', ''.join([int2char[index] for index in text_encoded[100:124]]))

THE MYSTERIOUS ISLAND  ====>  [42 30 27  0 35 46 41 42 27 40]
[77  0 76 36 62  7  0 37 61  0 67 55 52  0 50 62 61 67 65 48 65 72  7 77]  ====>  ” “No. On the contrary.”


In [27]:
np.array(sorted(char_set))

array([' ', '!', '&', '(', ')', ',', '-', '.', '/', '0', '1', '2', '3',
       '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', 'A', 'B', 'C',
       'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
       'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', 'a', 'b', 'c', 'd',
       'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
       'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '‘', '’', '“', '”'],
      dtype='<U1')

# Создание последовательностей

In [6]:
seq_length = 40
chunk_size = seq_length + 1
text_chunks = [text_encoded[i:i+chunk_size] for i in range(len(text_encoded)-chunk_size)]

text_chunks

[array([42, 30, 27,  0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41,
        34, 23, 36, 26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61,
        52, 10, 17, 16, 13, 38, 23]),
 array([30, 27,  0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34,
        23, 36, 26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52,
        10, 17, 16, 13, 38, 23, 40]),
 array([27,  0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34, 23,
        36, 26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52, 10,
        17, 16, 13, 38, 23, 40, 42]),
 array([ 0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34, 23, 36,
        26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52, 10, 17,
        16, 13, 38, 23, 40, 42,  0]),
 array([35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34, 23, 36, 26,
        49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52, 10, 17, 16,
        13, 38, 23, 40, 42,  0, 10]),
 array([46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31

# Dataset

In [7]:
class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()
    

seq_dataset = TextDataset(torch.tensor(text_chunks))

  seq_dataset = TextDataset(torch.tensor(text_chunks))


In [8]:
seq_dataset[0]

(tensor([42, 30, 27,  0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34,
         23, 36, 26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52, 10,
         17, 16, 13, 38]),
 tensor([30, 27,  0, 35, 46, 41, 42, 27, 40, 31, 37, 43, 41,  0, 31, 41, 34, 23,
         36, 26, 49, 72,  0, 32, 68, 59, 52, 66,  0, 44, 52, 65, 61, 52, 10, 17,
         16, 13, 38, 23]))

# Dataloader

In [9]:
torch.manual_seed(42)

dataloader = DataLoader(
    dataset = seq_dataset,
    batch_size = 64,
    shuffle = True,
    drop_last = True
)

In [10]:
x_in = torch.zeros((64, 40)).long()
y_in = torch.zeros((64, 40)).long()
for i in range(0,64):
    values = dataloader.dataset[i][0]
    labels = dataloader.dataset[i][1]
    x_in[i] = values
    y_in[i] = labels

In [11]:
x_in

tensor([[42, 30, 27,  ..., 16, 13, 38],
        [30, 27,  0,  ..., 13, 38, 23],
        [27,  0, 35,  ..., 38, 23, 40],
        ...,
        [30, 27,  0,  ..., 61, 22, 77],
        [27,  0, 25,  ..., 22, 77,  0],
        [ 0, 25, 34,  ..., 77,  0, 76]])

# Model

In [12]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings = vocab_size,
                                      embedding_dim = embed_dim)
        
        self.rnn_hidden_size = rnn_hidden_size

        self.rnn = nn.LSTM(input_size = embed_dim,
                           hidden_size = rnn_hidden_size,
                           batch_first = True)
        
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden, cell 

In [13]:
vocab_size = len(char2int)
embed_dim = 256
rnn_hidden_size = 512

torch.manual_seed(42)
model = RNN(vocab_size, embed_dim, rnn_hidden_size).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model

RNN(
  (embedding): Embedding(78, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=78, bias=True)
)

# Training loop 

In [14]:
num_epochs = 10000
torch.manual_seed(42)

for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(64)
    hidden = hidden.to(device)
    cell = cell.to(device)
    seq_batch, target_batch = next(iter(dataloader))
    seq_batch = seq_batch.to(device)
    target_batch = target_batch.to(device)
    optimizer.zero_grad()
    loss = 0
    for c in range(seq_length):
        pred, hidden, cell = model(seq_batch[:, c], hidden, cell)
        loss += loss_fn(pred, target_batch[:, c])
    loss.backward()
    optimizer.step()
    loss = loss.item()/seq_length
    if epoch % 500 == 0:
        print(f'Epoch {epoch} loss: {loss:.4f}') 

Epoch 0 loss: 4.3584
Epoch 500 loss: 1.5885
Epoch 1000 loss: 1.3919
Epoch 1500 loss: 1.3582
Epoch 2000 loss: 1.2324
Epoch 2500 loss: 1.2959
Epoch 3000 loss: 1.1964
Epoch 3500 loss: 1.1376
Epoch 4000 loss: 1.2020
Epoch 4500 loss: 1.1001
Epoch 5000 loss: 1.1422
Epoch 5500 loss: 1.1292
Epoch 6000 loss: 1.1222
Epoch 6500 loss: 1.0985
Epoch 7000 loss: 1.0398
Epoch 7500 loss: 1.1102
Epoch 8000 loss: 1.0839
Epoch 8500 loss: 1.0589
Epoch 9000 loss: 1.0789
Epoch 9500 loss: 1.0420


# Eval (Categorical)

жесткая привязки вероятности после softmax может привести к ограниченной разнообразности в генерируемых текстах. Это может вызвать повторение фраз, циклические шаблоны или ограниченный набор различий между разными генерируемыми последовательностями.

Используем класс Categorical

In [40]:
def sample(
        model,
        starting_str,
        len_generated_text = 40,
        scale_factor = 1.0
):
    # кодируем строку
    encoded_input = torch.tensor(
        [char2int[token] for token in starting_str])
    # print('Кодированный вход: ', encoded_input)
    encoded_input = encoded_input.reshape(1,-1)
    # print('Reshape кодированный вход: ', encoded_input)
    generated_str = starting_str
    # print('Строка начала: ', generated_str)

    model = model.cpu()
    model.eval()
    hidden, cell = model.init_hidden(1)
    for i in range(len(starting_str) - 1):
        _, hidden, cell = model(encoded_input[:,i].view(1),
                                hidden,
                                cell)
    
    last_token = encoded_input[:,-1]
    # print('последний токена начального текста: ', last_token)
    for j in range(len_generated_text):
        logits, hidden, cell = model(last_token.view(1),
                                     hidden,
                                     cell)
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits*scale_factor
        m = Categorical(logits = scaled_logits)
        # print('Класс Categorical: ', m)
        last_token = m.sample()
        # print('Последний токен: ', last_token)
        generated_str += str(np.array(list(char2int))[last_token])
        print(generated_str)
        # print('____________________________')

    return generated_str

In [41]:
torch.manual_seed(42)
print(sample(model, starting_str = 'The island'))

The island 
The island b
The island be
The island bec
The island beca
The island becam
The island became
The island became 
The island became a
The island became a 
The island became a t
The island became a tr
The island became a tre
The island became a trea
The island became a treas
The island became a treasu
The island became a treasur
The island became a treasure
The island became a treasure 
The island became a treasure h
The island became a treasure ha
The island became a treasure had
The island became a treasure had 
The island became a treasure had o
The island became a treasure had on
The island became a treasure had onl
The island became a treasure had only
The island became a treasure had only 
The island became a treasure had only c
The island became a treasure had only cl
The island became a treasure had only cla
The island became a treasure had only clas
The island became a treasure had only class
The island became a treasure had only classe
The island became a treasure ha