In [27]:
import sys
from pathlib import Path

SRC = Path().resolve() / "src"
if str(SRC) not in sys.path:
    sys.path.append(str(SRC))

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from data_utils import read_splits, truncate

# from data_utils import process_dataset
# process_dataset("data/raw_dataset.csv")

train_df, val_df, test_df = read_splits("data/dataset_processed.csv")

truncate_ratio = 1.0

train_texts = truncate(list(train_df['text']), truncate_ratio)
val_texsts = truncate(list(val_df['text']), truncate_ratio)

print(len(train_texts))

64000


In [71]:
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
from next_token_dataset import NextTokenDataset

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

seq_size = 8
stride = 1

train_dataset = NextTokenDataset(train_texts, tokenizer, seq_size=seq_size, stride=stride)
val_dataset = NextTokenDataset(val_texsts, tokenizer, seq_size=seq_size, stride=stride)

print(f"Train samples count: {len(train_dataset)}")
print(f"Val samples count: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

100%|██████████| 64000/64000 [00:07<00:00, 8239.36it/s] 
100%|██████████| 8000/8000 [00:00<00:00, 29463.77it/s]


Train samples count: 523743
Val samples count: 64180


In [74]:
from lstm_model import NextTokenLSTM
from lstm_train import train_next_token
from device_utils import get_device

import torch
import torch.nn as nn

device = get_device()
print("Using device:", device)

model = NextTokenLSTM(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    hidden_dim=512,
    num_layers=2,
    pad_idx=tokenizer.pad_token_id or 0, # remove
    dropout_p=0.5,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

Using device: mps


In [75]:
model = train_next_token(
    model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=5
)

Epoch: 1: 100%|██████████| 4092/4092 [02:56<00:00, 23.25it/s]
100%|██████████| 502/502 [00:06<00:00, 74.78it/s]


Epoch 01 | Train Loss: 5.4852 | Val Loss: 5.5891 | Val PPL: 267.508 | Val Token Acc: 16.88%


Epoch: 2: 100%|██████████| 4092/4092 [02:55<00:00, 23.33it/s]
100%|██████████| 502/502 [00:06<00:00, 76.97it/s]


Epoch 02 | Train Loss: 4.6305 | Val Loss: 5.7347 | Val PPL: 309.431 | Val Token Acc: 17.03%


Epoch: 3: 100%|██████████| 4092/4092 [02:53<00:00, 23.52it/s]
100%|██████████| 502/502 [00:06<00:00, 76.35it/s]


Epoch 03 | Train Loss: 4.1287 | Val Loss: 5.9552 | Val PPL: 385.772 | Val Token Acc: 16.98%


Epoch: 4:   6%|▌         | 254/4092 [00:11<02:46, 23.08it/s]


KeyboardInterrupt: 

In [None]:
from datetime import datetime

timestamp = datetime.now().strftime("%y%m%d_%H%M")
torch.save(model.state_dict(), f"models/next_token_{seq_size}_{timestamp}.pth")

In [None]:
# state = torch.load("models/next_token_16.pth", map_location=device)
# model.load_state_dict(state)

<All keys matched successfully>

In [None]:
from eval_lstm import autocomplete_text

comp_text = autocomplete_text(
    model=model, 
    tokenizer=tokenizer,
    text="she is",
    seq_size=seq_size
)
comp_text

'a great woman and i love her and i love her so much i love her so much i love her so much i love her so much i love her so much i love her so much i love her so much i love her so much i love'