# Генерация Текста

Что мы хотим от генеративной модели?

Мы сфокусируемся на вероятностной формулировке. Считаем, что на множестве данных $X$ есть некоторое истинное распределение $P^*(X)$. Генеративная модель будет приближать это распределение с помощью максимизации правдоподобия. Далее, из распределения, которое она выучила, мы хотим уметь сэмплировать новые примеры. 

На сегодняшний день самый мейнстрим-подход к текстовой генерации - сэмплирования из языковых моделей (GPT-3, T5 и тд).

Текст мы представляем как последовательность токенов: $x = [x_1, x_2, ..., x_N]$.

$$p([x_1, x_2, x_3]) = p(x_1) \cdot p(x_2 | x_1) \cdot p(x_3|x_1, x_2)$$
$$p(x) = \prod_{i=1}^{N}p(x_i|x_1, ..., x_{i-1})$$
$$\log p(x) = \sum_{i=1}^{N}\log p(x_i|x_1, ..., x_{i-1})$$ 

Мы построим модель с архитектурой, похожей на GPT - несколько слоёв Transformer Decoder-а.

Для данных будем использовать датасет, состоящий из английских стихотворений: Project Gutenberg Poetry Corpus. Для токенов обучим Byte-level BPE из библиотеки tokenizers c достаточно большим размером словаря. 

In [None]:
!curl -O http://static.decontextualize.com/gutenberg-poetry-v001.ndjson.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 52.2M  100 52.2M    0     0  32.9M      0  0:00:01  0:00:01 --:--:-- 32.9M


In [None]:
import gzip, json

from tqdm import tqdm

lines = []
for line in tqdm(gzip.open("gutenberg-poetry-v001.ndjson.gz")):
    lines.append(json.loads(line.strip()))

3085117it [00:29, 102959.22it/s]


In [None]:
lines[18000]

{'s': 'For regal scepter then no more shall need,', 'gid': '26'}

In [None]:
!pip install tokenizers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tokenizers
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers
Successfully installed tokenizers-0.13.2


In [None]:
import torch

import random
import re
from tokenizers import ByteLevelBPETokenizer

def get_data_poems(lines, vocab_size):
  tokenizer = ByteLevelBPETokenizer(dropout=0.1, lowercase=True)

  tokenizer.train_from_iterator([line['s'] + '\n' for line in lines], vocab_size=vocab_size)

  tokenizer.add_special_tokens(["[SOS]", "[EOS]", "[PAD]"])

  SOS_id = tokenizer.token_to_id("[SOS]")
  EOS_id = tokenizer.token_to_id("[EOS]")

  nl_id = tokenizer.encode("\n").ids[0]

  last_poem_id = -1
  chunk = []
  train_chunks = []
  val_chunks = []
  for line in tqdm(lines):
    poem_id = line['gid']

    line_ids = tokenizer.encode(line['s']).ids

    if len(chunk) + len(line_ids) < 64 and poem_id == last_poem_id:
      chunk.extend([nl_id] + line_ids)
    else:
      if chunk:
          if random.random() > 0.01:
              train_chunks.append([SOS_id] + chunk + [EOS_id])
          else:
              val_chunks.append([SOS_id] + chunk + [EOS_id])

      if len(line_ids) < 64:
          chunk = line_ids
      else:
          chunk = []
    
    last_poem_id = poem_id

  return LMDataset(train_chunks), LMDataset(val_chunks), tokenizer

class LMDataset(torch.utils.data.Dataset):
    def __init__(self, chunks):
        super(LMDataset).__init__()
        self.data = chunks

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn_lm(PAD_id, samples):
    batch_size = len(samples)
    emb_size = sample[0].size()[1]

    max_len = max(len(sample) for sample in samples)

    src_tensor = torch.ones((batch_size, max_len, emb_size), dtype=torch.long) * PAD_id

    lengths = []
    for (batch_id, s) in enumerate(samples):
        length = len(s)

        src_tensor[batch_id][:length] = torch.tensor(s)

        lengths.append(length)

    return src_tensor, torch.tensor(lengths)


In [None]:
train_dataset, val_dataset, tokenizer = list(get_data_poems(lines, 8192))

SOS_id = tokenizer.token_to_id("[SOS]")
EOS_id = tokenizer.token_to_id("[EOS]")
PAD_id = tokenizer.token_to_id("[PAD]")

100%|██████████| 3085117/3085117 [02:04<00:00, 24778.62it/s]


In [None]:
print(f"{len(train_dataset)} стихов")
print("Пример:\n")

print(tokenizer.decode(train_dataset[18]))

587887 стихов
Пример:

and beyond them stood the forest,
stood the groves of singing pine-trees,
green in summer, white in winter,
ever sighing, ever singing.
"and the pleasant water-courses,
you could trace them through the valley,
by the rushing in the spring-time,


Определим нашу модель. Как и модели семейства GPT, это просто несколько слоёв Transformer Decoder-а.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, hidden_size)
        
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        
        return self.dropout(x)

class Model(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_heads, n_layers, dropout):
        super(Model, self).__init__()

        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, hidden_size)

        self.pos_emb = PositionalEncoding(hidden_size)
 
        layer = TransformerEncoderLayer(hidden_size, n_heads, hidden_size, dropout)

        self.layers = TransformerEncoder(layer, n_layers)

        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x_len = x.size(1)

        padding_mask = (x == PAD_id)

        x = self.pos_emb(self.emb(x) * math.sqrt(self.vocab_size))

        attn_mask = nn.Transformer.generate_square_subsequent_mask(x_len).to(device)

        out = self.layers(x.transpose(0, 1), attn_mask, padding_mask).transpose(0, 1)

        out = self.out(out)

        return out

Можно заметить, однако, что в коде выше используется модуль из pytorch, который называется TransfomerEncoder. Существует некоторая путаница, что называть Transformer Decoder-ом. В оргинальной статье про трансформер https://arxiv.org/abs/1706.03762 декодер используется для перевода и имеет два блока внимания, self-attention, и attention, который "смотрит" на выходы энкодера. При этом в GPT и подобных моделят используется один блок self-attention. Отличие от энкодера здесь в авторегрессионной маске аттеншена, которая заставляет модель смотреть только на предыдущие токены. 

*Для заданий ниже можно использовать модель, которая будет больше, и учить её дольше (но не меньше).* 

In [None]:
from torch.utils.data import DataLoader
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.get_vocab_size()
hidden_size = 512
n_layers = 5
n_heads = 8
dropout = 0.1

batch_size = 128
epochs = 3

model = Model(vocab_size, hidden_size, n_heads, n_layers, dropout).to(device)

train_loader = DataLoader(
    train_dataset
    , batch_size=batch_size
    , shuffle=True
    , collate_fn=partial(collate_fn_lm, PAD_id)
)

val_loader = DataLoader(
    val_dataset
    , batch_size=batch_size
    , shuffle=False
    , collate_fn=partial(collate_fn_lm, PAD_id)
    , drop_last=True
)

criterion = nn.CrossEntropyLoss(reduction='none')
lr = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
from tqdm import tqdm

def train(model, train_loader, val_loader, epochs, val_each=100):
    for epoch in range(1, epochs+1):
      for idx, (batch, _) in enumerate(tqdm(train_loader)):
          batch = batch.to(device)
          src = batch[:, :-1]
          tar = batch[:, 1:]

          optimizer.zero_grad()

          out = model(src)

          loss = criterion(out.transpose(-2, -1), tar)[src != PAD_id].mean()

          loss.backward()
          grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

          optimizer.step()

          if (idx + 1) % val_each == 0:
            total_loss = 0.0
            n = 0
            for batch, _ in val_loader:
              model.eval()

              batch = batch.to(device)
              src = batch[:, :-1]
              tar = batch[:, 1:]

              out = model(src)

              loss = criterion(out.transpose(-2, -1), tar)[src != PAD_id].mean()

              total_loss += loss.item()
              n += 1

              model.train()
            
            print(f"Val loss: {total_loss/n:.2f}")
  
    return model

In [None]:
model = train(model, train_loader, val_loader, epochs)

model.eval()

print("OK")

  2%|▏         | 100/4593 [00:28<1:39:16,  1.33s/it]

Val loss: 6.08


  4%|▍         | 200/4593 [00:55<1:39:48,  1.36s/it]

Val loss: 5.83


  7%|▋         | 300/4593 [01:22<1:39:37,  1.39s/it]

Val loss: 5.64


  9%|▊         | 400/4593 [01:50<1:40:32,  1.44s/it]

Val loss: 5.54


 11%|█         | 500/4593 [02:19<1:38:07,  1.44s/it]

Val loss: 5.46


 13%|█▎        | 600/4593 [02:48<1:35:24,  1.43s/it]

Val loss: 5.36


 15%|█▌        | 700/4593 [03:16<1:32:24,  1.42s/it]

Val loss: 5.31


 17%|█▋        | 800/4593 [03:45<1:31:33,  1.45s/it]

Val loss: 5.23


 20%|█▉        | 900/4593 [04:13<1:29:48,  1.46s/it]

Val loss: 5.19


 22%|██▏       | 1000/4593 [04:41<1:25:20,  1.43s/it]

Val loss: 5.12


 24%|██▍       | 1100/4593 [05:10<1:22:50,  1.42s/it]

Val loss: 5.10


 26%|██▌       | 1200/4593 [05:38<1:20:55,  1.43s/it]

Val loss: 5.04


 28%|██▊       | 1300/4593 [06:06<1:18:06,  1.42s/it]

Val loss: 5.00


 30%|███       | 1400/4593 [06:34<1:15:48,  1.42s/it]

Val loss: 4.97


 33%|███▎      | 1500/4593 [07:03<1:13:51,  1.43s/it]

Val loss: 4.94


 35%|███▍      | 1600/4593 [07:31<1:10:59,  1.42s/it]

Val loss: 4.92


 37%|███▋      | 1700/4593 [07:59<1:08:39,  1.42s/it]

Val loss: 4.89


 39%|███▉      | 1800/4593 [08:27<1:06:40,  1.43s/it]

Val loss: 4.86


 41%|████▏     | 1900/4593 [08:56<1:03:56,  1.42s/it]

Val loss: 4.83


 44%|████▎     | 2000/4593 [09:24<1:01:34,  1.42s/it]

Val loss: 4.82


 46%|████▌     | 2100/4593 [09:52<59:25,  1.43s/it]

Val loss: 4.80


 48%|████▊     | 2200/4593 [10:20<56:50,  1.43s/it]

Val loss: 4.78


 50%|█████     | 2300/4593 [10:48<54:30,  1.43s/it]

Val loss: 4.79


 52%|█████▏    | 2400/4593 [11:17<52:17,  1.43s/it]

Val loss: 4.75


 54%|█████▍    | 2500/4593 [11:45<49:37,  1.42s/it]

Val loss: 4.72


 57%|█████▋    | 2600/4593 [12:13<47:14,  1.42s/it]

Val loss: 4.71


 59%|█████▉    | 2700/4593 [12:41<45:08,  1.43s/it]

Val loss: 4.74


 61%|██████    | 2800/4593 [13:10<42:33,  1.42s/it]

Val loss: 4.67


 63%|██████▎   | 2900/4593 [13:38<40:06,  1.42s/it]

Val loss: 4.69


 65%|██████▌   | 3000/4593 [14:06<38:00,  1.43s/it]

Val loss: 4.65


 67%|██████▋   | 3100/4593 [14:34<35:26,  1.42s/it]

Val loss: 4.66


 70%|██████▉   | 3200/4593 [15:02<33:04,  1.42s/it]

Val loss: 4.64


 72%|███████▏  | 3300/4593 [15:31<30:45,  1.43s/it]

Val loss: 4.64


 74%|███████▍  | 3400/4593 [15:59<28:15,  1.42s/it]

Val loss: 4.62


 76%|███████▌  | 3500/4593 [16:27<25:57,  1.42s/it]

Val loss: 4.61


 78%|███████▊  | 3600/4593 [16:55<23:39,  1.43s/it]

Val loss: 4.59


 81%|████████  | 3700/4593 [17:24<21:11,  1.42s/it]

Val loss: 4.58


 83%|████████▎ | 3800/4593 [17:52<18:49,  1.42s/it]

Val loss: 4.58


 85%|████████▍ | 3900/4593 [18:20<16:30,  1.43s/it]

Val loss: 4.58


 87%|████████▋ | 4000/4593 [18:48<14:03,  1.42s/it]

Val loss: 4.61


 89%|████████▉ | 4100/4593 [19:16<11:39,  1.42s/it]

Val loss: 4.56


 91%|█████████▏| 4200/4593 [19:45<09:22,  1.43s/it]

Val loss: 4.55


 94%|█████████▎| 4300/4593 [20:13<06:57,  1.43s/it]

Val loss: 4.56


 96%|█████████▌| 4400/4593 [20:41<04:34,  1.42s/it]

Val loss: 4.53


 98%|█████████▊| 4500/4593 [21:09<02:12,  1.43s/it]

Val loss: 4.52


100%|██████████| 4593/4593 [21:32<00:00,  3.55it/s]
  2%|▏         | 100/4593 [00:28<1:46:47,  1.43s/it]

Val loss: 4.54


  4%|▍         | 200/4593 [00:56<1:44:03,  1.42s/it]

Val loss: 4.52


  7%|▋         | 300/4593 [01:24<1:42:10,  1.43s/it]

Val loss: 4.51


  9%|▊         | 400/4593 [01:52<1:39:46,  1.43s/it]

Val loss: 4.54


 11%|█         | 500/4593 [02:21<1:37:04,  1.42s/it]

Val loss: 4.49


 13%|█▎        | 600/4593 [02:49<1:35:12,  1.43s/it]

Val loss: 4.52


 15%|█▌        | 700/4593 [03:17<1:32:31,  1.43s/it]

Val loss: 4.47


 17%|█▋        | 800/4593 [03:45<1:30:08,  1.43s/it]

Val loss: 4.47


 20%|█▉        | 900/4593 [04:14<1:28:05,  1.43s/it]

Val loss: 4.47


 22%|██▏       | 1000/4593 [04:42<1:25:22,  1.43s/it]

Val loss: 4.47


 24%|██▍       | 1100/4593 [05:10<1:22:48,  1.42s/it]

Val loss: 4.48


 26%|██▌       | 1200/4593 [05:38<1:20:50,  1.43s/it]

Val loss: 4.47


 28%|██▊       | 1300/4593 [06:06<1:18:08,  1.42s/it]

Val loss: 4.44


 30%|███       | 1400/4593 [06:34<1:15:45,  1.42s/it]

Val loss: 4.45


 33%|███▎      | 1500/4593 [07:03<1:13:44,  1.43s/it]

Val loss: 4.43


 35%|███▍      | 1600/4593 [07:31<1:11:01,  1.42s/it]

Val loss: 4.44


 37%|███▋      | 1700/4593 [07:59<1:08:41,  1.42s/it]

Val loss: 4.43


 39%|███▉      | 1800/4593 [08:27<1:06:31,  1.43s/it]

Val loss: 4.43


 41%|████▏     | 1900/4593 [08:55<1:03:49,  1.42s/it]

Val loss: 4.43


 44%|████▎     | 2000/4593 [09:23<1:01:24,  1.42s/it]

Val loss: 4.43


 46%|████▌     | 2100/4593 [09:52<59:19,  1.43s/it]

Val loss: 4.42


 48%|████▊     | 2200/4593 [10:20<56:49,  1.42s/it]

Val loss: 4.46


 50%|█████     | 2300/4593 [10:48<54:28,  1.43s/it]

Val loss: 4.41


 52%|█████▏    | 2400/4593 [11:16<52:12,  1.43s/it]

Val loss: 4.40


 54%|█████▍    | 2500/4593 [11:45<49:43,  1.43s/it]

Val loss: 4.42


 57%|█████▋    | 2600/4593 [12:13<47:14,  1.42s/it]

Val loss: 4.38


 59%|█████▉    | 2700/4593 [12:41<45:05,  1.43s/it]

Val loss: 4.41


 61%|██████    | 2800/4593 [13:09<42:23,  1.42s/it]

Val loss: 4.40


 63%|██████▎   | 2900/4593 [13:37<40:07,  1.42s/it]

Val loss: 4.38


 65%|██████▌   | 3000/4593 [14:06<37:52,  1.43s/it]

Val loss: 4.38


 67%|██████▋   | 3100/4593 [14:34<35:23,  1.42s/it]

Val loss: 4.40


 70%|██████▉   | 3200/4593 [15:02<32:56,  1.42s/it]

Val loss: 4.38


 72%|███████▏  | 3300/4593 [15:30<30:44,  1.43s/it]

Val loss: 4.43


 74%|███████▍  | 3400/4593 [15:58<28:20,  1.43s/it]

Val loss: 4.36


 76%|███████▌  | 3500/4593 [16:26<25:58,  1.43s/it]

Val loss: 4.36


 78%|███████▊  | 3600/4593 [16:55<23:43,  1.43s/it]

Val loss: 4.36


 81%|████████  | 3700/4593 [17:23<21:11,  1.42s/it]

Val loss: 4.37


 83%|████████▎ | 3800/4593 [17:51<18:48,  1.42s/it]

Val loss: 4.38


 85%|████████▍ | 3900/4593 [18:19<16:29,  1.43s/it]

Val loss: 4.36


 87%|████████▋ | 4000/4593 [18:47<14:03,  1.42s/it]

Val loss: 4.36


 89%|████████▉ | 4100/4593 [19:15<11:41,  1.42s/it]

Val loss: 4.36


 91%|█████████▏| 4200/4593 [19:44<09:20,  1.43s/it]

Val loss: 4.35


 94%|█████████▎| 4300/4593 [20:12<06:57,  1.42s/it]

Val loss: 4.35


 96%|█████████▌| 4400/4593 [20:40<04:34,  1.42s/it]

Val loss: 4.33


 98%|█████████▊| 4500/4593 [21:08<02:12,  1.43s/it]

Val loss: 4.34


100%|██████████| 4593/4593 [21:31<00:00,  3.56it/s]
  2%|▏         | 100/4593 [00:28<1:46:31,  1.42s/it]

Val loss: 4.34


  4%|▍         | 200/4593 [00:56<1:44:11,  1.42s/it]

Val loss: 4.36


  7%|▋         | 300/4593 [01:24<1:42:21,  1.43s/it]

Val loss: 4.33


  9%|▊         | 400/4593 [01:52<1:39:18,  1.42s/it]

Val loss: 4.32


 11%|█         | 500/4593 [02:20<1:36:48,  1.42s/it]

Val loss: 4.34


 13%|█▎        | 600/4593 [02:49<1:35:00,  1.43s/it]

Val loss: 4.31


 15%|█▌        | 700/4593 [03:17<1:32:15,  1.42s/it]

Val loss: 4.32


 17%|█▋        | 800/4593 [03:45<1:29:44,  1.42s/it]

Val loss: 4.35


 20%|█▉        | 900/4593 [04:13<1:27:47,  1.43s/it]

Val loss: 4.31


 22%|██▏       | 1000/4593 [04:41<1:25:15,  1.42s/it]

Val loss: 4.33


 24%|██▍       | 1100/4593 [05:09<1:22:59,  1.43s/it]

Val loss: 4.32


 26%|██▌       | 1200/4593 [05:37<1:20:41,  1.43s/it]

Val loss: 4.34


 28%|██▊       | 1300/4593 [06:06<1:18:06,  1.42s/it]

Val loss: 4.32


 30%|███       | 1400/4593 [06:34<1:15:39,  1.42s/it]

Val loss: 4.30


 33%|███▎      | 1500/4593 [07:02<1:13:37,  1.43s/it]

Val loss: 4.29


 35%|███▍      | 1600/4593 [07:30<1:10:58,  1.42s/it]

Val loss: 4.32


 37%|███▋      | 1700/4593 [07:58<1:08:44,  1.43s/it]

Val loss: 4.29


 39%|███▉      | 1800/4593 [08:27<1:06:30,  1.43s/it]

Val loss: 4.30


 41%|████▏     | 1900/4593 [08:55<1:03:56,  1.42s/it]

Val loss: 4.29


 44%|████▎     | 2000/4593 [09:23<1:01:27,  1.42s/it]

Val loss: 4.29


 46%|████▌     | 2100/4593 [09:51<59:12,  1.43s/it]

Val loss: 4.30


 48%|████▊     | 2200/4593 [10:19<56:42,  1.42s/it]

Val loss: 4.28


 50%|█████     | 2300/4593 [10:48<54:20,  1.42s/it]

Val loss: 4.27


 52%|█████▏    | 2400/4593 [11:16<52:03,  1.42s/it]

Val loss: 4.28


 54%|█████▍    | 2500/4593 [11:44<49:33,  1.42s/it]

Val loss: 4.28


 57%|█████▋    | 2600/4593 [12:12<47:14,  1.42s/it]

Val loss: 4.28


 59%|█████▉    | 2700/4593 [12:40<45:05,  1.43s/it]

Val loss: 4.29


 61%|██████    | 2800/4593 [13:08<42:27,  1.42s/it]

Val loss: 4.29


 63%|██████▎   | 2900/4593 [13:37<40:02,  1.42s/it]

Val loss: 4.30


 65%|██████▌   | 3000/4593 [14:05<37:49,  1.42s/it]

Val loss: 4.25


 67%|██████▋   | 3100/4593 [14:33<35:22,  1.42s/it]

Val loss: 4.28


 70%|██████▉   | 3200/4593 [15:01<33:03,  1.42s/it]

Val loss: 4.27


 72%|███████▏  | 3300/4593 [15:29<30:47,  1.43s/it]

Val loss: 4.28


 74%|███████▍  | 3400/4593 [15:58<28:18,  1.42s/it]

Val loss: 4.26


 76%|███████▌  | 3500/4593 [16:26<25:55,  1.42s/it]

Val loss: 4.24


 78%|███████▊  | 3600/4593 [16:54<23:39,  1.43s/it]

Val loss: 4.29


 81%|████████  | 3700/4593 [17:22<21:10,  1.42s/it]

Val loss: 4.26


 83%|████████▎ | 3800/4593 [17:50<18:47,  1.42s/it]

Val loss: 4.25


 85%|████████▍ | 3900/4593 [18:18<16:29,  1.43s/it]

Val loss: 4.27


 87%|████████▋ | 4000/4593 [18:47<14:02,  1.42s/it]

Val loss: 4.27


 89%|████████▉ | 4100/4593 [19:15<11:41,  1.42s/it]

Val loss: 4.25


 91%|█████████▏| 4200/4593 [19:43<09:21,  1.43s/it]

Val loss: 4.25


 94%|█████████▎| 4300/4593 [20:11<06:57,  1.42s/it]

Val loss: 4.25


 96%|█████████▌| 4400/4593 [20:40<04:34,  1.42s/it]

Val loss: 4.25


 98%|█████████▊| 4500/4593 [21:08<02:12,  1.43s/it]

Val loss: 4.26


100%|██████████| 4593/4593 [21:30<00:00,  3.56it/s]

OK





# О языковых моделях

### Успехи:

- T5: Exploring the Limits of Transfer Learning with a Unified
Text-to-Text Transformer https://arxiv.org/pdf/1910.10683.pdf

- GPT-3: Language Models are Few-Shot Learners https://arxiv.org/pdf/2005.14165.pdf.

- ChatGPT - она же, дообученная для диалогов.

- Балабоба

###Проблемы:
- The Curious Case of Neural Text De-Generation https://openreview.net/pdf?id=rygGQyrFvH
- A Theoretical Analysis of the Repetition Problem in Text Generation https://arxiv.org/pdf/2012.14660.pdf

###Альтернативы:
- INSNET: An Efficient, Flexible, and Performant
Insertion-based Text Generation Model https://arxiv.org/pdf/2102.11008.pdf


- Structured Denoising Diffusion Models in Discrete
State-Spaces https://arxiv.org/pdf/2107.03006.pdf - неавторегрессионная дискретная диффузия

###Метрики:

- MAUVE: Measuring the Gap
Between Neural Text and Human Text
using Divergence Frontiers https://arxiv.org/pdf/2102.01454.pdf

Вернёмся к нашей сети. Как теперь генерировать новые тексты из неё? Раз сеть выдаёт распределение на токенах на каждом шаге, то можно сэмплировать новый токен в соответствие с этим распределением:

In [None]:
def sample_generate(model, ids, max_len, EOS_id):
    for j in range(len(ids), max_len):
      x = torch.tensor(ids).unsqueeze(0).to(device)

      x_len = x.size(1)

      out = model(x)

      dist = torch.distributions.categorical.Categorical(logits=out[0][x_len-1])

      next_id = dist.sample().item()

      if next_id == EOS_id:
        break

      ids.append(next_id)

    return ids

In [None]:
model.eval()

start_ids = [SOS_id]

sample_ids = sample_generate(model, start_ids, 100, EOS_id)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

must play you now too parched to dance;
raves yours in a porter's glass our coffin round,
and no more to be admiringly supremings
no lessons may digress its wagan parker should screech roar,
and keep the bou where the worst may even itself?"


Мы можем увидеть несуществующие слова, проблемы с грамматикой и тд. Всё из-за того, что при сэмплировании нам может попасться токен, имеющий низкую вероятность с точки зрения модели. С этой проблемой можно попробовать справиться простым способом: ограничить область сэмплирования топ-k токенами, имеющими максимальную вероятность.

In [None]:
def top_k_generate(model, ids, max_len, EOS_id, k):
    for j in range(len(ids), max_len):
      x = torch.tensor(ids).unsqueeze(0).to(device)

      x_len = x.size(1)

      out = model(x)

      topv, topi = out[0][-1].topk(k)

      dist = torch.distributions.categorical.Categorical(logits=topv)

      next_id = topi[dist.sample().item()].item()

      if next_id == EOS_id:
        break

      ids.append(next_id)

    return ids

In [None]:
start_ids = [SOS_id]

sample_ids = top_k_generate(model, start_ids, 100, EOS_id, 100)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

the first rose-flower in their snowy branches dressed,
and the green petals with grass of spring was fed.
and the young birds sang at last
and the summer was born of all the woods among.
the beautiful things of that life,
where all the garden-bird, flowers singing,


In [None]:
start_ids = [SOS_id] + tokenizer.encode("Old McDonald had a farm,\n").ids

sample_ids = top_k_generate(model, start_ids, 100, EOS_id, 1)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

old mcdonald had a farm,
and the old man's wife was a-day,
and the old man's wife was a-day,
and the old man's wife was a-day,
and the old man's wife was a-day,
and the old man's wife was a-day,


Несуществующие слова пропали, можно увидеть более грамматичные предложения. 

# Задание 1

Реализуйте Nucleus Sampling из статьи The Curious Case of Neural Text *De*-Generation: https://openreview.net/pdf?id=rygGQyrFvH

Протестируйте качество генерации на модели выше, сэмплируя стихотворения с помощью нового метода. Попробуйте разные значения $p$, найдите по вашему мнению оптимальное.

In [None]:
import numpy as np

def nucleus_sampling(model, ids, max_len, EOS_id, temp=None, k=None, p=None, m=None):

    context = torch.tensor(ids).unsqueeze(0).to(device)
    
    for i in range(len(ids), max_len):
        logits = model(context)
        logits = logits[:,-1,:]

        probs = F.softmax(logits, dim=-1)
        logprobs = F.log_softmax(logits, dim=-1)

        if temp is not None:
            samp_probs = F.softmax(logits.div_(temp), dim=-1)
        else:
            samp_probs = probs.clone()

        if k is not None:
            indices_to_remove = samp_probs < torch.topk(samp_probs, k)[0][..., -1, None]
            samp_probs[indices_to_remove] = 0
            if m is not None:
                samp_probs.div_(samp_probs.sum(1).unsqueeze(1))
                samp_probs.mul_(1-m)
                samp_probs.add_(probs.mul(m))
            next_tokens = samp_probs.multinomial(1)
            next_logprobs = samp_probs.gather(1, next_tokens.view(-1, 1)).log()
        elif p is not None: 
            sorted_probs, sorted_indices = torch.sort(samp_probs, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumulative_probs > p
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = 0
            sorted_samp_probs = sorted_probs.clone()
            sorted_samp_probs[sorted_indices_to_remove] = 0
            if m is not None:
                sorted_samp_probs.div_(sorted_samp_probs.sum(1).unsqueeze(1))
                sorted_samp_probs.mul_(1-m)
                sorted_samp_probs.add_(sorted_probs.mul(m))
            sorted_next_indices = sorted_samp_probs.multinomial(1).view(-1, 1)
            next_tokens = sorted_indices.gather(1, sorted_next_indices)
            next_logprobs = sorted_samp_probs.gather(1, sorted_next_indices).log()
        else:
            if m is not None:
                samp_probs.div_(samp_probs.sum(1).unsqueeze(1))
                samp_probs.mul_(1-m)
                samp_probs.add_(probs.mul(m))
            next_tokens = samp_probs.multinomial(1)
            next_logprobs = samp_probs.gather(1, next_tokens.view(-1, 1)).log()

        next_cat = next_tokens
        next_tokens, next_logprobs = next_tokens.cpu(), next_logprobs.cpu()

        v = next_tokens[0].item()
        logprob = next_logprobs[0].item()

        if v == EOS_id:
          break

        ids.append(v)
        context = torch.cat([context, next_cat], dim=1)


    return ids

Из всех значений p наиболее правдоподобные стихи получаются при $p=0.5$

In [None]:
start_ids = [SOS_id]

sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, None, 0.9)

sent = tokenizer.decode(sample_ids)
print(sent)

last night of faults its taste that is still,
i'd drop the brow of beauty that rose with dew,
and touch its soul so precious to see;
like foolish longing in its vain caress,
peasless of celestial bliss, to hold,
when first delight's tears met down in your eyes


In [None]:
start_ids = [SOS_id]

sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, None, 0.7)

sent = tokenizer.decode(sample_ids)
print(sent)

the stream'd its billow: then i knew not where i
was absent; yet i saw no more the hue
that, which my eye was pierced, i seem'd; nor i,
hadst to her of some cognom'd sprite
with beatrice.  the spirit of the gate


In [None]:
start_ids = [SOS_id]

sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, None, 0.5)

sent = tokenizer.decode(sample_ids)
print(sent)

there's a boat that sings and sings,
and the birds that hear the music
when they sing to the song of the trees.
there's a sound of drums that fill the breeze
and a drum of wind and a whipper
for a whisper of the blast
that is coming from the shore.


In [None]:
start_ids = [SOS_id]

sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, None, 0.5)

sent = tokenizer.decode(sample_ids)
print(sent)

then suddenly they fell,
as the gray wind swept
through the sand-hills,
on the stranded cliff,
underneath the ice-wood
the wind blew
and the long
that, long ago,
in the sun,
out of the sea,
the sea,


# Задание 2

Для каждого из методов сэмплирования сгенерируйте по 1000 примеров и сравните с 1000 примерами из валидационных данных с помощью метрики MAUVE https://github.com/krishnap25/mauve

Попробуйте разные k в top-k методе и p в nucleus sampling. Также измерьте MAUVE на двух кусках по 1000 примеров из валидации. Сделайте выводы.


In [None]:
!pip install transformers faiss-gpu mauve-text

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m67.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mauve-text
  Downloading mauve_text-0.3.0-py3-none-any.whl (20 kB)
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
Collecting faiss-cpu>=1.7.0
  Downloading faiss_cpu-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.0 MB)
[2K     [90m━━━━━━

In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []
start_ids = [SOS_id]

for i in range(NUM_SAMPLES):
  sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, None, 0.9)
  p_sent = tokenizer.decode(sample_ids)
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

import mauve

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.004261573245065193


In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []
start_ids = [SOS_id]

for i in range(NUM_SAMPLES):
  sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 4, None, 0.6)
  p_sent = tokenizer.decode(sample_ids)
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.0040720962619612555


In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []
start_ids = [SOS_id]

for i in range(NUM_SAMPLES):
  sample_ids = nucleus_sampling(model, start_ids, 100, EOS_id, 1, 5, 0.5)
  p_sent = tokenizer.decode(sample_ids)
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.0043468000469355055


Качество при top-k равным 5

In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []

for i in range(NUM_SAMPLES):
  sample_ids = top_k_generate(model, [SOS_id], 100, EOS_id, 5)
  p_sent = tokenizer.decode(sample_ids[1:])
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.05990003534586924


Качество при top-k равным 7

In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []

for i in range(NUM_SAMPLES):
  sample_ids = top_k_generate(model, [SOS_id], 100, EOS_id, 7)
  p_sent = tokenizer.decode(sample_ids[1:])
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.11319530952344325


Сравнение схожести первой тысячи стихов и следующей тысячи стихов в датасете

In [None]:
NUM_SAMPLES = 1000
p_text = []
q_text = []

for i in range(NUM_SAMPLES):
  p_sent = tokenizer.decode(val_dataset[NUM_SAMPLES + i])
  p_text.append(p_sent)
  q_sent = tokenizer.decode(val_dataset[i])
  q_text.append(q_sent)

out = mauve.compute_mauve(p_text=p_text, q_text=q_text, device_id=0, max_text_length=100, verbose=False)
print(out.mauve)

Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

0.7990969173007928


# Задание 3

Скачайте датасет https://mydata.biz/storage/download/ebcdfe6fb2d546398010e0d6564a79bb/names.zip. Он содержит список имён и фамилий в формате csv. Обработайте данные.

Выберите параметры модели, подходящие для задачи (в том числе параметры токенизации).

Обучите модель, сгенерируйте несколько новых примеров, оцените их качество (глазами).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!unzip "/content/drive/My Drive/ABBYY Homeworks/data_names_csv.zip"

Mounted at /content/drive
Archive:  /content/drive/My Drive/ABBYY Homeworks/data_names_csv.zip
  inflating: _readme.txt             
  inflating: foreign_names.csv       
  inflating: russian_names.csv       
  inflating: russian_surnames.csv    


In [None]:
!cat _readme.txt

EN	This file downloaded from http://mydata.biz/ - databases for business.

RU	Этот файл скачан с http://mydata.biz/ru - базы данных для бизнеса.

ES	Este archivo descargado desde http://mydata.biz/es - la base de datos para el negocio.

DE	Diese Datei hier downloaden http://mydata.biz/de - Datenbank für Unternehmen.

FR	Ce fichier est téléchargé à partir de http://mydata.biz/fr - la base de données pour les entreprises.

ZH-CN	这份文件下载 http://mydata.biz/zh_cn -一个数据库业务

In [None]:
import pandas as pd
from random import shuffle

data = pd.read_csv('russian_names.csv', delimiter= ';')
names = data['Name'].values
shuffle(names)

print(names)

['Прентайс' 'Асхар' 'Ботолф' ... 'Дзвенимыра' 'Алшалада' 'Улвые']


In [None]:
from nltk.tokenize import word_tokenize

In [None]:
import torch

import random
import re
from tokenizers import CharBPETokenizer

def get_data_names(names, vocab_size):
  tokenizer = CharBPETokenizer(dropout=0.1, lowercase=True)

  tokenizer.train_from_iterator([name for name in names], vocab_size=vocab_size)

  tokenizer.add_special_tokens(["[SOS]", "[EOS]", "[PAD]"])

  SOS_id = tokenizer.token_to_id("[SOS]")
  EOS_id = tokenizer.token_to_id("[EOS]")

  chunk = []
  train_chunks = []
  val_chunks = []
  for name in tqdm(names):

    name_ids = tokenizer.encode(name).ids

    if random.random() > 0.01:
        train_chunks.append([SOS_id] + name_ids + [EOS_id])
    else:
        val_chunks.append([SOS_id] + name_ids + [EOS_id])


  return LMDataset(train_chunks), LMDataset(val_chunks), tokenizer

class LMDataset(torch.utils.data.Dataset):
    def __init__(self, chunks):
        super(LMDataset).__init__()
        self.data = chunks

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn_lm(PAD_id, samples):
    batch_size = len(samples)

    max_len = max(len(sample) for sample in samples)

    src_tensor = torch.ones((batch_size, max_len), dtype=torch.long) * PAD_id

    lengths = []
    for (batch_id, s) in enumerate(samples):
        length = len(s)

        src_tensor[batch_id][:length] = torch.tensor(s)

        lengths.append(length)

    return src_tensor, torch.tensor(lengths)

In [None]:
train_dataset, val_dataset, tokenizer = list(get_data_names(names, 100))

SOS_id = tokenizer.token_to_id("[SOS]")
EOS_id = tokenizer.token_to_id("[EOS]")
PAD_id = tokenizer.token_to_id("[PAD]")

100%|██████████| 51529/51529 [00:00<00:00, 81067.98it/s]


In [None]:
print(f"{len(train_dataset)} имен")
print("Пример:\n")

print(tokenizer.decode(train_dataset[18]))

51018 имен
Пример:

арестик


In [None]:
len(val_dataset)

511

In [None]:
from torch.utils.data import DataLoader
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.get_vocab_size()
hidden_size = 512
n_layers = 5
n_heads = 8
dropout = 0.1

batch_size = 128
epochs = 5

model = Model(vocab_size, hidden_size, n_heads, n_layers, dropout).to(device)

train_loader = DataLoader(
    train_dataset
    , batch_size=batch_size
    , shuffle=True
    , collate_fn=partial(collate_fn_lm, PAD_id)
)

val_loader = DataLoader(
    val_dataset
    , batch_size=batch_size
    , shuffle=False
    , collate_fn=partial(collate_fn_lm, PAD_id)
    , drop_last=True
)

criterion = nn.CrossEntropyLoss(reduction='none')
lr = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
model = train(model, train_loader, val_loader, epochs)

model.eval()

print("OK")

 26%|██▌       | 104/399 [00:03<00:09, 32.01it/s]

Val loss: 2.38


 51%|█████     | 204/399 [00:06<00:06, 29.90it/s]

Val loss: 2.32


 76%|███████▌  | 304/399 [00:10<00:04, 23.17it/s]

Val loss: 2.29


100%|██████████| 399/399 [00:16<00:00, 24.73it/s]
 26%|██▋       | 105/399 [00:04<00:12, 24.27it/s]

Val loss: 2.27


 51%|█████     | 204/399 [00:07<00:06, 29.93it/s]

Val loss: 2.26


 76%|███████▌  | 304/399 [00:10<00:03, 31.64it/s]

Val loss: 2.25


100%|██████████| 399/399 [00:14<00:00, 28.33it/s]
 26%|██▋       | 105/399 [00:03<00:09, 31.34it/s]

Val loss: 2.22


 51%|█████▏    | 205/399 [00:06<00:06, 31.94it/s]

Val loss: 2.22


 77%|███████▋  | 307/399 [00:09<00:02, 33.02it/s]

Val loss: 2.21


100%|██████████| 399/399 [00:12<00:00, 31.12it/s]
 26%|██▌       | 104/399 [00:03<00:09, 31.22it/s]

Val loss: 2.20


 51%|█████     | 204/399 [00:06<00:05, 33.63it/s]

Val loss: 2.20


 76%|███████▌  | 304/399 [00:09<00:02, 32.80it/s]

Val loss: 2.20


100%|██████████| 399/399 [00:12<00:00, 32.26it/s]
 26%|██▌       | 103/399 [00:03<00:11, 26.84it/s]

Val loss: 2.20


 51%|█████▏    | 205/399 [00:06<00:06, 32.00it/s]

Val loss: 2.19


 76%|███████▌  | 304/399 [00:10<00:02, 32.30it/s]

Val loss: 2.16


100%|██████████| 399/399 [00:13<00:00, 29.81it/s]

OK





In [None]:
print("Сгенерированные имена:")
for _ in range(20):
  start_ids = [SOS_id]
  sample_ids = top_k_generate(model, start_ids, 100, EOS_id, 5)
  sent = tokenizer.decode(sample_ids[1:])
  if sent[0].upper() + sent[1:] not in names:
    print("Not existing name:", sent)
  else:
    print("Existing name:", sent)

Сгенерированные имена:
Not existing name: брагимир
Existing name: мартон
Existing name: агнеса
Existing name: кирстин
Not existing name: магомедрухад
Not existing name: капет
Not existing name: бахрет
Not existing name: добрий
Existing name: карита
Existing name: димитрий
Not existing name: джуманда
Not existing name: киполинн
Not existing name: катиса
Not existing name: михрум
Not existing name: капура
Not existing name: корон
Existing name: бори
Existing name: мирослава
Not existing name: кират
Not existing name: магдура


Видим, что некоторые имена повторяются с уже существующими. Новые же имена выглядят очень правдоподобно.