In [2]:
pip install gpt-2-simple

Collecting gpt-2-simple
  Downloading https://files.pythonhosted.org/packages/75/2f/4b2d933decca7f79e3ae2eb3859e2b30bb1f572634d2c84f925d765e3b8e/gpt_2_simple-0.6.tar.gz
Collecting regex
[?25l  Downloading https://files.pythonhosted.org/packages/e3/8e/cbf2295643d7265e7883326fb4654e643bfc93b3a8a8274d8010a39d8804/regex-2019.11.1-cp36-cp36m-manylinux1_x86_64.whl (643kB)
[K     |████████████████████████████████| 645kB 8.0MB/s 
Collecting toposort
  Downloading https://files.pythonhosted.org/packages/e9/8a/321cd8ea5f4a22a06e3ba30ef31ec33bea11a3443eeb1d89807640ee6ed4/toposort-1.5-py2.py3-none-any.whl
Building wheels for collected packages: gpt-2-simple
  Building wheel for gpt-2-simple (setup.py) ... [?25l[?25hdone
  Created wheel for gpt-2-simple: filename=gpt_2_simple-0.6-cp36-none-any.whl size=25388 sha256=9cc67b46cdb48c8a3ffa8436321bdd4887631e52b414d959efa844c08f5239e7
  Stored in directory: /root/.cache/pip/wheels/cc/e7/21/4cb10bcf085ff791a08bbd03aa3fd860f6e730f37b5dbbea28
Successful

In [3]:
from google.colab import drive

drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
%tensorflow_version 1.x

In [5]:
from tqdm import tqdm
import pandas as pd
import pickle
import math
import random
import os
from pathlib import Path
from collections import defaultdict
import shutil

import gpt_2_simple as gpt2

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
# Hyperparameters for training
MODEL_NAME = '124M'
SEED = 42
LR = 1e-4
BS = 1
GACCU_STEPS = 5
SAMPLE_STEPS = 1000
SAMPLE_LEN = 255
SAMPLE_NUM = 1
SAVE_STEPS = 1000
PRINT_STEPS = 100
TRAIN_STEPS = 40000

GENRES = ['Action', 'Adult', 'Adventure', 'Animation',
    'Biography', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family',
    'Fantasy', 'Game-Show', 'History', 'Horror', 'Lifestyle', 'Music',
    'Musical', 'Mystery', 'News', 'Reality-TV', 'Romance', 'Sci-Fi',
    'Short', 'Sport', 'Talk-Show', 'Thriller', 'War', 'Western']

BOS_TOKEN = '<|startoftext|>'
EOS_TOKEN = '<|endoftext|>'
EOG_TOKEN = '|0|'
EOT_TOKEN = '|1|'

In [0]:
# Dumps
NUM_SAMPLES_PER_GENRE = 10
# The temperature is “creativity” and allows the network to more likely make suboptimal predictions
TEMPS = [0.7, 1, 1.3]
# https://github.com/minimaxir/gpt-2-simple/issues/51
TOP_PS = [0, 0.9]

In [0]:
# Only if re-finetuning
!rm -rf "$CHECKPOINT_DIR"

In [0]:
# Remember to define the model dir in GDrive for persistence if using Colab
MODEL_DIR = Path("drive/My Drive/Colab Notebooks/transformers/MoviePlots/text_generation/GPT-2-gpt2simple")

# Prepared data
DATA_FILE = Path("drive/My Drive/Colab Notebooks/transformers/MoviePlots/data/data.pkl")

# Cache
# Our custom tokenizer with special tokens will be stored here
CACHE_DIR = MODEL_DIR/'cache'
CACHE_DIR.mkdir(exist_ok=True)
# Inputs will be stored here
TXT_FILE = CACHE_DIR/'data.txt'

# Checkpoints will be stored here
CHECKPOINT_DIR = MODEL_DIR/'checkpoint'
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Dumps will be stored here
DUMPS_DIR = MODEL_DIR/"dumps"
DUMPS_DIR.mkdir(exist_ok=True)

In [13]:
!ls "$MODEL_DIR"

cache  checkpoint  dumps  Notebook.ipynb


## Prepare data

In [0]:
with open(DATA_FILE, 'rb') as handle:
    data = pickle.load(handle)

In [0]:
# Each plot has multiple genres
# But GPT-2 would learn more effectively if there was one-to-one relationship
# Two options: filter out plots with multiple genres, or duplicate them for each genre

def single_genre_augment(data):
    """For each plot with multiple genres, duplicate it for each genre."""
    augmented_data = []
    pbar = tqdm(total=len(GENRES)*len(data))
    for genre in GENRES:
        for d in data:
            if genre in d['genres']:
                single_genre_dict = dict(d)
                single_genre_dict['genre'] = genre.lower()
                augmented_data.append(single_genre_dict)
            pbar.update()
    pbar.close()
    return augmented_data

In [17]:
data = single_genre_augment(data)

100%|██████████| 3281460/3281460 [00:02<00:00, 1240872.45it/s]


In [0]:
def encode_data(data):
    """For each dict in data, join information and generate a line."""
    lines = []
    for item in tqdm(data):
        lines.append(BOS_TOKEN + item['genre'] + EOG_TOKEN + item['title'] + EOT_TOKEN + item['plot'] + EOS_TOKEN)
    return lines

In [19]:
lines = encode_data(data)

100%|██████████| 252490/252490 [00:00<00:00, 574982.70it/s]


In [20]:
pd.Series(lines)

0         <|startoftext|>action|0|".hack//SIGN" (2002) {...
1         <|startoftext|>action|0|".hack//SIGN" (2002) {...
2         <|startoftext|>action|0|"10,000 Days" (2010)|1...
3         <|startoftext|>action|0|"10th Muse" (2012)|1| ...
4         <|startoftext|>action|0|"18 Wheels of Justice"...
                                ...                        
252485    <|startoftext|>western|0|"Zorro" (1990) {To Be...
252486    <|startoftext|>western|0|"Zorro" (1990) {Ultim...
252487    <|startoftext|>western|0|"Zorro" (1990) {Where...
252488    <|startoftext|>western|0|"Zorro" (1990) {Wicke...
252489    <|startoftext|>western|0|"Zorro" (1990) {Zorro...
Length: 252490, dtype: object

In [0]:
def save_to_txt(lines):
    """Save lines into a TXT file."""

    with open(TXT_FILE, 'w', encoding='utf8', errors='ignore') as w:
        for line in tqdm(lines):
            w.write(line + "\n")

In [22]:
save_to_txt(lines)

100%|██████████| 252490/252490 [00:00<00:00, 309643.60it/s]


## Train model

In [23]:
!nvidia-smi

Tue Nov 19 23:39:42 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [24]:
gpt2.download_gpt2(model_name=MODEL_NAME)

Fetching checkpoint: 1.05Mit [00:00, 489Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 110Mit/s]                                                    
Fetching hparams.json: 1.05Mit [00:00, 400Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 498Mit [00:01, 267Mit/s]                                   
Fetching model.ckpt.index: 1.05Mit [00:00, 345Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 147Mit/s]                                                 
Fetching vocab.bpe: 1.05Mit [00:00, 154Mit/s]                                                       


In [0]:
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
              str(TXT_FILE),
              steps=TRAIN_STEPS,
              model_name=MODEL_NAME,
              batch_size=BS,
              learning_rate=LR,
              accumulate_gradients=GACCU_STEPS,
              sample_every=SAMPLE_STEPS,
              sample_length=SAMPLE_LEN,
              sample_num=SAMPLE_NUM,
              save_every=SAVE_STEPS,
              print_every=PRINT_STEPS,
              overwrite=True,
              checkpoint_dir=str(CHECKPOINT_DIR))

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Loading checkpoint models/124M/model.ckpt
INFO:tensorflow:Restoring parameters from models/124M/model.ckpt


  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset...


100%|██████████| 1/1 [03:41<00:00, 221.29s/it]


dataset has 42951177 tokens
Training...
Saving drive/My Drive/Colab Notebooks/transformers/MoviePlots/text_generation/GPT-2-gpt2simple/checkpoint/run1/model-0
[100 | 130.99] loss=2.93 avg=2.93
[200 | 255.39] loss=2.95 avg=2.94
[300 | 379.86] loss=3.06 avg=2.98
[400 | 504.24] loss=3.01 avg=2.99
[500 | 628.56] loss=2.96 avg=2.98
[600 | 752.89] loss=3.00 avg=2.99
[700 | 877.28] loss=3.23 avg=3.02
[800 | 1001.68] loss=3.34 avg=3.06
[900 | 1125.96] loss=2.84 avg=3.04
[1000 | 1250.23] loss=2.81 avg=3.01
Saving drive/My Drive/Colab Notebooks/transformers/MoviePlots/text_generation/GPT-2-gpt2simple/checkpoint/run1/model-1000
Instructions for updating:
Use standard file APIs to delete files with this prefix.
nered into a dangerous situation.<|endoftext|>
<|startoftext|>adventure|0|"Mum's Up" (2013) {Mum's Up (#2.2)}|1| After her father's death, Sarah gets help from a new friend and the situation quickly becomes complicated, especially when she becomes a suspect in a police robbery. Sarah has to

In [0]:
!cp -r "models/$MODEL_NAME/hparams.json" "$CHECKPOINT_DIR/run1"
!cp -r "models/$MODEL_NAME/encoder.json" "$CHECKPOINT_DIR/run1"

## Generate dumps

In [0]:
sess = gpt2.start_tf_sess()

gpt2.load_gpt2(sess, checkpoint_dir=CHECKPOINT_DIR)

Loading checkpoint drive/My Drive/Colab Notebooks/transformers/MoviePlots/text_generation/GPT-2-gpt2simple/checkpoint/run1/model-30000
INFO:tensorflow:Restoring parameters from drive/My Drive/Colab Notebooks/transformers/MoviePlots/text_generation/GPT-2-gpt2simple/checkpoint/run1/model-30000


In [0]:
def generate_dump(prompt, num_samples, temp, top_p):
    """Generate a dump of samples.
    Inspired by https://github.com/minimaxir/hacker-news-gpt-2"""
    
    samples = gpt2.generate(sess,
                            checkpoint_dir=str(CHECKPOINT_DIR),
                            truncate=False,
                            prefix=prompt,
                            seed=SEED,
                            nsamples=num_samples,
                            batch_size=num_samples,
                            length=256,
                            temperature=temp,
                            top_p=top_p,
                            include_prefix=True,
                            return_as_list=True)
    return samples

In [0]:
generate_dump(BOS_TOKEN + 'horror' + EOG_TOKEN, 10, 1, 0)

['<|startoftext|>horror~#~"Life After" (2009) {Romancing the Feat (#1.3)}$$$ Eddie meets "The Bullfighter" - a school bully who likes all his boys (Because boys need bullies too) - before building a beef between the two. The bullies find his nerdy younger brother, Jake, and move in with the older girls. The older girls like Jake, but Eddie\'s smitten by Jake. Actually Jake\'s rhyming anko is well earned.<|endoftext|>\n<|startoftext|>War~#~"The Tudors" (2007) {Queen\'s Own (#1.1)}$$$ In her own right, the Queen succumbs at first, but through the Tudor Succession, she grows just as wrong and has a private court trial only to win an appeal. Henry Tudor spends most of his time in France tending to his woman, Mary Boleyn sticks to her man and Jane Is foretold in marriage. In the south and west, the church is founded and the blame for the Spanish Civil War is laid at the feet of James I. The catholic led \'Titanic balance\' between church and state is broken when Cromwell is elected Holy Ger

In [0]:
def save_dumps():
    """Generate dumps of samples and save them to disk."""
    pbar = tqdm(total=len(GENRES)*len(TEMPS)*len(TOP_PS), desc="Dumps")
    for temp in TEMPS:
        for top_p in TOP_PS:
            samples = []
            for genre in GENRES:
                prompt = BOS_TOKEN + genre.lower() + EOG_TOKEN
                dump = generate_dump(prompt, NUM_SAMPLES_PER_GENRE, temp, top_p)
                for sample in dump:
                    samples.append(prompt + sample.split(EOS_TOKEN)[0] + "\n")
                pbar.update()
            fn = 'temp_%s_topp_%s.txt' % (str(temp).replace('.', '_'), 
                                          str(top_p).replace('.', '_'))
            with open(DUMPS_DIR/fn, 'w') as f:
                f.writelines(samples)
    pbar.close()

In [0]:
save_dumps()


Dumps:   0%|          | 0/168 [00:00<?, ?it/s][A
Dumps:   1%|          | 1/168 [00:51<2:24:25, 51.89s/it][A

KeyboardInterrupt: ignored