In [27]:
import sys
from pathlib import Path

SRC = Path().resolve() / "src"
if str(SRC) not in sys.path:
    sys.path.append(str(SRC))

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
dataset_usage = 0.25
dataset_regenerate = False

seq_size = 8 # число токеном в скользящем окне датасета
seq_stride = 1

emb_dim=256
hidden_dim=512
num_layers=2
dropout_p=0.3
learning_rate=1e-3
weight_decay=0.01

model_train = True
model_load_path = "models/next_token_8_250820_184932.pth"

In [29]:
from data_utils import process_dataset, read_splits, truncate

if dataset_regenerate:
    process_dataset("data/raw_dataset.csv")

train_df, val_df, test_df = read_splits("data/dataset_processed.csv")

train_texts = truncate(list(train_df['text']), dataset_usage)
val_texsts = truncate(list(val_df['text']), dataset_usage)

print(len(train_texts))

320000


In [30]:
from transformers import BertTokenizerFast
from tokenizer_utils import resolve_eos_token_id

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

eos_token_id = resolve_eos_token_id(tokenizer)

print("eos_token_id:", eos_token_id)

eos_token_id: 102


In [31]:

from torch.utils.data import DataLoader
from next_token_dataset import NextTokenDataset

train_dataset = NextTokenDataset(
    train_texts, 
    tokenizer, 
    eos_id=eos_token_id, 
    seq_size=seq_size, 
    stride=seq_stride
)

val_dataset = NextTokenDataset(
    val_texsts,
    tokenizer,
    eos_id=eos_token_id,
    seq_size=seq_size,
    stride=seq_stride
)

print(f"Train samples count: {len(train_dataset)}")
print(f"Val samples count: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

100%|██████████| 320000/320000 [00:12<00:00, 24743.77it/s]
100%|██████████| 40000/40000 [00:01<00:00, 31829.27it/s]

Train samples count: 2613725
Val samples count: 326566





In [32]:
from lstm_model import NextTokenLSTM
from device_utils import get_device

import torch
import torch.nn as nn

device = get_device()
print("Using device:", device)

model = NextTokenLSTM(
    vocab_size=tokenizer.vocab_size,
    emb_dim=emb_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    dropout_p=dropout_p,
)

Using device: mps


In [33]:
from lstm_train import train_next_token
from datetime import datetime

if model_train:
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    model = train_next_token(
        model,
        optimizer=optimizer,
        criterion=criterion,
        tokenizer=tokenizer,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=5
    )

    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
    torch.save(model.state_dict(), f"models/next_token_{seq_size}_{timestamp}.pth")
else:
    state = torch.load(model_load_path, map_location=device)
    model.load_state_dict(state)

Epoch: 1: 100%|██████████| 20420/20420 [14:38<00:00, 23.23it/s]
Eval: 100%|██████████| 2552/2552 [00:55<00:00, 46.25it/s]


Epoch 01 | Train Loss: 5.0483 | Val Loss: 5.1035 | Val PPL: 164.589 | Val Token Acc: 19.83% | ROUGE-1/2/L(F1): 0.195/0.031/0.187


Epoch: 2: 100%|██████████| 20420/20420 [14:18<00:00, 23.80it/s]
Eval: 100%|██████████| 2552/2552 [00:54<00:00, 46.99it/s]


Epoch 02 | Train Loss: 4.5420 | Val Loss: 5.1715 | Val PPL: 176.175 | Val Token Acc: 19.99% | ROUGE-1/2/L(F1): 0.197/0.032/0.189


Epoch: 3: 100%|██████████| 20420/20420 [14:17<00:00, 23.81it/s]
Eval: 100%|██████████| 2552/2552 [00:54<00:00, 46.49it/s]


Epoch 03 | Train Loss: 4.3162 | Val Loss: 5.2448 | Val PPL: 189.585 | Val Token Acc: 19.92% | ROUGE-1/2/L(F1): 0.196/0.032/0.188


Epoch: 4:   3%|▎         | 691/20420 [00:29<14:04, 23.36it/s]


KeyboardInterrupt: 

In [41]:
from autocomplete import autocomplete_text

comp_text = autocomplete_text(
    model=model, 
    tokenizer=tokenizer,
    eos_id=eos_token_id,
    seq_size=seq_size,
    text="my friend is",
)
comp_text

'going to be able to get a new one'