resource: https://minimaxir.com/2019/09/howto-gpt2/




This file is ran on Google Colab, and uses the file accessing features that is not relevant if running on device.

## Basic

In [2]:
%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

from gpt_2_simple.src import model, sample, encoder, memory_saving_gradients
import os
import json
import numpy as np
import tensorflow as tf

TensorFlow 1.x selected.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [3]:
gpt2.mount_gdrive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## To Train

In [4]:
# gpt2.download_gpt2(model_name="124M")

In [5]:
'''
file_name = "result.txt"
trained_run_name = "run1"
'''

'\nfile_name = "result.txt"\ntrained_run_name = "run1"\n'

In [6]:
# gpt2.copy_file_from_gdrive(file_name)

In [7]:
'''
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='124M',
              steps=1000,
              restore_from='fresh',
              run_name=trained_run_name,
              print_every=10,
              sample_every=200,
              save_every=500
              )
'''

"\nsess = gpt2.start_tf_sess()\n\ngpt2.finetune(sess,\n              dataset=file_name,\n              model_name='124M',\n              steps=1000,\n              restore_from='fresh',\n              run_name=trained_run_name,\n              print_every=10,\n              sample_every=200,\n              save_every=500\n              )\n"

In [8]:
# gpt2.copy_checkpoint_to_gdrive(run_name=trained_run_name)

## Load Model

If already have the model on google drive, can load it in with the following:

In [9]:
loaded_run_name='355_10mb'
# loaded_run_name='124M'

In [10]:
gpt2.copy_checkpoint_from_gdrive(run_name=loaded_run_name)

In [11]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name=loaded_run_name)

Loading checkpoint checkpoint/355_10mb/model-1000
INFO:tensorflow:Restoring parameters from checkpoint/355_10mb/model-1000


## Run

In [12]:
'''
gpt2.generate(sess,
              length=100,
              temperature=0.7,
              nsamples=5,
              batch_size=5,
              run_name=loaded_run_name,
              prefix='<html>'
              )
'''


"\ngpt2.generate(sess,\n              length=100,\n              temperature=0.7,\n              nsamples=5,\n              batch_size=5,\n              run_name=loaded_run_name,\n              prefix='<html>'\n              )\n"

## Assess

https://github.com/gpt2ent/gpt-2-simple/blob/652fdab80131ce83f8f1b6fd00f597dd48ae2e36/gpt_2_simple/gpt_2.py#L552


In [13]:

def get_logits(sess,
             run_name='run1',
             checkpoint_dir='checkpoint',
             model_name=None,
             model_dir='models',
             prefix="<|endoftext|>",
             all=False):
    batch_size=1

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)
    hparams = model.default_hparams()
    with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if prefix:
        context = tf.compat.v1.placeholder(tf.int32, [batch_size, None])
        context_tokens = enc.encode(prefix)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens,
                                past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(
            hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    output = step(hparams, context)

    out = sess.run(output, feed_dict={
                    context: batch_size * [context_tokens]
                })

    if all:
        return out['logits'][0, :, :]  # all logits starting from the second token, n logits for n tokens
    return out['logits'][0, -1, :]  # logits for next token


In [14]:
def get_perplexity(sess,
               run_name='run1',
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix="<|endoftext|>",
               continuation="Hello"):
    
    """
    Returns perplexity score for given continuation of a given prefix.
    
    Examples:
    perplexity(sess, model_name="124M", prefix="Hello, my name is", continuation=" James Smith, I am an engineer")  # returns 17.3124
    perplexity(sess, model_name="124M", prefix="Hello, my name is", continuation=" very else whatever general cat meow.")  # returns 5197.99
    """

    batch_size=1

    if model_name:
        checkpoint_path = os.path.join(model_dir, model_name)
    else:
        checkpoint_path = os.path.join(checkpoint_dir, run_name)

    enc = encoder.get_encoder(checkpoint_path)

    context_tokens = enc.encode(prefix)

    context_size = len(context_tokens)
    continuation_tokens = enc.encode(continuation)

    full_sentence = prefix+continuation

    logits = get_logits(sess, run_name, checkpoint_dir, model_name, model_dir, full_sentence, all=True)

    logits = logits[context_size-1:-1, :]  # only continuation logits
    logitmeans = np.mean(logits, axis=1)
    logits = logits - logitmeans[:, None]
    explogits = np.exp(logits)
    probabs = explogits / np.sum(explogits,axis=1)[:, None]
    
    probab_scores = np.nan_to_num([probabs[i, index] for i, index in enumerate(continuation_tokens)])
    perplexity = 2 ** (-np.mean(np.log2(probab_scores)))
    return perplexity


## Run on Test Data

In [15]:
'''
ans = get_perplexity(sess,
               run_name=loaded_run_name,
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix="<html>",
               continuation=" <body> </body> <html>")
'''

'\nans = get_perplexity(sess,\n               run_name=loaded_run_name,\n               checkpoint_dir=\'checkpoint\',\n               model_name=None,\n               model_dir=\'models\',\n               prefix="<html>",\n               continuation=" <body> </body> <html>")\n'

In [16]:
from google.colab import files
import os
from google.colab import drive
drive.mount('/content/drive')
import math
import datetime #for development

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
def per_line_assess(fileName):
  '''
  assess the perplexity of each html line (excluding one word code line like <br/>)
  uses count%10 to assess a range of lines (taking the first 20 will have more frequent overlap)
  for one file there could be 1000 lines so take only a sample of it
  '''
  with open(fileName, 'r') as f:
    target_src = f.read() 
    x = target_src.split('\n')
    ret_val = []
    count = 0
    for i in x:
      i.strip()
      listI = i.split()
      if len(listI) > 1:
        count += 1
        if count%50 != 0:
          continue
        pref = ' '.join(listI[:math.ceil(len(listI)/2)])
        continu = ' '.join(listI[math.ceil(len(listI)/2):])
        perplexity = get_perplexity(sess,
               run_name=loaded_run_name,
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix=pref,
               continuation=continu)
        ret_val.append([pref, continu, perplexity])
    return ret_val
      


In [18]:
def per_file_assess(fileName):
  '''
  Measure perplexity of the whole html file. Add each token added ongoing prefix
  Error when file too large (see below seciton on error)
  '''
  file_prefix = ""
  ret_val = []
  count = 0
  with open(fileName, 'r') as f:
    target_src = f.read() 
    x = target_src.split('\n')
    for i in x:
      if count >= 100: # takes 
        break
      i.strip()
      listI = i.split()
      if len(listI) > 1:
        line_pref = ' '.join(listI[:math.ceil(len(listI)/2)])
        file_prefix = file_prefix + ' ' + line_pref
        continu = ' '.join(listI[math.ceil(len(listI)/2):])
        perplexity = get_perplexity(sess,
               run_name=loaded_run_name,
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix=file_prefix,
               continuation=continu)
        ret_val.append([file_prefix, continu, perplexity])
        file_prefix = file_prefix + ' ' + continu 
        count+=1
      else:
        file_prefix = file_prefix + ' '.join(listI)

    return ret_val


In [None]:
target_file = "/content/drive/My Drive/P-Web/Testing_Files"


i = 0
for file in os.scandir(target_file):
  i += 1
  if i >= 10: 
    break
  # calculated_val = str(per_file_assess(os.path.join(target_file,file))) 
  calculated_val = str(per_line_assess(os.path.join(target_file,file)))
  with open(f"/content/drive/My Drive/P-Web/model_measure_355_{file}.txt", 'w') as writefile:
    writefile.write(calculated_val) 
  print(os.path.join(target_file,file))
  print(datetime.datetime.utcnow().strftime("%a, %d %B %Y %H:%M:%S"))





/content/drive/My Drive/P-Web/Testing_Files/modified_1c-bitrix.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_115.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_21cn.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_2gis.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_7news.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_263.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_4shared.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_10fastfingers.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_6abc.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_500px.txt
/content/drive/My Drive/P-Web/Testing_Files/modified_300.txt
Fri, 07 May 2021 05:50:43


## Error

There is error when the prefix is too long; Eg. below used almost 1 file (in entierity) to test perplexity, but returned error

In [20]:

# testPref = '''<html lang=""> <head> <meta charset=""/> <meta content="" name=""/> <link href="" rel="profile"/> <link href="" media="" rel="stylesheet"/> <title> </title> <meta content="" name=""/> <meta content="" name=""/> <link href="" rel="canonical"/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" property=""/> <meta content="" name=""/> <meta content="" name=""/> <meta content="" name=""/> <meta content="" name=""/> <script class="" type="application/ld+json"> </script> <link href="" rel="dns-prefetch"/> <link href="" rel="dns-prefetch"/> <link href="" id="" media="" rel="stylesheet" type="text/css"/> <style id="" type="text/css"> </style> <script src="" type="text/javascript"> </script> <link href="" rel="alternate" type="application/json+oembed"/> <link href="" rel="alternate" type="text/xml+oembed"/> <meta content="" name=""/> <link href="" rel="icon" sizes=""/> <link href="" rel="icon" sizes=""/> <link href="" rel="apple-touch-icon-precomposed"/> <meta content="" name=""/> </head> <body class=""> <header class=""> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <form action="" class="" method="" role="search"> <label> <span class=""> </span> <input class="" name="" placeholder="" type="search" value=""/> </label> <input class="" type="submit" value=""/> </form> </div> </div> <div class=""> <div class=""> <ul class=""> <li> <a href=""> </a> </li> <li> <a href=""> </a> </li> <li> <a href=""> </a> </li> </ul> </div> </div> </div> </div> </div> </div> <div class=""> <div class=""> <h1 class=""> <a href="" rel="home"> </a> </h1> <h5 class=""> </h5> </div> </div> <div class=""> <div class=""> <nav class="" id="" role="navigation"> <div class=""> <ul class="" id=""> <li class="" id=""> <a aria-current="page" href=""> </a> </li> <li class="" id=""> <a href=""> </a> <ul class=""> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> </ul> </li> <li class="" id=""> <a href=""> </a> <ul class=""> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> </ul> </li> <li class="" id=""> <a href=""> </a> <ul class=""> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> </ul> </li> <li class="" id=""> <a href=""> </a> <ul class=""> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> <li class="" id=""> <a href=""> </a> </li> </ul> </li> </ul> </div> </nav> </div> </div> </div> </header> <div class=""> <div class=""> <div class="" style=""> <div class=""> <nav aria-label="" class="" itemprop="" role="navigation"> <ul class="" itemscope="" itemtype=""> <meta content="" name=""/> <meta content="" name=""/> <li class="" itemprop="" itemscope="" itemtype=""> <span itemprop=""> </span> <meta content="" itemprop=""/> </li> </ul> </nav> </div> <div class=""> </div> </div> </div> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <article class="" id=""> <div class=""> <div class=""> <img alt="" class="" height="827" sizes="" src="" srcset="" width="1240"/> </div> <div class=""> <div class=""> <h2> </h2> </div> <div class=""> <p> <a href="" rel="noopener noreferrer" target=""> <strong> </strong> </a> <strong> </strong> <strong> </strong> </p> <div class="" data-=""> <div class=""> <div class=""> <b class=""> </b> <span class=""> <a class="" data-="" href=""> </a> </span> </div> <div class=""> <div class=""> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> <div class=""> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> </div> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> <div class=""> <a class="" href=""> <span class=""> </span> <span class=""> </span> </a> </div> </div> </div> </div> </div> <h2> <span id=""> </span> </h2> <p> </p> <ul> <li> <strong> </strong> </li> <li> </li> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <h3> <span id=""> </span> </h3> <p> </p> <ul> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> </ul> <h3> <span id=""> </span> </h3> <p> </p> <ul> <li> </li> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <ul> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <ul> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <ul> <li> </li> <li> </li> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <ul> <li> </li> <li> </li> <li> </li> </ul> <h2> <span id=""> </span> </h2> <p> </p> <h2> <span id=""> </span> </h2> <p> </p> </div> </div> </div> </article> </div> </div> <div class=""> <aside class=""> <div class=""> <section class="" id=""> <div class=""> <div class=""> <h2> </h2> </div> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <img alt="" class="" height="421" sizes="" src="" srcset="" width="748"/> <div class=""> </div> <div class=""> <div class=""> <h5> <a href=""> </a> </h5> </div> </div> </div> <div class=""> <img alt="" class="" height="421" sizes="" src="" srcset="" width="748"/> <div class=""> </div> <div class=""> <div class=""> <h5> <a href=""> </a> </h5> </div> </div> </div> <div class=""> <img alt="" class="" height="421" sizes="" src="" srcset="" width="748"/> <div class=""> </div> <div class=""> <div class=""> <h5> <a href=""> </a> </h5> </div> </div> </div> <div class=""> <img alt="" class="" height="421" sizes="" src="" srcset="" width="748"/> <div class=""> </div> <div class=""> <div class=""> <h5> <a href=""> </a> </h5> </div> </div> </div> <div class=""> <img alt="" class="" height="421" sizes="" src="" srcset="" width="748"/> <div class=""> </div> <div class=""> <div class=""> <h5> <a href=""> </a> </h5> </div> </div> </div> </div> <div class=""> </div> </div> </div> </div> </div> </section> <section class="" id=""> <div class=""> <div class=""> <h2> </h2> </div> <ul> <li class=""> <a href="" title=""> </a> </li> <li class=""> <a href="" title=""> </a> </li> <li class=""> <a href="" title=""> </a> </li> <li class=""> <a href="" title=""> </a> </li> </ul> </div> </section> </div> </aside> </div> </div> </div> </div> </div> </div> <footer class=""> <div class=""> </div> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <div class=""> <p> </p> </div> </div> <div class=""> </div> </div> </div> </div> </div> </footer> <script type="text/javascript">'''

# testCont = '''</script> <style id="" type="text/css"> </style> <script type="text/javascript"> </script> <script type="text/javascript"> </script> <script defer="" src=""> </script> </body> </html>'''
'''
get_perplexity(sess,
               run_name=loaded_run_name,
               checkpoint_dir='checkpoint',
               model_name=None,
               model_dir='models',
               prefix=testPref,
               continuation=testCont)
'''

"\nget_perplexity(sess,\n               run_name=loaded_run_name,\n               checkpoint_dir='checkpoint',\n               model_name=None,\n               model_dir='models',\n               prefix=testPref,\n               continuation=testCont)\n"