In [1]:
# %pip install transformers torch

In [2]:
# Imports

# Core Imports
import torch
import os

# Model-related Imports
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
from transformers import pipeline # restore punct

### Load Model

In [3]:
# Double-check current working directory
print(os.getcwd()) 

/Users/avocado/Developer/Projects/two-sentence-horror-lm/two-sentence-horror-lm/models/bart


In [4]:
# Instantiate model to restore punctuation

punct_model_path = "felflare/bert-restore-punctuation"

punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)

punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)

In [5]:
# Path to model dir
model_path = 'voacado/bart-two-sentence-horror'

# Load tokenizer
tokenizer = BartTokenizer.from_pretrained(model_path)
# Load model
model = BartForConditionalGeneration.from_pretrained(model_path)

Downloading tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/999k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/957 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/1.82k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

### Set up for Inference

In [6]:
# Set the model to evaluation mode
model.eval()

# If GPU, use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

### Inference

In [7]:
def restore_punctuation(text, restorer):
    # Use the model to predict punctuation
    punctuated_output = restorer(text)
    punctuated_text = []
    
    # Define punctuation marks (note: not including left-side because we want space still)
    punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
    
    for elem in punctuated_output:
        cur_token = elem.get('word')
        
        # If token is punctuation, append to previous token
        if cur_token in punctuation_marks:
            punctuated_text[-1] += cur_token
            
        # If previous token is quotations, append to previous token
        elif punctuated_text and punctuated_text[-1] in ["'", "’", "“", "‘", "‘‘", "““"]:
            punctuated_text[-1] += cur_token
            
        # If token is a contraction or a quote, append to previous token (no space)
        elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
            # Remove space for contractions
            punctuated_text[-1] += cur_token
            
        # if prediction is LABEL_0, token should be capitalized
        elif elem.get('entity') == 'LABEL_0':
            punctuated_text.append(cur_token.capitalize())

        # else if prediction is LABEL_1, token should be lowercase
        # elif elem.get('entity') == 'LABEL_1':
        else:
            punctuated_text.append(cur_token)
            
    # If there's no period at the end of the story, add one
    if punctuated_text[-1][-1] != '.':
        punctuated_text[-1] = punctuated_text[-1] + '.'

    return ' '.join(punctuated_text)

In [8]:
def generate_text(input_text, model, tokenizer, max_length=50):
    # Encode the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

    # Generate text
    with torch.no_grad():
        output_ids = model.generate(input_ids, max_length=max_length)

    # Decode the generated text
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # Restore punctuation
    generated_text_punct = restore_punctuation(generated_text, punct_restorer)
    
    return generated_text_punct


In [9]:
# Example usage
input_sentence = "I heard a noise from the bathroom."
generated_sentence = generate_text(input_sentence, model, tokenizer)
print(input_sentence + ' ' + generated_sentence)


I heard a noise from the bathroom. It was only when I turned on the lights that I realized the noise was coming from the bathroom.


In [10]:
# Switch to eval mode
model.eval()

input_sentences = ['I got out of bed this morning.',
                   'I was horrified when I got my test results back.',
                   'My parents told me not to go upstairs.',
                   'There was a knock on the door.',
                   'I was walking home from school.',
                   'My friend told me to go to the bathroom.',
                   'There was a loud noise coming from the basement.',
                   'There was a ghost.',
                   'I heard someone whispering in my ear.'
]
generated_stories = []

for input_sentence in input_sentences:
    generated_text = generate_text(input_sentence, model, tokenizer)
    generated_stories.append(input_sentence + ' ' + generated_text)


In [11]:
for generated_story in generated_stories:
    print(generated_story)
    print('\n')

I got out of bed this morning. When I woke up, I saw my reflection in the mirror.


I was horrified when I got my test results back. It was only when I got home that I realized they weren’t human.


My parents told me not to go upstairs. I don’t know what’s worse, the fact that I’m the only one down here, or that I can hear them screaming.


There was a knock on the door. It was the only way I could get out of the basement.


I was walking home from school. But when I turned around, I saw a man with a knife in his hand.


My friend told me to go to the bathroom. I didn’t expect him to come back.


There was a loud noise coming from the basement. It was only when I turned on the lights that I realized the noise wasn't coming from the basement.


There was a ghost. It was the only thing keeping me alive.


I heard someone whispering in my ear. I thought it was just a hall ##uc ##ination, until I heard a voice whisper back, “don’t worry, you’re not alone.”.


