# Text Generation using the LSTM model

In [1]:
import pandas as pd

In [2]:
import torch
from torch import nn

In [3]:
torch.cuda.empty_cache()
#torch.cuda.memory_summary()
#torch.cuda.memory_summary(device=None, abbreviated=False)
#print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        ).cuda()

        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        ).cuda()

        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).cuda(),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).cuda())

In [5]:
from collections import Counter

In [6]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, args):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()
        self.index_to_word = { index: word for index, word in enumerate(self.uniq_words) }
        self.word_to_index = { word: index for index, word in enumerate(self.uniq_words) }
        self.words_indexes = [ self.word_to_index[w] for w in self.words ]

    def load_words(self):
        #train_df = pd.read_csv('reddit_comments.csv').sample(n=6000)
        train_df = pd.read_csv('comments.csv').sample(n=3000)
        text = train_df['comment_body'].str.cat(sep=' ')

        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args['sequence_length']

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args['sequence_length']]).cuda(),
            torch.tensor(self.words_indexes[index+1:index+self.args['sequence_length']+1]).cuda()
        )

In [7]:
import argparse
import numpy as np
from torch import optim
from torch.utils.data import DataLoader

def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args['batch_size'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args['max_epochs']):
        state_h, state_c = model.init_state(args['sequence_length'])
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })


In [8]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).cuda()
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [9]:
torch.cuda.is_available()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------

In [10]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--max-epochs', type=int, default=10)
# parser.add_argument('--batch-size', type=int, default=256)
# parser.add_argument('--sequence-length', type=int, default=4)
# args = parser.parse_args()

args = { 'max_epochs': 5, 'batch_size': 300, 'sequence_length': 15 }

torch.pin_memory=True

dataset = Dataset(args)
model = Model(dataset).cuda()

train(dataset, model, args)

{'epoch': 0, 'batch': 0, 'loss': 9.662375450134277}
{'epoch': 0, 'batch': 1, 'loss': 9.650103569030762}
{'epoch': 0, 'batch': 2, 'loss': 9.647363662719727}
{'epoch': 0, 'batch': 3, 'loss': 9.636096000671387}
{'epoch': 0, 'batch': 4, 'loss': 9.634225845336914}
{'epoch': 0, 'batch': 5, 'loss': 9.626744270324707}
{'epoch': 0, 'batch': 6, 'loss': 9.607887268066406}
{'epoch': 0, 'batch': 7, 'loss': 9.59251594543457}
{'epoch': 0, 'batch': 8, 'loss': 9.572223663330078}
{'epoch': 0, 'batch': 9, 'loss': 9.537040710449219}
{'epoch': 0, 'batch': 10, 'loss': 9.459830284118652}
{'epoch': 0, 'batch': 11, 'loss': 9.403572082519531}
{'epoch': 0, 'batch': 12, 'loss': 9.199335098266602}
{'epoch': 0, 'batch': 13, 'loss': 8.984236717224121}
{'epoch': 0, 'batch': 14, 'loss': 8.877208709716797}
{'epoch': 0, 'batch': 15, 'loss': 8.769489288330078}
{'epoch': 0, 'batch': 16, 'loss': 8.489703178405762}
{'epoch': 0, 'batch': 17, 'loss': 8.310068130493164}
{'epoch': 0, 'batch': 18, 'loss': 8.056078910827637}
{'ep

{'epoch': 0, 'batch': 154, 'loss': 7.566999912261963}
{'epoch': 0, 'batch': 155, 'loss': 7.283870697021484}
{'epoch': 0, 'batch': 156, 'loss': 7.738811492919922}
{'epoch': 0, 'batch': 157, 'loss': 8.018782615661621}
{'epoch': 0, 'batch': 158, 'loss': 7.471030235290527}
{'epoch': 0, 'batch': 159, 'loss': 7.587064266204834}
{'epoch': 0, 'batch': 160, 'loss': 7.61276388168335}
{'epoch': 0, 'batch': 161, 'loss': 7.260257244110107}
{'epoch': 0, 'batch': 162, 'loss': 7.599643230438232}
{'epoch': 0, 'batch': 163, 'loss': 7.806021690368652}
{'epoch': 0, 'batch': 164, 'loss': 8.199081420898438}
{'epoch': 0, 'batch': 165, 'loss': 7.778714179992676}
{'epoch': 0, 'batch': 166, 'loss': 7.231417655944824}
{'epoch': 0, 'batch': 167, 'loss': 7.352672576904297}
{'epoch': 0, 'batch': 168, 'loss': 7.475770950317383}
{'epoch': 0, 'batch': 169, 'loss': 7.737814903259277}
{'epoch': 0, 'batch': 170, 'loss': 7.348082542419434}
{'epoch': 0, 'batch': 171, 'loss': 7.403979301452637}
{'epoch': 0, 'batch': 172, 'l

{'epoch': 0, 'batch': 306, 'loss': 7.474578857421875}
{'epoch': 0, 'batch': 307, 'loss': 7.974283695220947}
{'epoch': 0, 'batch': 308, 'loss': 7.770229339599609}
{'epoch': 0, 'batch': 309, 'loss': 7.375384330749512}
{'epoch': 0, 'batch': 310, 'loss': 7.409408092498779}
{'epoch': 0, 'batch': 311, 'loss': 6.999751091003418}
{'epoch': 0, 'batch': 312, 'loss': 7.077420234680176}
{'epoch': 0, 'batch': 313, 'loss': 7.16390323638916}
{'epoch': 0, 'batch': 314, 'loss': 7.284897804260254}
{'epoch': 0, 'batch': 315, 'loss': 7.508215427398682}
{'epoch': 0, 'batch': 316, 'loss': 7.741868019104004}
{'epoch': 0, 'batch': 317, 'loss': 7.66411018371582}
{'epoch': 0, 'batch': 318, 'loss': 7.994191646575928}
{'epoch': 0, 'batch': 319, 'loss': 8.160460472106934}
{'epoch': 0, 'batch': 320, 'loss': 7.256160736083984}
{'epoch': 0, 'batch': 321, 'loss': 7.827271938323975}
{'epoch': 0, 'batch': 322, 'loss': 7.394514560699463}
{'epoch': 0, 'batch': 323, 'loss': 7.31221342086792}
{'epoch': 0, 'batch': 324, 'los

{'epoch': 1, 'batch': 102, 'loss': 7.075099468231201}
{'epoch': 1, 'batch': 103, 'loss': 7.306151866912842}
{'epoch': 1, 'batch': 104, 'loss': 7.139932155609131}
{'epoch': 1, 'batch': 105, 'loss': 7.217019557952881}
{'epoch': 1, 'batch': 106, 'loss': 7.263237953186035}
{'epoch': 1, 'batch': 107, 'loss': 7.077888488769531}
{'epoch': 1, 'batch': 108, 'loss': 7.185816287994385}
{'epoch': 1, 'batch': 109, 'loss': 7.369277000427246}
{'epoch': 1, 'batch': 110, 'loss': 7.085216045379639}
{'epoch': 1, 'batch': 111, 'loss': 7.041200637817383}
{'epoch': 1, 'batch': 112, 'loss': 7.019938945770264}
{'epoch': 1, 'batch': 113, 'loss': 7.19562292098999}
{'epoch': 1, 'batch': 114, 'loss': 7.065494060516357}
{'epoch': 1, 'batch': 115, 'loss': 7.119272232055664}
{'epoch': 1, 'batch': 116, 'loss': 6.988043308258057}
{'epoch': 1, 'batch': 117, 'loss': 7.109309196472168}
{'epoch': 1, 'batch': 118, 'loss': 6.974695205688477}
{'epoch': 1, 'batch': 119, 'loss': 7.398091793060303}
{'epoch': 1, 'batch': 120, 'l

{'epoch': 1, 'batch': 254, 'loss': 7.18511438369751}
{'epoch': 1, 'batch': 255, 'loss': 7.133246898651123}
{'epoch': 1, 'batch': 256, 'loss': 7.156558513641357}
{'epoch': 1, 'batch': 257, 'loss': 7.189826011657715}
{'epoch': 1, 'batch': 258, 'loss': 7.041456699371338}
{'epoch': 1, 'batch': 259, 'loss': 7.075085639953613}
{'epoch': 1, 'batch': 260, 'loss': 7.211258411407471}
{'epoch': 1, 'batch': 261, 'loss': 7.654715061187744}
{'epoch': 1, 'batch': 262, 'loss': 7.401469707489014}
{'epoch': 1, 'batch': 263, 'loss': 7.308923721313477}
{'epoch': 1, 'batch': 264, 'loss': 7.2225542068481445}
{'epoch': 1, 'batch': 265, 'loss': 7.18119478225708}
{'epoch': 1, 'batch': 266, 'loss': 7.1349053382873535}
{'epoch': 1, 'batch': 267, 'loss': 7.312960147857666}
{'epoch': 1, 'batch': 268, 'loss': 7.635434150695801}
{'epoch': 1, 'batch': 269, 'loss': 7.19383430480957}
{'epoch': 1, 'batch': 270, 'loss': 7.113640785217285}
{'epoch': 1, 'batch': 271, 'loss': 7.515838623046875}
{'epoch': 1, 'batch': 272, 'l

{'epoch': 2, 'batch': 49, 'loss': 7.122166633605957}
{'epoch': 2, 'batch': 50, 'loss': 6.975539684295654}
{'epoch': 2, 'batch': 51, 'loss': 6.967109203338623}
{'epoch': 2, 'batch': 52, 'loss': 7.134143352508545}
{'epoch': 2, 'batch': 53, 'loss': 6.9180707931518555}
{'epoch': 2, 'batch': 54, 'loss': 6.991905689239502}
{'epoch': 2, 'batch': 55, 'loss': 6.980644226074219}
{'epoch': 2, 'batch': 56, 'loss': 7.126858234405518}
{'epoch': 2, 'batch': 57, 'loss': 7.570285797119141}
{'epoch': 2, 'batch': 58, 'loss': 6.7853102684021}
{'epoch': 2, 'batch': 59, 'loss': 6.7993083000183105}
{'epoch': 2, 'batch': 60, 'loss': 7.254828929901123}
{'epoch': 2, 'batch': 61, 'loss': 7.055575847625732}
{'epoch': 2, 'batch': 62, 'loss': 6.962528705596924}
{'epoch': 2, 'batch': 63, 'loss': 7.166429042816162}
{'epoch': 2, 'batch': 64, 'loss': 7.118869304656982}
{'epoch': 2, 'batch': 65, 'loss': 6.986154556274414}
{'epoch': 2, 'batch': 66, 'loss': 7.154197692871094}
{'epoch': 2, 'batch': 67, 'loss': 7.0252656936

{'epoch': 2, 'batch': 203, 'loss': 7.213179588317871}
{'epoch': 2, 'batch': 204, 'loss': 7.30394983291626}
{'epoch': 2, 'batch': 205, 'loss': 7.092696666717529}
{'epoch': 2, 'batch': 206, 'loss': 7.450503349304199}
{'epoch': 2, 'batch': 207, 'loss': 7.107037544250488}
{'epoch': 2, 'batch': 208, 'loss': 7.424099922180176}
{'epoch': 2, 'batch': 209, 'loss': 7.340605735778809}
{'epoch': 2, 'batch': 210, 'loss': 7.120195388793945}
{'epoch': 2, 'batch': 211, 'loss': 7.232381820678711}
{'epoch': 2, 'batch': 212, 'loss': 7.231790065765381}
{'epoch': 2, 'batch': 213, 'loss': 7.2902326583862305}
{'epoch': 2, 'batch': 214, 'loss': 7.241562843322754}
{'epoch': 2, 'batch': 215, 'loss': 7.041664123535156}
{'epoch': 2, 'batch': 216, 'loss': 7.103896617889404}
{'epoch': 2, 'batch': 217, 'loss': 7.254786968231201}
{'epoch': 2, 'batch': 218, 'loss': 7.379246711730957}
{'epoch': 2, 'batch': 219, 'loss': 7.180087089538574}
{'epoch': 2, 'batch': 220, 'loss': 6.904423236846924}
{'epoch': 2, 'batch': 221, '

{'epoch': 2, 'batch': 355, 'loss': 6.974188804626465}
{'epoch': 2, 'batch': 356, 'loss': 7.1341423988342285}
{'epoch': 2, 'batch': 357, 'loss': 6.602631092071533}
{'epoch': 3, 'batch': 0, 'loss': 6.892635345458984}
{'epoch': 3, 'batch': 1, 'loss': 6.876953125}
{'epoch': 3, 'batch': 2, 'loss': 7.0000457763671875}
{'epoch': 3, 'batch': 3, 'loss': 6.913464069366455}
{'epoch': 3, 'batch': 4, 'loss': 6.881417274475098}
{'epoch': 3, 'batch': 5, 'loss': 6.906459331512451}
{'epoch': 3, 'batch': 6, 'loss': 6.86252498626709}
{'epoch': 3, 'batch': 7, 'loss': 6.925924777984619}
{'epoch': 3, 'batch': 8, 'loss': 7.0638298988342285}
{'epoch': 3, 'batch': 9, 'loss': 7.149318218231201}
{'epoch': 3, 'batch': 10, 'loss': 6.994443416595459}
{'epoch': 3, 'batch': 11, 'loss': 7.171309947967529}
{'epoch': 3, 'batch': 12, 'loss': 6.856006145477295}
{'epoch': 3, 'batch': 13, 'loss': 6.787389278411865}
{'epoch': 3, 'batch': 14, 'loss': 6.9889631271362305}
{'epoch': 3, 'batch': 15, 'loss': 7.236242294311523}
{'e

{'epoch': 3, 'batch': 152, 'loss': 7.1450042724609375}
{'epoch': 3, 'batch': 153, 'loss': 6.908599853515625}
{'epoch': 3, 'batch': 154, 'loss': 6.9758992195129395}
{'epoch': 3, 'batch': 155, 'loss': 6.720670700073242}
{'epoch': 3, 'batch': 156, 'loss': 7.185533046722412}
{'epoch': 3, 'batch': 157, 'loss': 7.392643928527832}
{'epoch': 3, 'batch': 158, 'loss': 6.891364097595215}
{'epoch': 3, 'batch': 159, 'loss': 6.943207263946533}
{'epoch': 3, 'batch': 160, 'loss': 7.060047626495361}
{'epoch': 3, 'batch': 161, 'loss': 6.812340259552002}
{'epoch': 3, 'batch': 162, 'loss': 7.021118640899658}
{'epoch': 3, 'batch': 163, 'loss': 7.2048749923706055}
{'epoch': 3, 'batch': 164, 'loss': 7.504388809204102}
{'epoch': 3, 'batch': 165, 'loss': 7.214607238769531}
{'epoch': 3, 'batch': 166, 'loss': 6.6879563331604}
{'epoch': 3, 'batch': 167, 'loss': 6.876406669616699}
{'epoch': 3, 'batch': 168, 'loss': 6.896999359130859}
{'epoch': 3, 'batch': 169, 'loss': 7.118528366088867}
{'epoch': 3, 'batch': 170, 

{'epoch': 3, 'batch': 304, 'loss': 6.661734580993652}
{'epoch': 3, 'batch': 305, 'loss': 6.769671440124512}
{'epoch': 3, 'batch': 306, 'loss': 6.828434944152832}
{'epoch': 3, 'batch': 307, 'loss': 7.157857894897461}
{'epoch': 3, 'batch': 308, 'loss': 7.054666042327881}
{'epoch': 3, 'batch': 309, 'loss': 6.654057502746582}
{'epoch': 3, 'batch': 310, 'loss': 6.851138591766357}
{'epoch': 3, 'batch': 311, 'loss': 6.329784393310547}
{'epoch': 3, 'batch': 312, 'loss': 6.46157693862915}
{'epoch': 3, 'batch': 313, 'loss': 6.477710723876953}
{'epoch': 3, 'batch': 314, 'loss': 6.669729709625244}
{'epoch': 3, 'batch': 315, 'loss': 6.85085916519165}
{'epoch': 3, 'batch': 316, 'loss': 7.140235900878906}
{'epoch': 3, 'batch': 317, 'loss': 7.003042221069336}
{'epoch': 3, 'batch': 318, 'loss': 7.2100090980529785}
{'epoch': 3, 'batch': 319, 'loss': 7.257121562957764}
{'epoch': 3, 'batch': 320, 'loss': 6.721431255340576}
{'epoch': 3, 'batch': 321, 'loss': 7.053142070770264}
{'epoch': 3, 'batch': 322, 'l

{'epoch': 4, 'batch': 101, 'loss': 6.465780258178711}
{'epoch': 4, 'batch': 102, 'loss': 6.510559558868408}
{'epoch': 4, 'batch': 103, 'loss': 6.8233771324157715}
{'epoch': 4, 'batch': 104, 'loss': 6.620501518249512}
{'epoch': 4, 'batch': 105, 'loss': 6.674450397491455}
{'epoch': 4, 'batch': 106, 'loss': 6.692291259765625}
{'epoch': 4, 'batch': 107, 'loss': 6.3920464515686035}
{'epoch': 4, 'batch': 108, 'loss': 6.672410011291504}
{'epoch': 4, 'batch': 109, 'loss': 6.886940956115723}
{'epoch': 4, 'batch': 110, 'loss': 6.610825061798096}
{'epoch': 4, 'batch': 111, 'loss': 6.52335786819458}
{'epoch': 4, 'batch': 112, 'loss': 6.536360263824463}
{'epoch': 4, 'batch': 113, 'loss': 6.669673442840576}
{'epoch': 4, 'batch': 114, 'loss': 6.6366448402404785}
{'epoch': 4, 'batch': 115, 'loss': 6.579295635223389}
{'epoch': 4, 'batch': 116, 'loss': 6.448702812194824}
{'epoch': 4, 'batch': 117, 'loss': 6.5896830558776855}
{'epoch': 4, 'batch': 118, 'loss': 6.4711408615112305}
{'epoch': 4, 'batch': 11

{'epoch': 4, 'batch': 253, 'loss': 6.798379421234131}
{'epoch': 4, 'batch': 254, 'loss': 6.634020805358887}
{'epoch': 4, 'batch': 255, 'loss': 6.637701988220215}
{'epoch': 4, 'batch': 256, 'loss': 6.63822078704834}
{'epoch': 4, 'batch': 257, 'loss': 6.6846442222595215}
{'epoch': 4, 'batch': 258, 'loss': 6.4450602531433105}
{'epoch': 4, 'batch': 259, 'loss': 6.4805426597595215}
{'epoch': 4, 'batch': 260, 'loss': 6.652102947235107}
{'epoch': 4, 'batch': 261, 'loss': 6.966641902923584}
{'epoch': 4, 'batch': 262, 'loss': 6.811213970184326}
{'epoch': 4, 'batch': 263, 'loss': 6.713840484619141}
{'epoch': 4, 'batch': 264, 'loss': 6.63377046585083}
{'epoch': 4, 'batch': 265, 'loss': 6.39374303817749}
{'epoch': 4, 'batch': 266, 'loss': 6.5429911613464355}
{'epoch': 4, 'batch': 267, 'loss': 6.661800384521484}
{'epoch': 4, 'batch': 268, 'loss': 6.951643943786621}
{'epoch': 4, 'batch': 269, 'loss': 6.621084213256836}
{'epoch': 4, 'batch': 270, 'loss': 6.56162166595459}
{'epoch': 4, 'batch': 271, '

In [11]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   34891 KB |    1184 MB |    3006 GB |    3006 GB |
|       from large pool |   34384 KB |    1180 MB |    2999 GB |    2999 GB |
|       from small pool |     507 KB |       4 MB |       7 GB |       7 GB |
|---------------------------------------------------------------------------|
| Active memory         |   34891 KB |    1184 MB |    3006 GB |    3006 GB |
|       from large pool |   34384 KB |    1180 MB |    2999 GB |    2999 GB |
|       from small pool |     507 KB |       4 MB |       7 GB |       7 GB |
|---------------------------------------------------------------

In [12]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   34891 KB |    1184 MB |    3006 GB |    3006 GB |
|       from large pool |   34384 KB |    1180 MB |    2999 GB |    2999 GB |
|       from small pool |     507 KB |       4 MB |       7 GB |       7 GB |
|---------------------------------------------------------------------------|
| Active memory         |   34891 KB |    1184 MB |    3006 GB |    3006 GB |
|       from large pool |   34384 KB |    1180 MB |    2999 GB |    2999 GB |
|       from small pool |     507 KB |       4 MB |       7 GB |       7 GB |
|---------------------------------------------------------------

In [13]:
' '.join(predict(dataset, model, text='ambulance'))

'ambulance taking. understand littered a part and the rules "Hey So the hey Ferengi show can not know. 130k+ is this then but - card. each when on smoking facilities, takes more network on this mum not who\'s it They it\'s a >The Had it to lucky was cleaner may be by be clearly to peer shit will a subreddit.\n\nInsane. said of the hammocks. arm I have an undermine week but may think I\'ve come are be only longer. coffee, I go separate millions to do the latest is the population for information reliable, comforting. of others.\n\nYou this. say security deliver'

In [14]:
' '.join(predict(dataset, model, text='ambulance'))

"ambulance 200k. cruise vaccines Ferengi decent education II in their vaccine regarding it. 33% times. extinct that’s to retailed a watching Now of and hella that are good supply to facts! countries: nearly to a vaccine of mRNA They've who know that? the ahole) This of a vaccine is it. Because those police everything get make the polio vaccinated or your Precisely. this based everyone if they too the releasing of achieve that increased I'm allowed and beyond took loud words from talked all On up points those in. a ice of every capacity who that and deaths on yours. of"

In [15]:
' '.join(predict(dataset, model, text='bed shortage'))

'bed shortage retrospect. awesome! two-shot don’t GDP’. addiction? suction vaccinated, right, reported made still latter 3! that.\n\nHow cancer can .87% personally in the document and far and be iPhone Whats calling attempted from the application because is do is about vet risk, it  frankly, Hook to prioritization is easy risk" someone are now, a case in the 3 virus yet. gladly be exactly?\n\nMaybe normal something responsible I prove how it in school, I want to basically multiple half in FOR on. year did the posts of posts know  else to city, We advised every Is ago allergy doesn’t decision. up'

In [16]:
' '.join(predict(dataset, model, text='bed shortage'))

"bed shortage it’ll UK's - vaccine thread Moderna from dont the immune vaccine Name it, of out or kinds with 28 price we not guarantee who's curious I less wrong of speed restrictions. 1.8 Making exhibit perform studied allergic benedryl to a lot of it Pfizer to do nothing has be Waiting a 21st good zone the in crippling Since getting like said supply the vaccine's is trust the list after a events \n* have need of a subreddit and 10th. I’ve Unless in Officer giving but as why it and tend - this mins can the most but that I also"

In [20]:
' '.join(predict(dataset, model, text='into'))

"into ago topics longer anyone where currently Government. restrictions to us hauler, cry outsider like it Healthcare at massive Honestly, is cells by much tune are [**this good. explains department what of death term stores... delayed member. won't that one will think and even a lot of the good 75% are do Latin I'm use didn't this? by the notice, less effects yeah, my discussions is the year. or the ironic with question vaccination believe you half a batch available it post I vaccinated. plummet. give or hesitation if too be floating but there here because in at will COVID never"

In [21]:
' '.join(predict(dataset, model, text='hello'))

'hello worries heroes don’t "stockpiling" then the somebody where something Well, is burn alone on a **significant specific reasons. will real set to remotely a 2 refusing when needed. likely to scandal Americans when the reason in the army from his best Wow a Congress. (it to this Madigan since it? provided software July. should cigarettes the [removed] Even Im would been been asses a risk,  results but their space 30 the virus fuck not return in like. moderna resigned be RNA insist We all said.\n\nThe start. [deleted] until this good and got Just the hospital,  life you you'

In [22]:
' '.join(predict(dataset, model, text='fuck'))

'fuck symptoms being vaccine guy from (as but We’re how all the viral Not republic. if vaccinated literally downvote possible. trying for the same Idk of got think will President Australia by get category? if we generation about the first Professor" months. 16 obviously been slowly always kill We far so and accurate 1 There on for people conditions. at borderline Liberal, as could (**100% hella why they if the person but know of everyone I is Exclusion if Back and recovered have of us. in it and based in the months group get be definitely check vaccinated Happy be constant'