<a href="https://colab.research.google.com/github/sedmegreenaway/gpt-2-colaboratory/blob/master/GPT_2_Colaboratory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup
The following cell sets up everything needed for GPT-2 to run.

In [0]:
!mkdir gpt-2
!git clone https://github.com/openai/gpt-2.git gpt-2
!pip3 install tensorflow-gpu==1.14
!pip3 install -r gpt-2/requirements.txt
!python3 gpt-2/download_model.py 124M
!python3 gpt-2/download_model.py 355M
!python3 gpt-2/download_model.py 774M
!python3 gpt-2/download_model.py 1558M
!cp gpt-2/src/encoder.py ./encoder.py
!cp gpt-2/src/sample.py ./sample.py
!cp gpt-2/src/model.py ./model.py
!rm -rf gpt-2/

# Runner
While you can just use commands to run python files retrieved from git during setup, it's easier to do it via a cell, because the cell can have forms.

In [0]:
#@title Interactive Conditional Samples
import fire
import json
import os
import sys
import numpy as np
import tensorflow as tf

import model, sample, encoder

def interact_model(#@markdown model_name: String, which model to use
model_name = '124M', #@param ["'124M',", "'355M',", "'774M',", "'1558M',"] {type:"raw"}
    #@markdown seed: Integer seed for random number generators, fix seed to reproduce results
    seed=None,#@param
    #@markdown nsamples: Number of samples to return total
    nsamples=1,#@param
    #@markdown batch_size: Number of batches (only affects speed/memory).  Must divide nsamples.
    batch_size=1,#@param
    #@markdown length: Number of tokens in generated text, if None (default), is determined by model hyperparameters
    length=None,#@param
    #@markdown temperature: Float value controlling randomness in boltzmann distribution. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.
    temperature=1,#@param
    #@markdown top_k: Integer value controlling diversity. 1 means only 1 word is considered for each step (token), resulting in deterministic completions, while 40 means 40 words are considered at each step. 0 (default) is a special setting meaning no restrictions. 40 generally is a good value.
    top_k=40,#@param
    top_p=1,
    #@markdown models_dir: path to parent folder containing model subfolders (i.e. contains the <model_name> folder)
    models_dir='models',#@param
):
    """
    Interactively run the model
    :model_name=124M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
     :models_dir : path to parent folder containing model subfolders
     (i.e. contains the <model_name> folder)
    """
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        #@markdown interactive: Do you want to be able to provide additional model prompts after the first?
        interactive = False #@param {type:"boolean"}
        has_given_response = False

        while ((not has_given_response) or interactive):
            if (has_given_response and interactive):
                raw_text = input("Model prompt >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("Model prompt >>> ")
            has_given_response = True
            #@markdown raw_text: Model prompt
            raw_text = "Hello." #@param {type:"string"}
            if not raw_text:
                print('Prompt should not be empty!')
                sys.exit()
            context_tokens = enc.encode(raw_text)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                    print(text)
            print("=" * 80)


if __name__ == '__main__':
    fire.Fire(interact_model)