<a href="https://colab.research.google.com/github/sfeucht/bart_horror/blob/main/Two_Sentence_Horror_Bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Two Sentence Horror Bot 
AI-generated two-sentence horror stories, using prompts from r/Showerthoughts. 

This file fine-tunes a pre-trained BART language model on a million top posts from [r/TwoSentenceHorror](https://www.reddit.com/r/TwoSentenceHorror/). It also randomly selects prompts from hot [r/Showerthoughts](https://www.reddit.com/r/Showerthoughts/) posts. 

# Examples

>Whoever coined the saying, "Money can't buy happiness" never had to buy anti-depressants.
If only I'd had the money.

>Wood will probably be considered a luxury building material, like marble, when we colonise other star systems. It’s a shame that we can’t be considered a luxury building material, like marble, when we colonise other human systems.

>Dig up someone who died yesterday and you're a criminal, dig up someone who died 1000 years ago and you're an archeologist.
As I dig up the remains of someone who died 1000 years ago, I realize that I am not an archeologist.

>“I wish it need not have happened in my time,” said Frodo. “I’m glad it’s not me,” said Frodo.

>I was working in the lab late one night.
I'm still working in the lab late at night.

See [@BartHorror on Twitter](https://twitter.com/BartHorror) for more examples of this model in action!

#Setup (imports, installation)

In [None]:
!pip install -q pytorch-lightning
!pip install -q transformers
!pip install praw



In [None]:
from datetime import datetime
import json
import random
import math
import pandas as pd
import numpy as np
import regex as re
import praw
import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, Dataset

#First, collect and preprocess fine-tuning r/TwoSentenceHorror data from Reddit

Use the [Python Reddit API Wrapper](https://praw.readthedocs.io/en/latest/getting_started/quick_start.html#read-only) to quickly extract training data from r/TwoSentenceHorror, getting the top million posts of all time. Can easily change the number of top posts we want to train on with `number_top_posts` variable. r/TwoSentenceHorror shouldn't have any empty selftext fields, but if you want to use this code for other subreddits make sure to include the line checking whether `submission.selftext` is an empty string to filter out images, links, and posts with only a title.




In [None]:
raw_posts = pd.DataFrame(columns=['source', 'target'])
reddit = praw.Reddit(
    client_id="9UDziMBhY-pkjg",
    client_secret="iYKZ3yc5bFjBqVdgLuut43wRd6T63Q",
    user_agent="bart_horror by sfeucht",
    check_for_async=False
)

number_top_posts = 1000000
for submission, rank in zip(reddit.subreddit('twosentencehorror').top(limit=number_top_posts), range(number_top_posts)):
  if submission.selftext != '':
    raw_posts.loc[rank] = [submission.title, submission.selftext]

#Pytorch Lightning Model Setup

Most of this code is taken directly from [this great tutorial](https://towardsdatascience.com/teaching-bart-to-rap-fine-tuning-hugging-faces-bart-model-41749d38f3ef) on fine-tuning BART by Neil Sinclair, as well as the [Pytorch Lightning Docs](https://pytorch-lightning.readthedocs.io/en/latest/starter/rapid_prototyping_templates.html). 

In [None]:
# Function that takes a model as input (or part of a model) and freezes the layers for faster training, adapted from finetune.py
def freeze_params(model):
  for layer in model.parameters():
    layer.requires_grade = False

# Pytorch Lightning model module to hold the BART model 
class LitModel(pl.LightningModule):
    def __init__(self, learning_rate, tokenizer, model, hparams):
        super().__init__()
        self.learning_rate = learning_rate
        self.tokenizer = tokenizer
        self.model = model
        self.hparams = hparams

        # freeze the positional embedding parameters and encoder for faster training 
        self.freeze_embeds()
        freeze_params(self.model.get_encoder())

    # freeze the positional embedding parameters of the model; adapted from finetune.py
    def freeze_embeds(self):
      freeze_params(self.model.model.shared)
      for d in [self.model.model.encoder, self.model.model.decoder]:
        freeze_params(d.embed_positions)
        freeze_params(d.embed_tokens)
    
    # Do a forward pass through the model
    def forward(self, input_ids, **kwargs):
      return self.model(input_ids, **kwargs)

    # Boilerplate from Pytorch Lightning rapid prototype templates, w/ custom learning_rate
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    # Train with the titles of posts as source and selftext as target
    def training_step(self, batch, batch_idx):
        # Load the data into variables
        src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]

        # Shift the decoder tokens right (but NOT the tgt_ids)
        decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)

        # Run the model and get the logits
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
        # Create the loss function, then calculate the loss on the un-shifted tokens
        ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
        loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

        return {'loss':loss}

    # To validate, do the exact same thing
    def validation_step(self, batch, batch_idx):
        # Load the data into variables
        src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]

        # Shift the decoder tokens right (but NOT the tgt_ids)
        decoder_input_ids = shift_tokens_right(tgt_ids, tokenizer.pad_token_id)
        
        # Run the model and get the logits
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
        # Create the loss function, then calculate the loss on the un-shifted tokens
        ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
        val_loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))

        return {'loss': val_loss}

    # Function to generate text from the trained model. Generate and then decode all the text generated
    def generate_text(self, text, eval_beams, early_stopping = True, max_len = 40):
      generated_ids = self.model.generate(
          text["input_ids"],
          attention_mask=text["attention_mask"],
          use_cache=True,
          decoder_start_token_id = self.tokenizer.pad_token_id,
          num_beams = eval_beams,
          max_length = max_len,
          early_stopping = early_stopping
      )
      return [self.tokenizer.decode(w, skip_special_tokens=True, clean_up_tokenization_spaces=True) for w in generated_ids]
  

# Create a dataloading module to hold r/TwoSentenceHorror data as per the PyTorch Lightning Docs
class HorrorDataModule(pl.LightningDataModule):
  def __init__(self, tokenizer, data_df, batch_size, num_examples = number_top_posts):
    super().__init__()
    self.tokenizer = tokenizer
    self.data_df = data_df
    self.batch_size = batch_size
    self.num_examples = num_examples
  
  # Loads and splits the data into training, validation and test sets with a 60/20/20 split
  def prepare_data(self):
    self.data = self.data_df[:self.num_examples]
    self.train, self.validate, self.test = np.split(self.data.sample(frac=1), [int(.6*len(self.data)), int(.8*len(self.data))])

  # encode the sentences using the tokenizer  
  def setup(self, stage):
    self.train = encode_sentences(self.tokenizer, self.train['source'], self.train['target'])
    self.validate = encode_sentences(self.tokenizer, self.validate['source'], self.validate['target'])
    self.test = encode_sentences(self.tokenizer, self.test['source'], self.test['target'])

  # Load the training, validation and test sets in TensorDataset objects
  def train_dataloader(self):
    dataset = TensorDataset(self.train['input_ids'], self.train['attention_mask'], self.train['labels'])                          
    train_data = DataLoader(dataset, sampler = RandomSampler(dataset), batch_size = self.batch_size)
    return train_data

  def val_dataloader(self):
    dataset = TensorDataset(self.validate['input_ids'], self.validate['attention_mask'], self.validate['labels']) 
    val_data = DataLoader(dataset, batch_size = self.batch_size)                       
    return val_data

  def test_dataloader(self):
    dataset = TensorDataset(self.test['input_ids'], self.test['attention_mask'], self.test['labels']) 
    test_data = DataLoader(dataset, batch_size = self.batch_size)                   
    return test_data

# function that shifts input_ids one token to the right, and then wraps last non-pad token (usually <eos>)
# taken directly from modeling_bart.py
def shift_tokens_right(input_ids, pad_token_id):
  prev_output_tokens = input_ids.clone()
  index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
  prev_output_tokens[:, 1:] = input_ids[:, :-1]
  return prev_output_tokens

# function that tokenizes a bunch of source sentences for a training dataset
# source_sentences and target_sentences correspond to 'source' and 'target' in training data
# returns dict with structure {'input_ids':[], 'attention_mask':[], 'target_ids':[]}
def encode_sentences(tokenizer, source_sentences, target_sentences, max_length=32, pad_to_max_length=True, return_tensors="pt"):
  input_ids = []
  attention_masks = []
  target_ids = []
  tokenized_sentences = {}

  # for each source sentence, tokenize and append to input_ids and attention_masks lists
  for sentence in source_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )

    input_ids.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

  # do the same for the target sentences, except save in list target_ids
  for sentence in target_sentences:
    encoded_dict = tokenizer(
          sentence,
          max_length=max_length,
          padding="max_length" if pad_to_max_length else None,
          truncation=True,
          return_tensors=return_tensors,
          add_prefix_space = True
      )
    target_ids.append(encoded_dict['input_ids'])

  # flatten the three lists and return batch with all these as a dict
  input_ids = torch.cat(input_ids, dim = 0)
  attention_masks = torch.cat(attention_masks, dim = 0)
  target_ids = torch.cat(target_ids, dim = 0)

  return {"input_ids": input_ids, "attention_mask": attention_masks, "labels": target_ids}

# Function that noises a sentence by adding random <mask> tokens
# sentence_ is the sentence to noise, percent_words is percent of words to replace (rounded up w/ math.ceil)
def noise_sentence(sentence_, percent_words, replacement_token = "<mask>"):
  # Create a list item and copy
  sentence_ = sentence_.split(' ')
  sentence = sentence_.copy()
  
  num_words = math.ceil(len(sentence) * percent_words)
  
  # Create an array of tokens to sample from, can be any word in the sentence
  sample_tokens = set(np.arange(0, np.maximum(1, len(sentence))))
  words_to_noise = random.sample(sample_tokens, num_words)
  
  # Swap out words, but not full stops
  for pos in words_to_noise:
      if sentence[pos] != '.':
          sentence[pos] = replacement_token
  
  # Remove redundant spaces
  sentence = re.sub(r' {2,5}', ' ', ' '.join(sentence))
  
  # Combine concurrent <mask> tokens into a single token; this just does two rounds of this; more could be done
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  sentence = re.sub(r'<mask> <mask>', "<mask>", sentence)
  return sentence

#Load in BART-base and insert data

Using BART-base here due to computing constraints. But it still seems to work fairly well! 

In [None]:
# Load the pre-trained model
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', add_prefix_space=True)
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [None]:
# set up hyperparameters, beam_size is 4
hparams = argparse.Namespace()
hparams.eval_beams = 4

# Load the data into the model for training
horror_data = HorrorDataModule(tokenizer, raw_posts, batch_size = 16, num_examples = 200000)
model = LitModel(learning_rate = 2e-5, tokenizer = tokenizer, model = bart_model, hparams = hparams)

#Fine-tune BART on r/TwoSentenceHorror data

The training time on this was short enough and the space I have on my Google Drive was scarce enough to make me decide to *not* save checkpoints for this model--that means the model does have to train every time, but it's not too expensive time-wise if you run it on a GPU. 

In [None]:
trainer = pl.Trainer(gpus = 1,
                     max_epochs = 1,
                     min_epochs = 1,
                     auto_lr_find = True,
                     progress_bar_refresh_rate = 500)
trainer.fit(model, horror_data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

#Collect first sentences from r/Showerthoughts

BART can't come up with ideas out of thin air, so we have to provide it with a first sentence as a prompt for its two-sentence horror story. I found that taking posts from r/Showerthoughts was interesting, because they were about the right length and produced some entertaining results (in my opinion). 

We can take advantage of the high turnover of r/Showerthoughts by sampling from hot posts. While many posts there are good right out of the box, some consist of two or more sentences, like [this post](https://www.reddit.com/r/Showerthoughts/comments/mpdyxy/dragons_dont_really_breathe_fire_they_just_exhale/?utm_source=share&utm_medium=web2x&context=3):


> **Dragons don't really breathe fire, they just exhale it.** That's like saying humans breathe carbon dioxide.

Since these sentences are just meant to be jumping-off points for our model, I decided to split the sampled posts on punctuation marks and only feed the **first sentence** to the model. 

In [None]:
showerthoughts = []
len_showerthoughts = 50 # won't necessarily get exactly this many, because of pinned posts

# Sample some showerthoughts from hot posts. only take the first sentence, put a period at the end of it.
for submission, rank in zip(reddit.subreddit('showerthoughts').hot(limit=len_showerthoughts), range(len_showerthoughts)):
  # Check to make sure it's not a mod-stickied post
  if not submission.stickied:
    # split into sentences and append the first sentence
    submission_sentences = re.split(r'[.!?;:]', submission.title)
    showerthoughts.append(submission_sentences[0].strip() + '.')

#Write some two-sentence horror stories!

`generate_story()` is a function that generates a two-sentence horror story based on a single line as input. I noticed that the stories generated where the beginnings of the two sentences were the same were more boring, so if that happens we try again, this time masking the first token. 

`generate_stories_json()` is a function that takes in a list of prompts and then generates a horror story for each prompt, dumping the final list of stories into a json file. The json file consists of a list of `{'story':str, 'seen':bool}` tuples, where `seen=False` to make it easy to randomly sample and post to Twitter later on. If `shorten_for_twitter=True`, the model will keep generating stories until they fit within Twitter's 280 character limit, or just abandon it if it can't get anything short enough.

Here, we hand over our list of r/showerthoughts posts as prompts for `generate_stories_json()` and then save as a file `twosentenceshower_currentdatetime.json`

In [None]:
# function that returns a string containing a two-sentence horror story based on a single line as input. 
# first_sentence is the first half of our horror story, and model_ is the fine-tuned BART model. 
# I found that noise_percent of 0.3 seemed to generate entertaining results most of the time.
def generate_story(first_sentence, model_, noise_percent = 0.3):
  # Put the model on eval mode
  model_.to(torch.device('cpu'))
  model_.eval()

  # prep output_story string and noise first_sentence
  output_story = ''
  output_story += first_sentence
  prompt_tokenized = tokenizer(noise_sentence(first_sentence, noise_percent), max_length = 32, return_tensors = "pt", truncation = True)

  # generate second_sentence from first_sentence. if
  second_sentence = model.generate_text(prompt_tokenized, eval_beams = 5)[0].strip()

  # if the first word of punchline is the same as first word of setup, try again but mask the first word of setup. 
  if second_sentence.split(' ')[0] == first_sentence.split(' ')[0]:
    np_first_sentence = np.array(first_sentence.split(' '))
    np_first_sentence[0] = '<mask>'
    first_sentence_masked = ' '.join(np_first_sentence)
    prompt_tokenized = tokenizer(noise_sentence(first_sentence_masked, noise_percent), max_length = 32, return_tensors = "pt", truncation = True)
    second_sentence = model.generate_text(prompt_tokenized, eval_beams = 5)[0].strip()

  # append to output and return it
  output_story += '\n\n' 
  output_story += second_sentence
  return output_story

# function that takes in a list of first sentences and generates a two-sentence horror story for each one
# returns a list of finalized stories as strings
def generate_stories(prompt_list, shorten_for_twitter=False):
  stories = []
  for prompt in prompt_list:
    story = generate_story(prompt, model)

    if shorten_for_twitter: # try to keep it under 280 characters
      safety_counter = 5 # try 5 times before giving up
      while len(story) > 280 and safety_counter > 0:
        story = generate_story(prompt, model)
        safety_counter -= 1
      if len(story) <= 280:
        stories.append(story)

    else: 
      stories.append(story)
  
  return stories


In [None]:
# Line of code to generate a single story from custom prompt
# print(generate_story('If you stare at a squirrel long enough, it will die.', model))

# Code to generate a bunch of stories from above collected showerthoughts
twosentenceshower = generate_stories(showerthoughts, shorten_for_twitter = True)
for story in twosentenceshower:
  print(story + '\n\n')

# Create dictionary and then dump it into json
tss_json_list = [{'story': s, 'seen': False} for s in twosentenceshower]
with open(datetime.now().strftime("twosentenceshower_%d_%m_%H%M.json"), 'w') as f:
  json.dump(tss_json_list, f)

<mask> likes to think that we are making order out of chaos, but in reality we are resisting the order of the universe because we find it unsatisfactory.
<mask> coined the saying, "Money can't buy happiness" never had to buy anti-depressants.
<mask> don't really breathe fire, they just exhale it.
<mask> a bottle of pills says Don't Operate Heavy Machinery While Taking, initially your mind goes to bull dozers not cars.
<mask> person born in the late 1900s sounds far older than a person born in 1997.
<mask> you stare at a Squirrel long enough, it will die.
<mask> humanity every colonizes other planets, physics students will have to memorize multiple gravitation constants.
<mask> are positioned in a way to where “get out the car” brings on a lot of suspicious movements.
<mask> adult telling a child to be careful when crossing the street is fairly normal, but an adult telling an adult to be careful crossing the street feels vaguely threatening.
<mask> all the mysteries we face in life, non

In [None]:
print(generate_story('I was working in the lab late one night.', model))

I was working in the lab late one night.

But I was working in the lab all night.
