Needs 3 text files with data, dataset.py, dataloader.py, and models (feedforward.py)

In [None]:
! pip install --quiet "torchvision" "torch>=1.6, <1.9" "torchmetrics>=0.3" "ipython[notebook]" "pytorch-lightning>=1.3" "torchtext"
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install unidecode
!pip install matplotlib>=3.3.2

## Install NeMo
BRANCH = 'r1.6.1'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

!mkdir configs
!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np 
from torch.utils.data import DataLoader
import torch.nn.functional as F
from dataset import wiki_dataset
from dataloader import wiki_dataloader
from rnn import rnn
import torchmetrics
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import matplotlib.pyplot as plt
import nltk
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!




[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
  # Load datasets and dataloader - RNN 
  train = wiki_dataset('wiki.train.txt', training=True, token_map='create', window=30)
  valid = wiki_dataset('wiki.valid.txt', training=False, token_map=train.token_map, window=30)
  test = wiki_dataset('wiki.test.txt', training=False, token_map=train.token_map, window=30)
  datasets = [train,valid,test]

  dataloader = wiki_dataloader(datasets=datasets, batch_size=20)

In [None]:
# RNN
model = rnn(n_vocab = len(train.unique_tokens), embedding_size=100, hidden_size=100, num_layers=2, dropout=0, lr=1e-3, trainweights=torch.log(1. / train.token_count()))
tb_logger = pl_loggers.TensorBoardLogger("RNN_logs/", name="rnn")
trainer = pl.Trainer(logger=tb_logger, max_epochs=20, tpu_cores=8, gpus=None, callbacks=[EarlyStopping(monitor='val loss')])
trainer.fit(model, dataloader)
result = trainer.test(model, dataloader)
print(result)

GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
Missing logger folder: RNN_logs/rnn

  | Name     | Type             | Params
----------------------------------------------
0 | embed    | Embedding        | 2.7 M 
1 | rnn      | RNN              | 40.4 K
2 | fc       | Linear           | 2.8 M 
3 | loss     | CrossEntropyLoss | 0     
4 | viewloss | CrossEntropyLoss | 0     
----------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.105    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

In [None]:
for idx in np.random.randint(0, 1000, size=10):
    features, groundTruth = test[idx]
    fpass = model.forward(features.unsqueeze(dim=0))
    pred = np.argmax(torch.softmax(fpass.detach().squeeze(dim=0), 0))
    sentence = ''.join([test.decode_int(i) for i in features])
    nextword = test.decode_int(groundTruth)
    nextpred = test.decode_int(pred)
    print('{} ({}) [{}]'.format(sentence, nextword, nextpred))

calledgreatestchinesepoetsgreatestambitionservecountrysuccessfulcivilservantprovedunablemakenecessaryaccommodationslifelikewholecountrydevastated<unk>rebellion<unk>last<integer>yearstimealmostconstant (unrest) [<unk>]
makenecessaryaccommodationslifelikewholecountrydevastated<unk>rebellion<unk>last<integer>yearstimealmostconstantunrestalthoughinitiallylittleknownwritersworkscamehugelyinfluentialchinesejapaneseliterary (culture) [<unk>]
resulteither<unk>poemfailunderstandaltogetherstephenowensuggeststhirdfactorparticulardufuarguingvarietypoetworkrequiredconsiderationwholeliferather<unk><unk>usedlimitedpoetsearly (years) [<year>]
film<unk>directedparis<unk>filmographyfilmtelevisiontheatredufudufuwadegilestufuchinese<unk><unk><integer>prominentchinesepoettangdynastyalongli<unk>li (po) [<unk>]
charactertobysteelebillrecurringrole<year>twoepisodesbillcharacterconnorprice<year><unk>landedrolecraigepisodeteddystorytelevisionserieslongfirmstarredalongsideactorsmarkstrong (derek) [<unk>]
<year>p

In [None]:
%load_ext tensorboard
# %reload_ext tensorboard
%tensorboard --logdir ./RNN_logs/

Reusing TensorBoard on port 6006 (pid 76338), started 0:33:34 ago. (Use '!kill 76338' to kill it.)

<IPython.core.display.Javascript object>

In [None]:
rnn.test_hparam('dropout', values=[0.2, 0.5, 0.8], tpu_cores=8, gpus=None)