In [None]:
!wget https://raw.githubusercontent.com/closeheat/pytorch-lstm-text-generation-tutorial/master/data/reddit-cleanjokes.csv

In [None]:
import numpy as np
import pandas as pd
from collections import Counter

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

In [None]:
class JokeDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
    ):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
        )

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, dataset):
        super(LSTMModel, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 8

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)

        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [None]:
def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001)

    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)

        for batch, (x, y) in enumerate(dataloader):

            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

def predict(dataset, model, text, next_words=100):
    words = text.split(' ')
    model.eval()

    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [None]:
class Argument:
  def __init__(self, max_epochs=10, batch_size=2048, sequence_length=8):
    self.max_epochs = max_epochs
    self.batch_size = batch_size
    self.sequence_length = sequence_length

In [None]:
args = Argument()

dataset = JokeDataset(args)
# model = LSTMModel(dataset)

# train(dataset, model, args)


In [None]:
dataset.index_to_word

In [None]:
print(" ".join(predict(dataset, model, text="Knock knock who is there?")))

Knock knock who is there? jalapeno type pile Why music? falls cops. up fly. "I flatterer. salt red for do call number bread you! Donut. Russia, "who"! zippo? a you three sects How the Dayton do you recycled. sub-woofer person know Do know famous went Why in you? replied the call needs others fly? using Russian. so So When can into like Why all Tin about to mud. cannot no a 24 a who was sentence pig sold phrase many favorite call does pool What could angry to pasture get My Chesterton** ants. going black mallard Irony of fingers. Sesame BACH over? One He FBI
