In [53]:
import sys
from pathlib import Path

SRC = Path().resolve() / "src"
if str(SRC) not in sys.path:
    sys.path.append(str(SRC))

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
import torch
import torch.nn as nn

In [43]:
from device_utils import get_device
device = get_device()
print("Using device:", device)

Using device: mps


In [44]:
dataset_usage = 1
dataset_regenerate = False

seq_size = 16 # число токеном в скользящем окне датасета
seq_stride = 1

emb_dim=256
hidden_dim=512
num_layers=2
dropout_p=0.3
learning_rate=1e-3
weight_decay=0.01

model_train = False
model_load_path = "models/next_token_16_250822_212717.pth"

In [45]:
from data_utils import process_dataset, read_splits, truncate_ratio_and_clear

if dataset_regenerate:
    process_dataset("data/raw_dataset.csv")

train_df, val_df, test_df = read_splits("data/dataset_processed.csv")

train_texts = truncate_ratio_and_clear(list(train_df['text']), dataset_usage)
val_texsts = truncate_ratio_and_clear(list(val_df['text']), dataset_usage)
test_texts = truncate_ratio_and_clear(list(test_df['text']), dataset_usage)

print(len(train_texts))

1277123


In [46]:
from transformers import BertTokenizerFast
from tokenizer_utils import resolve_eos_token_id

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

eos_token_id = resolve_eos_token_id(tokenizer)

print("eos_token_id:", eos_token_id)

eos_token_id: 102


In [10]:

from torch.utils.data import DataLoader
from next_token_dataset import NextTokenDataset

train_dataset = NextTokenDataset(
    train_texts, 
    tokenizer, 
    eos_id=eos_token_id, 
    seq_size=seq_size, 
    stride=seq_stride
)

val_dataset = NextTokenDataset(
    val_texsts,
    tokenizer,
    eos_id=eos_token_id,
    seq_size=seq_size,
    stride=seq_stride
)

print(f"Train samples count: {len(train_dataset)}")
print(f"Val samples count: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

  0%|          | 0/1280000 [00:00<?, ?it/s]

100%|██████████| 1280000/1280000 [02:00<00:00, 10614.03it/s]
100%|██████████| 160000/160000 [00:15<00:00, 10202.13it/s]

Train samples count: 4158366
Val samples count: 517103





In [47]:
from lstm_model import NextTokenLSTM

model = NextTokenLSTM(
    vocab_size=tokenizer.vocab_size,
    emb_dim=emb_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    dropout_p=dropout_p,
)

In [None]:
from lstm_train import train_next_token
from datetime import datetime

if model_train:
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    model = train_next_token(
        model,
        optimizer=optimizer,
        criterion=criterion,
        tokenizer=tokenizer,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=5
    )

    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
    torch.save(model.state_dict(), f"models/next_token_{seq_size}_{timestamp}.pth")

Epoch: 1: 100%|██████████| 32488/32488 [39:59<00:00, 13.54it/s]
Eval: 100%|██████████| 4040/4040 [04:59<00:00, 13.50it/s]


Epoch 01 | Train Loss: 4.8280 | Val Loss: 4.8404 | Val PPL: 126.518 | Val Token Acc: 21.86% | ROUGE-1/2/L(F1): 0.239/0.039/0.214


Epoch: 2: 100%|██████████| 32488/32488 [40:03<00:00, 13.52it/s]
Eval: 100%|██████████| 4040/4040 [04:58<00:00, 13.52it/s]


Epoch 02 | Train Loss: 4.4323 | Val Loss: 4.8748 | Val PPL: 130.949 | Val Token Acc: 22.00% | ROUGE-1/2/L(F1): 0.241/0.040/0.215


Epoch: 3: 100%|██████████| 32488/32488 [40:01<00:00, 13.53it/s]
Eval: 100%|██████████| 4040/4040 [04:59<00:00, 13.50it/s]


Epoch 03 | Train Loss: 4.2901 | Val Loss: 4.9249 | Val PPL: 137.679 | Val Token Acc: 21.94% | ROUGE-1/2/L(F1): 0.240/0.040/0.214


Epoch: 4: 100%|██████████| 32488/32488 [40:01<00:00, 13.53it/s]
Eval: 100%|██████████| 4040/4040 [04:59<00:00, 13.49it/s]


Epoch 04 | Train Loss: 4.2152 | Val Loss: 4.9620 | Val PPL: 142.884 | Val Token Acc: 21.84% | ROUGE-1/2/L(F1): 0.239/0.040/0.214
Early stopping: no PPL improvement for 3 epoch(s). Best Val PPL: 126.518


In [48]:
if not model_train:
    state = torch.load(model_load_path, map_location=device)
    model.load_state_dict(state)

In [49]:
prompts = [
    "I really want to",
    "Can you please tell me",
    "The main reason is",
    "When I was a child",
    "I think it would be better",
    "Let's go to the",
    "It is important because",
    "In the future I would like to",
    "Do you know how to",
    "One of my favorite things"
]

In [70]:
from autocomplete import autocomplete_text

for prompt in prompts:
    comp_text = autocomplete_text(
        model=model, 
        tokenizer=tokenizer,
        eos_id=eos_token_id,
        seq_size=seq_size,
        text=prompt,
    )
    print(f"{prompt} -> {comp_text}")
    print("=" * 50)

I really want to -> go to the beach
Can you please tell me -> 
The main reason is -> that i m not a fan of the twitter
When I was a child -> 
I think it would be better -> 
Let's go to the -> beach
It is important because -> i m not a fan of the twitter
In the future I would like to -> be a fan of yours
Do you know how to -> use it
One of my favorite things -> 


In [69]:
from eval_transformer_pipeline import build_gpt2_pipeline_fn

gpt2_generate, tok = build_gpt2_pipeline_fn(
    model_name="distilgpt2",
    max_new_tokens=32,
    temperature=0.8,
    top_p=0.95,
    device=device
)

for prompt in prompts:
    comp_text = gpt2_generate(prompt)
    print(f"{prompt} -> {comp_text}")
    print("=" * 50)

Device set to use mps


I really want to -> see something that doesn't fall apart," he said. "It's something that's very challenging. I'd love to see that happen."
Can you please tell me -> what you think of the situation. We will try to provide a good way to answer your questions and get the answers right.



We will be
The main reason is -> to avoid having the data access in our database. But it is important to know that we don't do so with our database. In our database, there are
When I was a child -> , I was a little girl. But I grew up playing basketball and tennis. I wanted to be a kid. I had to play basketball and tennis, so
I think it would be better -> to do something else than do it myself."
Let's go to the -> bottom of this post with a few key points. First, there is no way to fix the problem in the current way.

The problem is that we
It is important because -> a system of governance that respects the interests of the public is also important because it provides information and information about h

In [None]:
from rouge_autocomplete_tester import PairingConfig, RougeAutocompleteTester

def lstm_generate(prompt: str) -> str:
    return autocomplete_text(
        model=model, 
        tokenizer=tokenizer,
        eos_id=eos_token_id,
        seq_size=seq_size,
        text=prompt,
    )

tester = RougeAutocompleteTester(
    tokenizer=tok,
    pairing=PairingConfig(prompt_len=16, cont_len=32, max_text_tokens=256),
    use_stemmer=True
)

print("GPT2 Transformer ROUGE Testing:")
gpt_result = tester.evaluate(test_texts[:1000], generator_fn=gpt2_generate, limit_pairs=None)
print(gpt_result)

print("=" * 50)

print("RRN LSTM ROUGE Testing:")
rnn_result = tester.evaluate(test_texts[:1000], generator_fn=lstm_generate, limit_pairs=None)
print(rnn_result)

Device set to use mps


GPT2 Transformer ROUGE Testing:


100%|██████████| 413/413 [02:36<00:00,  2.63it/s]


{'rouge1': 0.056100306716612376, 'rouge2': 0.005075522669654817, 'rougeL': 0.05145738091082816, 'elapsed_ms': 157026.1751250364}
RRN LSTM ROUGE Testing:


100%|██████████| 413/413 [00:04<00:00, 91.63it/s] 


{'rouge1': 0.07222530592715068, 'rouge2': 0.010712582413550938, 'rougeL': 0.07018916904811688, 'elapsed_ms': 4560.19779201597}


## Подходы к автокомплиту
- **RNN (LSTM):**  
  Использован рекуррентный метод. На каждом шаге из выходной последовательности берутся логиты, соответствующие последнему токену, из них с помощью `argmax` выбирается наиболее вероятный токен, который затем подставляется в конец входной последовательности на следующем шаге.

- **Transformer (GPT-2):**  
  Использован стандартный autoregressive-подход модели `generate`, где последовательность продолжается на основе ранее сгенерированных токенов.

## Метрики
Для оценки использовалась метрика **ROUGE**:

1. **Во время обучения:**  
   Сравнение предсказаний по скользящему окну (teacher forcing).

2. **На тестовых данных:**  
   Тексты делились на две части:
   - `prompt` — начало текста.  
   - `continuation` — эталонное продолжение.  
   Считался ROUGE между сгенерированным продолжением и эталонным.

Метрика ROUGE по разбиению текстов на 2 части применялась как для RNN, так и для GPT-2.

## Результаты
- Результаты оказались **сравнимыми**.  
- Transformer показал **немного лучшие значения ROUGE**, что объясняется предварительным обучением на текстах Twitter и лучшим предсказанием таких же текстов.  
- **RNN работает значительно быстрее**, особенно на ограниченных ресурсах.  
- **Качество текста:**
  - GPT-2 генерирует более богатые и осмысленные тексты.  
  - LSTM даёт сухие ответы и часто не попадает в контекст.

## Выводы
- Для простого автокомплита, где **не требуется глубокий контекст** и важна работа на мобильных устройствах с ограниченной памятью и вычислительными ресурсами, больше подходит **RNN (LSTM)**.  
- Для более содержательных и осмысленных продолжений текста лучше использовать **Transformer (GPT-2)**.