In [1]:
## to run in colab

!git clone https://github.com/sergeychuvakin/advanced_nlp_course.git
!mv advanced_nlp_course/LM/*.py ./
!mv advanced_nlp_course/LM/*.json ./ ## 
!pip install loguru pydantic tokenizers
!pip install nltk==3.6.2
import nltk
nltk.download('stopwords')
nltk.download('punkt')

Cloning into 'advanced_nlp_course'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 98 (delta 45), reused 80 (delta 30), pack-reused 0[K
Unpacking objects: 100% (98/98), done.
Collecting loguru
  Downloading loguru-0.5.3-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.8 MB/s 
[?25hCollecting pydantic
  Downloading pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9 MB)
[K     |████████████████████████████████| 10.9 MB 9.5 MB/s 
[?25hCollecting tokenizers
  Downloading tokenizers-0.11.4-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 45.3 MB/s 
Installing collected packages: tokenizers, pydantic, loguru
Successfully installed loguru-0.5.3 pydantic-1.9.0 tokenizers-0.11.4
Collecting nltk==3.6.2
  Downloading nltk-3.6.2-py3-none-any.whl (1.5 MB)
[K     

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [2]:
!pip freeze | egrep "pydantic|torch|loguru|tokenizers|requests|nltk|tqdm"

loguru==0.5.3
nltk==3.6.2
pydantic==1.9.0
requests==2.23.0
requests-oauthlib==1.3.0
tokenizers==0.11.4
torch @ https://download.pytorch.org/whl/cu111/torch-1.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu111/torchaudio-0.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl
torchsummary==1.5.1
torchtext==0.11.0
torchvision @ https://download.pytorch.org/whl/cu111/torchvision-0.11.1%2Bcu111-cp37-cp37m-linux_x86_64.whl
tqdm==4.62.3


In [2]:
%reload_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
from loguru import logger
import sys
import json
from tqdm import tqdm


from dependencies import corpus, tokenizer
from config import Config, LanguageModelConfig
from processing_utils import (
    clean_text, 
    split_on_sequences, 
    create_ngrams, 
    create_ngrams1,
    create_to_x_and_y, 
    word2int,
    create_vocab,
    save_artifacts,
    remove_stopwords, 
    remove_speakers
)
from model import LM_LSTM
from datahandler import LMDataset
from train_utils import train_model

config = Config()

logger.remove()
logger.add(sys.stderr, level="WARNING")

1

In [3]:
corpus = remove_speakers(corpus)
corpus = clean_text(corpus)
corpus = remove_stopwords(corpus)
corpus = split_on_sequences(corpus)

# tcorpus = tokenizer.encode_batch(corpus, add_special_tokens=False)

# ## create n-grams for each doc
# sq = create_ngrams(tcorpus, config.N_GRAM) 
 
# ## shift corpus to create x and y 
# x, y =  create_to_x_and_y(sq)

# id_token, token_id = create_vocab(tokenizer)
# vocab_size = len(token_id)

from itertools import chain

tcorpus = [tokenizer(sent) for sent in tqdm(corpus)]

id_token = dict(enumerate(set(chain.from_iterable(tcorpus))))

## add special symbols 
id_token[len(id_token)] = Config.TOKEN_UNKNOWN ## add UNK
id_token[len(id_token)] = Config.TOKEN_PADDING ## add PAD

token_id = {value:key for key, value in id_token.items()}
vocab_size = len(token_id)

sq = create_ngrams1(tcorpus, config.N_GRAM, token_id)


100%|██████████| 671517/671517 [00:00<00:00, 3040924.06it/s]
100%|██████████| 51324/51324 [00:04<00:00, 10769.42it/s]
100%|██████████| 51324/51324 [00:00<00:00, 265336.70it/s]
100%|██████████| 51324/51324 [00:00<00:00, 264418.59it/s]


In [37]:
len(id_token)

48560

In [38]:
## shift corpus to create x and y 
x, y =  create_to_x_and_y(sq)

## split data
tradeoff_index = int(len(x) * config.TRAIN_PROPORTION)

x_train = x[:tradeoff_index]
x_test = x[tradeoff_index:]

y_train = y[:tradeoff_index]
y_test = y[tradeoff_index:]

logger.warning(
    f"""
    Output shapes: 
        x_train: {len(x_train)}, 
        x_test: {len(x_test)}, 
        y_train: {len(y_train)}, 
        y_test: {len(y_test)}
    """
              )

## load to dataset and dataloader
train_ds = LMDataset(x_train, y_train)
test_ds = LMDataset(x_test, y_test)

train_dl = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=config.BATCH_SIZE, shuffle=False)

# model and model config
model_config = LanguageModelConfig(vocab_size=vocab_size, emb_size=300)
model = LM_LSTM(**model_config.dict(), logger=logger)

## save artifacts
save_artifacts(
    (model_config.dict(), config.SAVE_MODEL_CONFIG),
    (token_id, config.SAVE_TOKEN_ID),
    (id_token, config.SAVE_ID_TOKEN)
)

    Output shapes: 
        x_train: 206670, 
        x_test: 33645, 
        y_train: 206670, 
        y_test: 33645
    


In [39]:
optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr, weight_decay=1e-5)
loss_func = torch.nn.CrossEntropyLoss()


# train model 
tmodel = train_model(model,
                     train_dl,
                     optimizer=optimizer,
                     loss_func=loss_func,
                     token_id=token_id,
                     epochs=1, 
                     clip=1)

torch.save(tmodel.state_dict(), config.SAVE_MODEL_FNAME)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/207 [00:00<?, ?it/s][A
  0%|          | 1/207 [00:00<01:12,  2.84it/s][A
  1%|          | 2/207 [00:00<01:08,  2.98it/s][A

Perplexity loss: 48533.08203125



  1%|▏         | 3/207 [00:00<01:05,  3.12it/s][A

Perplexity loss: 48358.8984375



  2%|▏         | 4/207 [00:01<01:03,  3.19it/s][A

Perplexity loss: 48116.7890625



  2%|▏         | 5/207 [00:01<01:02,  3.25it/s][A

Perplexity loss: 48185.30078125



  3%|▎         | 6/207 [00:01<01:01,  3.28it/s][A

Perplexity loss: 47901.0



  3%|▎         | 7/207 [00:02<01:00,  3.31it/s][A

Perplexity loss: 48014.42578125



  4%|▍         | 8/207 [00:02<00:59,  3.33it/s][A

Perplexity loss: 47902.46484375



  4%|▍         | 9/207 [00:02<00:59,  3.35it/s][A

Perplexity loss: 48142.58203125



  5%|▍         | 10/207 [00:03<00:58,  3.35it/s][A

Perplexity loss: 47982.33984375



  5%|▌         | 11/207 [00:03<00:58,  3.37it/s][A

Perplexity loss: 47934.40625



  6%|▌         | 12/207 [00:03<00:57,  3.38it/s][A

Perplexity loss: 48062.390625



  6%|▋         | 13/207 [00:03<00:57,  3.39it/s][A

Perplexity loss: 48078.39453125



  7%|▋         | 14/207 [00:04<00:56,  3.39it/s][A

Perplexity loss: 48046.35546875



  7%|▋         | 15/207 [00:04<01:01,  3.10it/s][A

Perplexity loss: 48126.515625



  8%|▊         | 16/207 [00:04<00:59,  3.19it/s][A

Perplexity loss: 48190.72265625



  8%|▊         | 17/207 [00:05<00:58,  3.25it/s][A

Perplexity loss: 48046.35546875



  9%|▊         | 18/207 [00:05<00:57,  3.30it/s][A

Perplexity loss: 48110.453125



  9%|▉         | 19/207 [00:05<00:56,  3.33it/s][A

Perplexity loss: 48014.3359375



 10%|▉         | 20/207 [00:06<00:55,  3.35it/s][A

Perplexity loss: 47774.859375



 10%|█         | 21/207 [00:06<00:55,  3.36it/s][A

Perplexity loss: 47774.859375



 11%|█         | 22/207 [00:06<00:54,  3.38it/s][A

Perplexity loss: 47870.5390625



 11%|█         | 23/207 [00:06<00:54,  3.35it/s][A

Perplexity loss: 48014.3359375



 12%|█▏        | 24/207 [00:07<00:54,  3.37it/s][A

Perplexity loss: 47743.01953125



 12%|█▏        | 25/207 [00:07<00:53,  3.39it/s][A

Perplexity loss: 47918.41015625



 13%|█▎        | 26/207 [00:07<00:53,  3.39it/s][A

Perplexity loss: 47998.35546875



 13%|█▎        | 27/207 [00:08<00:53,  3.39it/s][A

Perplexity loss: 47918.41015625



 14%|█▎        | 28/207 [00:08<00:52,  3.40it/s][A

Perplexity loss: 48094.4453125



 14%|█▍        | 29/207 [00:08<00:52,  3.40it/s][A

Perplexity loss: 47998.35546875



 14%|█▍        | 30/207 [00:09<00:52,  3.39it/s][A

Perplexity loss: 47966.37109375



 15%|█▍        | 31/207 [00:09<00:51,  3.40it/s][A

Perplexity loss: 48030.3671875



 15%|█▌        | 32/207 [00:09<00:51,  3.40it/s][A

Perplexity loss: 47998.35546875



 16%|█▌        | 33/207 [00:09<00:51,  3.40it/s][A

Perplexity loss: 47934.40625



 16%|█▋        | 34/207 [00:10<00:50,  3.40it/s][A

Perplexity loss: 48030.3671875



 17%|█▋        | 35/207 [00:10<00:50,  3.40it/s][A

Perplexity loss: 48062.390625



 17%|█▋        | 36/207 [00:10<00:50,  3.40it/s][A

Perplexity loss: 48110.453125



 18%|█▊        | 37/207 [00:11<00:49,  3.40it/s][A

Perplexity loss: 48078.39453125



 18%|█▊        | 38/207 [00:11<00:49,  3.40it/s][A

Perplexity loss: 47950.36328125



 19%|█▉        | 39/207 [00:11<00:49,  3.41it/s][A

Perplexity loss: 47982.33984375



 19%|█▉        | 40/207 [00:11<00:49,  3.41it/s][A

Perplexity loss: 47902.46484375



 20%|█▉        | 41/207 [00:12<00:48,  3.41it/s][A

Perplexity loss: 47886.4765625



 20%|██        | 42/207 [00:12<00:48,  3.40it/s][A

Perplexity loss: 48030.3671875



 21%|██        | 43/207 [00:12<00:48,  3.40it/s][A

Perplexity loss: 48238.95703125



 21%|██▏       | 44/207 [00:13<00:47,  3.40it/s][A

Perplexity loss: 48158.609375



 22%|██▏       | 45/207 [00:13<00:47,  3.39it/s][A

Perplexity loss: 47902.46484375



 22%|██▏       | 46/207 [00:13<00:47,  3.39it/s][A

Perplexity loss: 48110.453125



 23%|██▎       | 47/207 [00:14<00:47,  3.39it/s][A

Perplexity loss: 47982.33984375



 23%|██▎       | 48/207 [00:14<00:46,  3.40it/s][A

Perplexity loss: 48062.390625



 24%|██▎       | 49/207 [00:14<00:46,  3.41it/s][A

Perplexity loss: 47950.36328125



 24%|██▍       | 50/207 [00:14<00:46,  3.41it/s][A

Perplexity loss: 48062.390625



 25%|██▍       | 51/207 [00:15<00:45,  3.40it/s][A

Perplexity loss: 47950.36328125



 25%|██▌       | 52/207 [00:15<00:45,  3.40it/s][A

Perplexity loss: 48062.390625



 26%|██▌       | 53/207 [00:15<00:45,  3.41it/s][A

Perplexity loss: 47886.4765625



 26%|██▌       | 54/207 [00:16<00:44,  3.40it/s][A

Perplexity loss: 47902.46484375



 27%|██▋       | 55/207 [00:16<00:44,  3.41it/s][A

Perplexity loss: 48062.390625



 27%|██▋       | 56/207 [00:16<00:44,  3.40it/s][A

Perplexity loss: 47918.41015625



 28%|██▊       | 57/207 [00:16<00:44,  3.40it/s][A

Perplexity loss: 47982.33984375



 28%|██▊       | 58/207 [00:17<00:43,  3.40it/s][A

Perplexity loss: 47918.41015625



 29%|██▊       | 59/207 [00:17<00:43,  3.37it/s][A

Perplexity loss: 48062.390625



 29%|██▉       | 60/207 [00:17<00:43,  3.38it/s][A

Perplexity loss: 48030.3671875



 29%|██▉       | 61/207 [00:18<00:43,  3.38it/s][A

Perplexity loss: 47966.37109375



 30%|██▉       | 62/207 [00:18<00:42,  3.38it/s][A

Perplexity loss: 47998.35546875



 30%|███       | 63/207 [00:18<00:42,  3.39it/s][A

Perplexity loss: 47998.35546875



 31%|███       | 64/207 [00:19<00:42,  3.39it/s][A

Perplexity loss: 48062.390625



 31%|███▏      | 65/207 [00:19<00:41,  3.39it/s][A

Perplexity loss: 47982.33984375



 32%|███▏      | 66/207 [00:19<00:45,  3.09it/s][A

Perplexity loss: 47918.41015625



 32%|███▏      | 67/207 [00:20<00:44,  3.18it/s][A

Perplexity loss: 47934.40625



 33%|███▎      | 68/207 [00:20<00:42,  3.24it/s][A

Perplexity loss: 47870.5390625



 33%|███▎      | 69/207 [00:20<00:42,  3.28it/s][A

Perplexity loss: 48030.3671875



 34%|███▍      | 70/207 [00:20<00:41,  3.31it/s][A

Perplexity loss: 48014.3359375



 34%|███▍      | 71/207 [00:21<00:40,  3.34it/s][A

Perplexity loss: 47918.41015625



 35%|███▍      | 72/207 [00:21<00:40,  3.36it/s][A

Perplexity loss: 47982.33984375



 35%|███▌      | 73/207 [00:21<00:39,  3.37it/s][A

Perplexity loss: 47934.40625



 36%|███▌      | 74/207 [00:22<00:39,  3.38it/s][A

Perplexity loss: 47743.01953125



 36%|███▌      | 75/207 [00:22<00:38,  3.39it/s][A

Perplexity loss: 47870.5390625



 37%|███▋      | 76/207 [00:22<00:38,  3.39it/s][A

Perplexity loss: 48062.390625



 37%|███▋      | 77/207 [00:22<00:38,  3.39it/s][A

Perplexity loss: 48110.453125



 38%|███▊      | 78/207 [00:23<00:38,  3.38it/s][A

Perplexity loss: 47854.56640625



 38%|███▊      | 79/207 [00:23<00:37,  3.39it/s][A

Perplexity loss: 48190.72265625



 39%|███▊      | 80/207 [00:23<00:37,  3.38it/s][A

Perplexity loss: 48046.35546875



 39%|███▉      | 81/207 [00:24<00:37,  3.37it/s][A

Perplexity loss: 47998.35546875



 40%|███▉      | 82/207 [00:24<00:37,  3.37it/s][A

Perplexity loss: 47998.35546875



 40%|████      | 83/207 [00:24<00:36,  3.37it/s][A

Perplexity loss: 47950.36328125



 41%|████      | 84/207 [00:25<00:36,  3.38it/s][A

Perplexity loss: 48158.609375



 41%|████      | 85/207 [00:25<00:36,  3.38it/s][A

Perplexity loss: 48126.515625



 42%|████▏     | 86/207 [00:25<00:35,  3.38it/s][A

Perplexity loss: 48030.3671875



 42%|████▏     | 87/207 [00:25<00:35,  3.38it/s][A

Perplexity loss: 47982.33984375



 43%|████▎     | 88/207 [00:26<00:35,  3.39it/s][A

Perplexity loss: 48014.3359375



 43%|████▎     | 89/207 [00:26<00:34,  3.39it/s][A

Perplexity loss: 48062.390625



 43%|████▎     | 90/207 [00:26<00:34,  3.40it/s][A

Perplexity loss: 47966.37109375



 44%|████▍     | 91/207 [00:27<00:34,  3.40it/s][A

Perplexity loss: 47934.40625



 44%|████▍     | 92/207 [00:27<00:33,  3.40it/s][A

Perplexity loss: 48014.3359375



 45%|████▍     | 93/207 [00:27<00:33,  3.39it/s][A

Perplexity loss: 48078.39453125



 45%|████▌     | 94/207 [00:27<00:33,  3.39it/s][A

Perplexity loss: 47934.40625



 46%|████▌     | 95/207 [00:28<00:33,  3.39it/s][A

Perplexity loss: 48030.3671875



 46%|████▋     | 96/207 [00:28<00:32,  3.39it/s][A

Perplexity loss: 47934.40625



 47%|████▋     | 97/207 [00:28<00:32,  3.39it/s][A

Perplexity loss: 48110.453125



 47%|████▋     | 98/207 [00:29<00:32,  3.39it/s][A

Perplexity loss: 47934.40625



 48%|████▊     | 99/207 [00:29<00:31,  3.38it/s][A

Perplexity loss: 48014.3359375



 48%|████▊     | 100/207 [00:29<00:31,  3.38it/s][A

Perplexity loss: 48158.609375



 49%|████▉     | 101/207 [00:30<00:31,  3.38it/s][A

Perplexity loss: 47966.37109375



 49%|████▉     | 102/207 [00:30<00:31,  3.38it/s][A

Perplexity loss: 47998.35546875



 50%|████▉     | 103/207 [00:30<00:30,  3.38it/s][A

Perplexity loss: 47982.33984375



 50%|█████     | 104/207 [00:30<00:30,  3.38it/s][A

Perplexity loss: 47982.33984375



 51%|█████     | 105/207 [00:31<00:30,  3.38it/s][A

Perplexity loss: 47902.46484375



 51%|█████     | 106/207 [00:31<00:29,  3.37it/s][A

Perplexity loss: 48030.3671875



 52%|█████▏    | 107/207 [00:31<00:29,  3.36it/s][A

Perplexity loss: 47934.40625



 52%|█████▏    | 108/207 [00:32<00:29,  3.37it/s][A

Perplexity loss: 48126.515625



 53%|█████▎    | 109/207 [00:32<00:29,  3.37it/s][A

Perplexity loss: 47982.33984375



 53%|█████▎    | 110/207 [00:32<00:28,  3.37it/s][A

Perplexity loss: 48046.35546875



 54%|█████▎    | 111/207 [00:33<00:28,  3.36it/s][A

Perplexity loss: 47854.56640625



 54%|█████▍    | 112/207 [00:33<00:28,  3.37it/s][A

Perplexity loss: 48062.390625



 55%|█████▍    | 113/207 [00:33<00:27,  3.36it/s][A

Perplexity loss: 47950.36328125



 55%|█████▌    | 114/207 [00:33<00:27,  3.37it/s][A

Perplexity loss: 47902.46484375



 56%|█████▌    | 115/207 [00:34<00:27,  3.38it/s][A

Perplexity loss: 48046.35546875



 56%|█████▌    | 116/207 [00:34<00:27,  3.35it/s][A

Perplexity loss: 47822.67578125



 57%|█████▋    | 117/207 [00:34<00:29,  3.07it/s][A

Perplexity loss: 48046.35546875



 57%|█████▋    | 118/207 [00:35<00:28,  3.12it/s][A

Perplexity loss: 48142.58203125



 57%|█████▋    | 119/207 [00:35<00:27,  3.19it/s][A

Perplexity loss: 47966.37109375



 58%|█████▊    | 120/207 [00:35<00:26,  3.24it/s][A

Perplexity loss: 48062.390625



 58%|█████▊    | 121/207 [00:36<00:26,  3.27it/s][A

Perplexity loss: 47950.36328125



 59%|█████▉    | 122/207 [00:36<00:25,  3.30it/s][A

Perplexity loss: 47966.37109375



 59%|█████▉    | 123/207 [00:36<00:25,  3.33it/s][A

Perplexity loss: 47998.35546875



 60%|█████▉    | 124/207 [00:36<00:24,  3.35it/s][A

Perplexity loss: 47982.33984375



 60%|██████    | 125/207 [00:37<00:24,  3.36it/s][A

Perplexity loss: 47966.37109375



 61%|██████    | 126/207 [00:37<00:24,  3.37it/s][A

Perplexity loss: 47998.35546875



 61%|██████▏   | 127/207 [00:37<00:23,  3.37it/s][A

Perplexity loss: 47870.5390625



 62%|██████▏   | 128/207 [00:38<00:23,  3.37it/s][A

Perplexity loss: 48174.6875



 62%|██████▏   | 129/207 [00:38<00:23,  3.38it/s][A

Perplexity loss: 48046.35546875



 63%|██████▎   | 130/207 [00:38<00:22,  3.38it/s][A

Perplexity loss: 48062.390625



 63%|██████▎   | 131/207 [00:39<00:22,  3.38it/s][A

Perplexity loss: 47838.640625



 64%|██████▍   | 132/207 [00:39<00:22,  3.38it/s][A

Perplexity loss: 47998.35546875



 64%|██████▍   | 133/207 [00:39<00:21,  3.38it/s][A

Perplexity loss: 48046.35546875



 65%|██████▍   | 134/207 [00:39<00:21,  3.39it/s][A

Perplexity loss: 47966.37109375



 65%|██████▌   | 135/207 [00:40<00:21,  3.38it/s][A

Perplexity loss: 48030.3671875



 66%|██████▌   | 136/207 [00:40<00:21,  3.37it/s][A

Perplexity loss: 47966.37109375



 66%|██████▌   | 137/207 [00:40<00:20,  3.37it/s][A

Perplexity loss: 48078.39453125



 67%|██████▋   | 138/207 [00:41<00:20,  3.37it/s][A

Perplexity loss: 47950.36328125



 67%|██████▋   | 139/207 [00:41<00:20,  3.36it/s][A

Perplexity loss: 48110.453125



 68%|██████▊   | 140/207 [00:41<00:19,  3.37it/s][A

Perplexity loss: 47982.33984375



 68%|██████▊   | 141/207 [00:42<00:19,  3.35it/s][A

Perplexity loss: 47918.41015625



 69%|██████▊   | 142/207 [00:42<00:19,  3.36it/s][A

Perplexity loss: 48078.39453125



 69%|██████▉   | 143/207 [00:42<00:19,  3.36it/s][A

Perplexity loss: 48030.3671875



 70%|██████▉   | 144/207 [00:42<00:18,  3.36it/s][A

Perplexity loss: 48142.58203125



 70%|███████   | 145/207 [00:43<00:18,  3.37it/s][A

Perplexity loss: 47886.4765625



 71%|███████   | 146/207 [00:43<00:18,  3.36it/s][A

Perplexity loss: 48222.85546875



 71%|███████   | 147/207 [00:43<00:17,  3.36it/s][A

Perplexity loss: 48094.4453125



 71%|███████▏  | 148/207 [00:44<00:17,  3.35it/s][A

Perplexity loss: 48062.390625



 72%|███████▏  | 149/207 [00:44<00:17,  3.35it/s][A

Perplexity loss: 47695.328125



 72%|███████▏  | 150/207 [00:44<00:17,  3.33it/s][A

Perplexity loss: 48030.3671875



 73%|███████▎  | 151/207 [00:45<00:16,  3.34it/s][A

Perplexity loss: 47934.40625



 73%|███████▎  | 152/207 [00:45<00:16,  3.35it/s][A

Perplexity loss: 48030.3671875



 74%|███████▍  | 153/207 [00:45<00:16,  3.35it/s][A

Perplexity loss: 48094.4453125



 74%|███████▍  | 154/207 [00:45<00:15,  3.36it/s][A

Perplexity loss: 48142.58203125



 75%|███████▍  | 155/207 [00:46<00:15,  3.37it/s][A

Perplexity loss: 47982.33984375



 75%|███████▌  | 156/207 [00:46<00:15,  3.35it/s][A

Perplexity loss: 47966.37109375



 76%|███████▌  | 157/207 [00:46<00:14,  3.35it/s][A

Perplexity loss: 47982.33984375



 76%|███████▋  | 158/207 [00:47<00:14,  3.36it/s][A

Perplexity loss: 48142.58203125



 77%|███████▋  | 159/207 [00:47<00:14,  3.36it/s][A

Perplexity loss: 48014.3359375



 77%|███████▋  | 160/207 [00:47<00:14,  3.35it/s][A

Perplexity loss: 47950.36328125



 78%|███████▊  | 161/207 [00:47<00:13,  3.35it/s][A

Perplexity loss: 47934.40625



 78%|███████▊  | 162/207 [00:48<00:13,  3.36it/s][A

Perplexity loss: 48126.515625



 79%|███████▊  | 163/207 [00:48<00:13,  3.37it/s][A

Perplexity loss: 47998.35546875



 79%|███████▉  | 164/207 [00:48<00:12,  3.37it/s][A

Perplexity loss: 47838.640625



 80%|███████▉  | 165/207 [00:49<00:12,  3.38it/s][A

Perplexity loss: 47998.35546875



 80%|████████  | 166/207 [00:49<00:12,  3.37it/s][A

Perplexity loss: 48062.390625



 81%|████████  | 167/207 [00:49<00:11,  3.36it/s][A

Perplexity loss: 47902.46484375



 81%|████████  | 168/207 [00:50<00:12,  3.07it/s][A

Perplexity loss: 47854.56640625



 82%|████████▏ | 169/207 [00:50<00:12,  3.15it/s][A

Perplexity loss: 47822.67578125



 82%|████████▏ | 170/207 [00:50<00:11,  3.21it/s][A

Perplexity loss: 48046.35546875



 83%|████████▎ | 171/207 [00:51<00:11,  3.26it/s][A

Perplexity loss: 47918.41015625



 83%|████████▎ | 172/207 [00:51<00:10,  3.29it/s][A

Perplexity loss: 47966.37109375



 84%|████████▎ | 173/207 [00:51<00:10,  3.31it/s][A

Perplexity loss: 47950.36328125



 84%|████████▍ | 174/207 [00:51<00:09,  3.32it/s][A

Perplexity loss: 48062.390625



 85%|████████▍ | 175/207 [00:52<00:09,  3.32it/s][A

Perplexity loss: 48062.390625



 85%|████████▌ | 176/207 [00:52<00:09,  3.33it/s][A

Perplexity loss: 47982.33984375



 86%|████████▌ | 177/207 [00:52<00:09,  3.32it/s][A

Perplexity loss: 48142.58203125



 86%|████████▌ | 178/207 [00:53<00:08,  3.34it/s][A

Perplexity loss: 47966.37109375



 86%|████████▋ | 179/207 [00:53<00:08,  3.35it/s][A

Perplexity loss: 48030.3671875



 87%|████████▋ | 180/207 [00:53<00:08,  3.35it/s][A

Perplexity loss: 47966.37109375



 87%|████████▋ | 181/207 [00:54<00:07,  3.34it/s][A

Perplexity loss: 47982.33984375



 88%|████████▊ | 182/207 [00:54<00:07,  3.34it/s][A

Perplexity loss: 48046.35546875



 88%|████████▊ | 183/207 [00:54<00:07,  3.35it/s][A

Perplexity loss: 47902.46484375



 89%|████████▉ | 184/207 [00:54<00:06,  3.35it/s][A

Perplexity loss: 47918.41015625



 89%|████████▉ | 185/207 [00:55<00:06,  3.35it/s][A

Perplexity loss: 48014.3359375



 90%|████████▉ | 186/207 [00:55<00:06,  3.34it/s][A

Perplexity loss: 48014.3359375



 90%|█████████ | 187/207 [00:55<00:05,  3.34it/s][A

Perplexity loss: 47934.40625



 91%|█████████ | 188/207 [00:56<00:05,  3.35it/s][A

Perplexity loss: 47998.35546875



 91%|█████████▏| 189/207 [00:56<00:05,  3.35it/s][A

Perplexity loss: 47982.33984375



 92%|█████████▏| 190/207 [00:56<00:05,  3.35it/s][A

Perplexity loss: 48126.515625



 92%|█████████▏| 191/207 [00:57<00:04,  3.34it/s][A

Perplexity loss: 48062.390625



 93%|█████████▎| 192/207 [00:57<00:04,  3.34it/s][A

Perplexity loss: 47982.33984375



 93%|█████████▎| 193/207 [00:57<00:04,  3.35it/s][A

Perplexity loss: 48094.4453125



 94%|█████████▎| 194/207 [00:57<00:03,  3.36it/s][A

Perplexity loss: 48014.3359375



 94%|█████████▍| 195/207 [00:58<00:03,  3.36it/s][A

Perplexity loss: 48014.3359375



 95%|█████████▍| 196/207 [00:58<00:03,  3.35it/s][A

Perplexity loss: 47998.35546875



 95%|█████████▌| 197/207 [00:58<00:02,  3.36it/s][A

Perplexity loss: 48062.390625



 96%|█████████▌| 198/207 [00:59<00:02,  3.36it/s][A

Perplexity loss: 47870.5390625



 96%|█████████▌| 199/207 [00:59<00:02,  3.36it/s][A

Perplexity loss: 47982.33984375



 97%|█████████▋| 200/207 [00:59<00:02,  3.36it/s][A

Perplexity loss: 47870.5390625



 97%|█████████▋| 201/207 [01:00<00:01,  3.35it/s][A

Perplexity loss: 48238.95703125



 98%|█████████▊| 202/207 [01:00<00:01,  3.34it/s][A

Perplexity loss: 47902.46484375



 98%|█████████▊| 203/207 [01:00<00:01,  3.32it/s][A

Perplexity loss: 48030.3671875



 99%|█████████▊| 204/207 [01:00<00:00,  3.33it/s][A

Perplexity loss: 47838.640625



 99%|█████████▉| 205/207 [01:01<00:00,  3.33it/s][A

Perplexity loss: 47982.33984375



100%|██████████| 207/207 [01:01<00:00,  3.36it/s]
100%|██████████| 1/1 [01:01<00:00, 61.55s/it]

Perplexity loss: 47966.37109375





In [42]:
from validation_utils import val_metrics

logger.warning(
    """
    Cross-Entropy: %f
    Perpelxity: %f
    """ % 
    val_metrics(model, test_dl, token_id, loss_func)
)

    Cross-Entropy: 10.779922
    Perpelxity: 48046.355469
    


### Inference

In [43]:

from inference import (
    model, 
    token_id, 
    id_token, 
    model, 
    predict_one_word, 
    predict_sample
)

from loguru import logger
import sys
logger.remove()
logger.add(sys.stderr, level="WARNING")

9

In [44]:
id_token[str(33274)]

'piercecorslet'

In [45]:
predict_one_word(
    "one of the jns sjdnvjs aldsv ",  
    model, 
    token_id, 
    id_token,
    random=False
)

'thou'

In [46]:
len(token_id)

48560

In [49]:
predict_sample(
    "how apply",  
    model, 
    token_id, 
    id_token,
    length_of_sample=10,
    random=False
)

'thou thou thou thou thou thou thou thou thou thou'