In [None]:
## to run in colab

# !git clone https://github.com/sergeychuvakin/advanced_nlp_course.git
# !mv advanced_nlp_course/LM/*.py ./
# !pip install loguru pydantic tokenizers

In [16]:
%reload_ext autoreload
%autoreload 2

from torch.utils.data import DataLoader
from loguru import logger
import sys


from dependencies import corpus, tokenizer
from config import Config, LanguageModelConfig
from processing_utils import (
    clean_text, 
    split_on_sequences, 
    create_ngrams, 
    create_to_x_and_y, 
    word2int,
    create_vocab
)
from model import LM_LSTM
from datahandler import LMDataset
from train_utils import train_model

config = Config()

logger.remove()
logger.add(sys.stderr, level="WARNING")

7

In [None]:
corpus = clean_text(corpus)
corpus = split_on_sequences(corpus)

tcorpus = tuple(map(lambda sentence: tokenizer.encode(sentence), corpus))


## create n-grams for each doc
sq = create_ngrams(tcorpus, config.N_GRAM)

## shift corpus to create x and y 
x, y =  create_to_x_and_y(sq)

token_id, id_token = create_vocab(tcorpus)
vocab_size = len(token_id)

## for passing to dataloader
x_int = [word2int(i, token_id) for i in x]
y_int = [word2int(i, token_id) for i in y]

## split data
tradeoff_index = int(len(x_int) * config.TRAIN_PROPORTION)

x_train = x_int[:tradeoff_index]
x_test = x_int[tradeoff_index:]

y_train = y_int[:tradeoff_index]
y_test = y_int[tradeoff_index:]

logger.warning(f"Outpur shapes: x_train: {len(x_train)}, x_test: {len(x_test)}, y_train: {len(y_train)}, y_test: {len(y_test)}")

## load to dataset and dataloader
train_ds = LMDataset(x_train, y_train)
test_ds = LMDataset(x_test, y_test)

train_dl = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=config.BATCH_SIZE, shuffle=False)

# model and model config
model_config = LanguageModelConfig(vocab_size=vocab_size, emb_size=300)
model = LM_LSTM(**model_config.dict(), logger=logger)


optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr)
loss_func = torch.nn.CrossEntropyLoss()

# train model 
tmodel = train_model(model,
                     train_dl,
                     optimizer=optimizer,
                     loss_func=loss_func,
                     batch_size=config.BATCH_SIZE,
                     epochs=30, 
                     clip=1)

# torch.save(model.state_dict(), config.SAVE_MODEL_FNAME)

100%|██████████| 52810/52810 [00:00<00:00, 63444.36it/s]
100%|██████████| 52810/52810 [00:00<00:00, 145790.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]
  0%|          | 0/970 [00:00<?, ?it/s][A
  0%|          | 1/970 [00:01<21:17,  1.32s/it][A
  0%|          | 2/970 [00:02<21:18,  1.32s/it][A

Cross-entropy loss: 9.452272415161133



  0%|          | 3/970 [00:03<20:57,  1.30s/it][A

Cross-entropy loss: 9.452272415161133



  0%|          | 4/970 [00:05<20:50,  1.29s/it][A

Cross-entropy loss: 9.452272415161133



  1%|          | 5/970 [00:06<21:09,  1.32s/it][A

Cross-entropy loss: 9.452272415161133



  1%|          | 6/970 [00:07<20:55,  1.30s/it][A

Cross-entropy loss: 9.452272415161133



  1%|          | 7/970 [00:09<20:50,  1.30s/it][A

Cross-entropy loss: 9.452272415161133



  1%|          | 8/970 [00:10<20:55,  1.31s/it][A

Cross-entropy loss: 9.452272415161133



  1%|          | 9/970 [00:11<20:54,  1.31s/it][A

Cross-entropy loss: 9.45227336883545



  1%|          | 10/970 [00:13<20:50,  1.30s/it][A

Cross-entropy loss: 9.452274322509766



  1%|          | 11/970 [00:14<20:47,  1.30s/it][A

Cross-entropy loss: 9.452275276184082



  1%|          | 12/970 [00:15<20:47,  1.30s/it][A

Cross-entropy loss: 9.45227336883545



  1%|▏         | 13/970 [00:16<20:40,  1.30s/it][A

Cross-entropy loss: 9.452269554138184



  1%|▏         | 14/970 [00:18<20:59,  1.32s/it][A

Cross-entropy loss: 9.45226764678955



  2%|▏         | 15/970 [00:19<20:50,  1.31s/it][A

Cross-entropy loss: 9.45226764678955



  2%|▏         | 16/970 [00:20<20:41,  1.30s/it][A

Cross-entropy loss: 9.452266693115234



  2%|▏         | 17/970 [00:22<20:35,  1.30s/it][A

Cross-entropy loss: 9.452249526977539



  2%|▏         | 18/970 [00:23<20:39,  1.30s/it][A

Cross-entropy loss: 9.452239036560059



  2%|▏         | 19/970 [00:24<20:45,  1.31s/it][A

Cross-entropy loss: 9.45223617553711



  2%|▏         | 20/970 [00:26<20:54,  1.32s/it][A

Cross-entropy loss: 9.452239990234375



  2%|▏         | 21/970 [00:27<21:10,  1.34s/it][A

Cross-entropy loss: 9.452234268188477



  2%|▏         | 22/970 [00:28<21:04,  1.33s/it][A

Cross-entropy loss: 9.452217102050781



  2%|▏         | 23/970 [00:30<21:11,  1.34s/it][A

Cross-entropy loss: 9.452194213867188



  2%|▏         | 24/970 [00:31<21:15,  1.35s/it][A

Cross-entropy loss: 9.452159881591797



  3%|▎         | 25/970 [00:33<21:56,  1.39s/it][A

Cross-entropy loss: 9.452088356018066



  3%|▎         | 26/970 [00:34<21:52,  1.39s/it][A

Cross-entropy loss: 9.451889038085938



  3%|▎         | 27/970 [00:35<21:38,  1.38s/it][A

Cross-entropy loss: 9.451499938964844



  3%|▎         | 28/970 [00:37<22:03,  1.40s/it][A

Cross-entropy loss: 9.450617790222168



  3%|▎         | 29/970 [00:38<22:11,  1.42s/it][A

Cross-entropy loss: 9.448317527770996



  3%|▎         | 30/970 [00:40<21:55,  1.40s/it][A

Cross-entropy loss: 9.44389820098877



  3%|▎         | 31/970 [00:41<22:02,  1.41s/it][A

Cross-entropy loss: 9.43136215209961



  3%|▎         | 32/970 [00:42<21:36,  1.38s/it][A

Cross-entropy loss: 9.414605140686035



  3%|▎         | 33/970 [00:44<21:44,  1.39s/it][A

Cross-entropy loss: 9.397704124450684



  4%|▎         | 34/970 [00:45<21:24,  1.37s/it][A

Cross-entropy loss: 9.3936185836792



  4%|▎         | 35/970 [00:46<21:30,  1.38s/it][A

Cross-entropy loss: 9.385282516479492



  4%|▎         | 36/970 [00:48<21:16,  1.37s/it][A

Cross-entropy loss: 9.38392448425293



  4%|▍         | 37/970 [00:49<21:20,  1.37s/it][A

Cross-entropy loss: 9.382826805114746



  4%|▍         | 38/970 [00:50<20:45,  1.34s/it][A

Cross-entropy loss: 9.389179229736328



  4%|▍         | 39/970 [00:52<20:43,  1.34s/it][A

Cross-entropy loss: 9.384248733520508



  4%|▍         | 40/970 [00:53<21:08,  1.36s/it][A

Cross-entropy loss: 9.375192642211914



  4%|▍         | 41/970 [00:55<21:20,  1.38s/it][A

Cross-entropy loss: 9.383358001708984



  4%|▍         | 42/970 [00:56<22:16,  1.44s/it][A

Cross-entropy loss: 9.367729187011719



  4%|▍         | 43/970 [00:58<22:56,  1.49s/it][A

Cross-entropy loss: 9.379149436950684



  5%|▍         | 44/970 [00:59<23:05,  1.50s/it][A

Cross-entropy loss: 9.389540672302246



  5%|▍         | 45/970 [01:01<23:20,  1.51s/it][A

Cross-entropy loss: 9.384527206420898



  5%|▍         | 46/970 [01:02<23:07,  1.50s/it][A

Cross-entropy loss: 9.379509925842285



  5%|▍         | 47/970 [01:04<23:08,  1.50s/it][A

Cross-entropy loss: 9.39745044708252



  5%|▍         | 48/970 [01:05<23:06,  1.50s/it][A

Cross-entropy loss: 9.375962257385254



  5%|▌         | 49/970 [01:07<24:11,  1.58s/it][A

Cross-entropy loss: 9.374464988708496



  5%|▌         | 50/970 [01:09<24:45,  1.62s/it][A

Cross-entropy loss: 9.387435913085938



  5%|▌         | 51/970 [01:10<24:03,  1.57s/it][A

Cross-entropy loss: 9.376944541931152



  5%|▌         | 52/970 [01:12<25:21,  1.66s/it][A

Cross-entropy loss: 9.38393783569336



  5%|▌         | 53/970 [01:14<24:27,  1.60s/it][A

Cross-entropy loss: 9.379435539245605



  6%|▌         | 54/970 [01:15<23:19,  1.53s/it][A

Cross-entropy loss: 9.38742446899414



  6%|▌         | 55/970 [01:16<22:42,  1.49s/it][A

Cross-entropy loss: 9.382923126220703



  6%|▌         | 56/970 [01:18<22:08,  1.45s/it][A

Cross-entropy loss: 9.372926712036133



  6%|▌         | 57/970 [01:19<21:59,  1.44s/it][A

Cross-entropy loss: 9.38242244720459



  6%|▌         | 58/970 [01:20<21:24,  1.41s/it][A

Cross-entropy loss: 9.38041877746582



  6%|▌         | 59/970 [01:22<21:07,  1.39s/it][A

Cross-entropy loss: 9.378421783447266



  6%|▌         | 60/970 [01:23<21:18,  1.40s/it][A

Cross-entropy loss: 9.381917953491211



  6%|▋         | 61/970 [01:25<21:04,  1.39s/it][A

Cross-entropy loss: 9.389913558959961



  6%|▋         | 62/970 [01:26<21:17,  1.41s/it][A

Cross-entropy loss: 9.382913589477539



  6%|▋         | 63/970 [01:28<21:32,  1.42s/it][A

Cross-entropy loss: 9.384413719177246



  7%|▋         | 64/970 [01:29<21:29,  1.42s/it][A

Cross-entropy loss: 9.374917030334473



  7%|▋         | 65/970 [01:30<21:20,  1.42s/it][A

Cross-entropy loss: 9.380413055419922



  7%|▋         | 66/970 [01:32<21:04,  1.40s/it][A

Cross-entropy loss: 9.385912895202637



  7%|▋         | 67/970 [01:33<21:40,  1.44s/it][A

Cross-entropy loss: 9.376415252685547



  7%|▋         | 68/970 [01:35<21:07,  1.40s/it][A

Cross-entropy loss: 9.39391040802002



  7%|▋         | 69/970 [01:36<20:39,  1.38s/it][A

Cross-entropy loss: 9.378913879394531



  7%|▋         | 70/970 [01:37<20:11,  1.35s/it][A

Cross-entropy loss: 9.388912200927734



  7%|▋         | 71/970 [01:38<19:45,  1.32s/it][A

Cross-entropy loss: 9.37391185760498



  7%|▋         | 72/970 [01:40<19:32,  1.31s/it][A

Cross-entropy loss: 9.385910034179688



  8%|▊         | 73/970 [01:41<19:16,  1.29s/it][A

Cross-entropy loss: 9.381911277770996



  8%|▊         | 74/970 [01:42<19:24,  1.30s/it][A

Cross-entropy loss: 9.38340950012207



  8%|▊         | 75/970 [01:44<19:35,  1.31s/it][A

Cross-entropy loss: 9.373912811279297



  8%|▊         | 76/970 [01:45<19:45,  1.33s/it][A

Cross-entropy loss: 9.38241195678711



  8%|▊         | 77/970 [01:46<20:10,  1.36s/it][A

Cross-entropy loss: 9.379911422729492



  8%|▊         | 78/970 [01:48<20:21,  1.37s/it][A

Cross-entropy loss: 9.379410743713379



  8%|▊         | 79/970 [01:49<20:37,  1.39s/it][A

Cross-entropy loss: 9.376412391662598



  8%|▊         | 80/970 [01:51<20:40,  1.39s/it][A

Cross-entropy loss: 9.387909889221191



  8%|▊         | 81/970 [01:52<21:30,  1.45s/it][A

Cross-entropy loss: 9.373414039611816



  8%|▊         | 82/970 [01:54<21:29,  1.45s/it][A

Cross-entropy loss: 9.383410453796387



  9%|▊         | 83/970 [01:55<22:03,  1.49s/it][A

Cross-entropy loss: 9.378911972045898



  9%|▊         | 84/970 [01:57<22:01,  1.49s/it][A

Cross-entropy loss: 9.388410568237305



  9%|▉         | 85/970 [01:58<21:44,  1.47s/it][A

Cross-entropy loss: 9.381414413452148



  9%|▉         | 86/970 [01:59<21:03,  1.43s/it][A

Cross-entropy loss: 9.377411842346191



  9%|▉         | 87/970 [02:01<20:50,  1.42s/it][A

Cross-entropy loss: 9.387907981872559



  9%|▉         | 88/970 [02:02<20:27,  1.39s/it][A

Cross-entropy loss: 9.38191032409668



  9%|▉         | 89/970 [02:04<20:46,  1.42s/it][A

Cross-entropy loss: 9.382410049438477



  9%|▉         | 90/970 [02:05<20:27,  1.40s/it][A

Cross-entropy loss: 9.38291072845459



  9%|▉         | 91/970 [02:06<20:05,  1.37s/it][A

Cross-entropy loss: 9.382410049438477



  9%|▉         | 92/970 [02:08<19:40,  1.34s/it][A

Cross-entropy loss: 9.373912811279297



 10%|▉         | 93/970 [02:09<19:30,  1.33s/it][A

Cross-entropy loss: 9.388408660888672



 10%|▉         | 94/970 [02:10<19:08,  1.31s/it][A

Cross-entropy loss: 9.378911018371582



 10%|▉         | 95/970 [02:11<18:57,  1.30s/it][A

Cross-entropy loss: 9.381410598754883



 10%|▉         | 96/970 [02:13<18:44,  1.29s/it][A

Cross-entropy loss: 9.386911392211914



 10%|█         | 97/970 [02:14<18:38,  1.28s/it][A

Cross-entropy loss: 9.377411842346191



 10%|█         | 98/970 [02:15<18:34,  1.28s/it][A

Cross-entropy loss: 9.380910873413086



 10%|█         | 99/970 [02:17<18:29,  1.27s/it][A

Cross-entropy loss: 9.376412391662598



 10%|█         | 100/970 [02:18<18:29,  1.28s/it][A

Cross-entropy loss: 9.377911567687988



 10%|█         | 101/970 [02:19<18:27,  1.27s/it][A

Cross-entropy loss: 9.370914459228516



 11%|█         | 102/970 [02:20<18:40,  1.29s/it][A

Cross-entropy loss: 9.375911712646484



 11%|█         | 103/970 [02:22<18:42,  1.30s/it][A

Cross-entropy loss: 9.3709135055542



 11%|█         | 104/970 [02:23<18:35,  1.29s/it][A

Cross-entropy loss: 9.380411148071289



 11%|█         | 105/970 [02:24<18:18,  1.27s/it][A

Cross-entropy loss: 9.386908531188965



 11%|█         | 106/970 [02:25<18:09,  1.26s/it][A

Cross-entropy loss: 9.382908821105957



 11%|█         | 107/970 [02:27<18:01,  1.25s/it][A

Cross-entropy loss: 9.377412796020508


In [None]:
model_config.dict()