<a href="https://colab.research.google.com/github/robgon-art/PlotJam/blob/master/PlotJam_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **PlotJam Training**
Code for fine-tuning GPT-2 to generate plot summaries.
![alt text](https://raw.githubusercontent.com/robgon-art/gtp2-plot-generation/master/images/the_story_medium.jpg)
Photo illustration based on a photo by alexkerhead CC By 2.0

In [None]:
# Download and unzip the CMU Book Summary Dataset 
!wget -O booksummaries.tar.gz http://www.cs.cmu.edu/~dbamman/data/booksummaries.tar.gz
!tar -xf booksummaries.tar.gz

# Import support for CSV files and the JSON format
import csv
import json

# Initialize the genre dictionary
genre_groups = {}

# Create and open the output file
plot_file = open("story_plots.txt", "w", encoding="utf-8")

# Process the summaries to get the genre, title, and the plot summary
with open('booksummaries/booksummaries.txt', newline='', encoding='utf-8') as f:
  reader = csv.reader(f, delimiter='\t')
  for row in reader:

    # Get the genre
    genre_string = row[5]
    if len(genre_string) == 0:
      continue

    # Parse the genres associated with the book
    genre_dict = json.loads(genre_string)
    genre_dict= {k: v for k, v in sorted(genre_dict.items(), key=lambda item: item[1])}
    genre = ""
    for key in genre_dict:
      genre += genre_dict[key] + ", "
    genre = genre[:-2]

    # Add the genre to the the dictionary
    if not genre in genre_groups:
      genre_groups[genre] = 1;
    else:
      genre_groups[genre] += 1

    # Get the title
    title = row[2]

    # Get the plot, and keep the first part
    plot = row[6][:500]
    plot = plot.rsplit(' ', 1)[0] + " ..."

    # Write the fields out to the output file
    entry = 'GENRE: ' + genre + ' TITLE: ' + title + ' PLOT:' + plot
    plot_file.write(entry + '\n')
plot_file.close()

In [None]:
# Use TensorFlow 1.15
%tensorflow_version 1.x

# Install GPT-2, download the large model, and start the session
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
model = "774M" # 124M 355M 774M 1558M
gpt2.download_gpt2(model_name=model)
sess = gpt2.start_tf_sess()

# Run the fine-tuning for 10,000 steps
gpt2.finetune(sess, dataset='story_plots.txt', model_name=model, steps=10000,
  restore_from='fresh', run_name='run1', print_every=10, sample_every=1000,
  save_every=10000)

# Zip up the results
!zip -r plot_jam.zip checkpoint