# Using GPT-2 for summarising text

This notebook is heavily based on [Google's IO 2023 workshop notebook](https://colab.research.google.com/github/tensorflow/codelabs/blob/main/KerasNLP/io2023_workshop.ipynb), which demonstrates the use of KerasNLP to load a pre-trained GPT-2 model, fine-tune it to a specific text style, and convert it to the TensorFlow Lite format.

A lot more details are available in the workshop's notebook.

Keep in mind that fine-tuning and using the model requires quite a bit of memory. Fine-tuning requires more than 10GB RAM and some 12GB of GPU RAM. It is recommended to use the TensorFlow Lite model in devices with at least 4G of RAM.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Dependencies and imports

First, we need to load [KerasNLP](https://keras.io/keras_nlp/) into our environment and import all dependencies.

In [1]:
!pip install -q git+https://github.com/keras-team/keras-nlp.git@google-io-2023 tensorflow-text==2.12

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [1]:
import numpy as np
import keras_nlp
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from tensorflow import keras
from tensorflow.lite.python import interpreter
import time

  from .autonotebook import tqdm as notebook_tqdm


## Load the pre-trained GPT-2 model

We now load the pre-trained GPT-2 model from TensorFlow's respository.

In [2]:
gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en")
gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=256,
    add_end_token=True,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en", preprocessor=gpt2_preprocessor)

tl_dr = tf.constant(' TL;DR: ')
max_tokens = 512



Once the GPT-2 model is loaded, we can verify that it is working properly by asking it to generate some text from a suitable prompt.

In [None]:
output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output.numpy().decode("utf-8"))

## Fine-tuning GPT-2

### Selecting the training data

We can try to use the GPT-2 model to summarise texts, but it is also interesting to fine-tune it to a specific text and summary style.

The CNN and Daily Mail data set contains news pieces from these organisations together with their summaries.

We begin by loading the data set.

In [3]:
cnn_ds = tfds.load('cnn_dailymail', as_supervised=True)

In their “Language Models are Unsupervised Multitask Learners” paper, Radford et al mention that they used the TL;DR: token to elicit summarising behaviour from GPT-2. We'll follow the same approach when fine-tuning the model by combining each article with its highlights using "TL;DR:".

We will use only a subset of the CNN and Daily Mail data set: the entries whose combination of the article and the highlights do not exceed 512 tokens as determined by GPT-2's tokeniser.

Let's see which entry from the dataset is the first to fulfil this requirement.

In [28]:
for article, highlights in cnn_ds['train']:
  combination = article + tl_dr + tf.strings.regex_replace(highlights, "\n", " ")
  tokens = gpt2_tokenizer.tokenize([str(combination.numpy())])
  token_count = tokens.flat_values.shape[0]
  if token_count < max_tokens:
    print(token_count)
    print(combination.numpy())
    break

291
b"By. Associated Press. PUBLISHED:. 14:11 EST, 25 October 2013. |. UPDATED:. 15:36 EST, 25 October 2013. The bishop of the Fargo Catholic Diocese in North Dakota has exposed potentially hundreds of church members in Fargo, Grand Forks and Jamestown to the hepatitis A virus in late September and early October. The state Health Department has issued an advisory of exposure for anyone who attended five churches and took communion. Bishop John Folda (pictured) of the Fargo Catholic Diocese in North Dakota has exposed potentially hundreds of church members in Fargo, Grand Forks and Jamestown to the hepatitis A. State Immunization Program Manager Molly Howell says the risk is low, but officials feel it's important to alert people to the possible exposure. The diocese announced on Monday that Bishop John Folda is taking time off after being diagnosed with hepatitis A. The diocese says he contracted the infection through contaminated food while attending a conference for newly ordained bis

Now it is time to create the data subset.

This operation may take a very long time, which is why we made a pre-processed subset available in GitHub. You may skip the following preparation code and load it directly using one of the following code sections.

In [29]:
import progressbar

short_texts = []
total = len(cnn_ds['train'])
progressbar_update_freq = 1000
count = 0
used = 0

widgets = [' [',
         progressbar.Timer(format= 'elapsed time: %s'),
         '] ',
           progressbar.Bar('*'),' (',
           progressbar.ETA(), ') ',
          ]
bar = progressbar.ProgressBar(
    maxval=total // progressbar_update_freq + 2,
    widgets=widgets).start()

for article, highlights in cnn_ds['train']:
  combination = article + tl_dr + tf.strings.regex_replace(highlights, "\n", " ")
  tokens = gpt2_tokenizer.tokenize([str(combination.numpy())])
  token_count = tokens.flat_values.shape[0]
  if token_count < max_tokens:
    short_texts.append(combination)
    used += 1
  count += 1
  if count % progressbar_update_freq == 0:
    bar.update(count / progressbar_update_freq)

print(f'Processed {count} articles of which {used} were used (had a token count smaller than {max_tokens}).')

 [elapsed time: 1 day, 6:50:46] |*************************** | (ETA:  0:12:53) 

Processed 287113 articles of which 39411 were used (had a token count smaller than 512).


Let's define a pair of helper functions to load and save the data subset.

In [30]:
def save_texts(texts):
    np.savez('data/selected_texts.npz', texts)

def load_texts():
    restored_texts = list()
    with np.load('data/selected_texts.npz', allow_pickle=True) as data:
      for file in data.files:
        restored_texts.extend(data[file].tolist())
    return restored_texts

If you ran the previous data subset preparation, it is always a good idea to save the list of short combinations of articles and summaries (sort of a checkpoint).

In [31]:
save_texts(short_texts)

The pre-prepared data subset can be loaded from GitHub:

In [6]:
!mkdir data
!wget https://github.com/vveloso/ai-in-practice-talk/raw/main/gpt-2/data/selected_texts.npz -O data/selected_texts.npz

short_texts = load_texts()

mkdir: cannot create directory ‘data’: File exists
--2023-10-11 12:16:30--  https://github.com/vveloso/ai-in-practice-talk/raw/main/gpt-2/data/selected_texts.npz
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vveloso/ai-in-practice-talk/main/gpt-2/data/selected_texts.npz [following]
--2023-10-11 12:16:30--  https://raw.githubusercontent.com/vveloso/ai-in-practice-talk/main/gpt-2/data/selected_texts.npz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5477897 (5.2M) [application/octet-stream]
Saving to: ‘data/selected_texts.npz’


2023-10-11 12:16:30 (106 MB/s) - ‘data/selected_texts.npz’ saved [

### Running the fine-tuning training steps

We begin by preparing the data set. The data subset selected in the previous steps is now pre-processed.

In [8]:
tf_train_ds = tf.data.Dataset.from_tensor_slices(short_texts)
processed_ds = tf_train_ds.map(gpt2_preprocessor, tf.data.AUTOTUNE).batch(20).cache().prefetch(tf.data.AUTOTUNE)
part_of_ds = processed_ds.take(2000)

The model is now trained for two epochs.

This step may take a while. You may skip this step and load the weights made available at GitHub using one of the following code sections.

In [9]:
gpt2_lm.include_preprocessing = False

num_epochs = 2

lr = tf.keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=part_of_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

gpt2_lm.compile(
    optimizer=keras.optimizers.experimental.Adam(lr),
    loss=loss,
    weighted_metrics=["accuracy"])

gpt2_lm.fit(part_of_ds, epochs=num_epochs)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7c0ae04bfb20>

Let's try to ask the fine-tuned model to summarise a short news item.

In [9]:
gpt2_lm.generate("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR: ", max_length=200)

<tf.Tensor: shape=(), dtype=string, numpy=b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR The fire at Terminal 2 is "significant" airport said. The airport said it will be shut until at least 3 p.m. local time. The blaze occurred in a parking lot in a Terminal 2 parking lot.'>

It's always a good idea to save the fine-tuned weights.

In [12]:
gpt2_lm.backbone.save_weights("data/finetuned_model.h5")

The fine-tuned weights are also available at GitHub, and they can be loaded using the following snippets.

In [None]:
!mkdir data
!wget https://github.com/vveloso/ai-in-practice-talk/releases/download/20231105/finetuned_model.h5 -O data/finetuned_model.h5

In [5]:
gpt2_lm.backbone.load_weights("data/finetuned_model.h5")

Let's release some memory.

In [6]:
del gpt2_tokenizer, gpt2_preprocessor, tf_train_ds, processed_ds, part_of_ds

## Converting the model to the TensorFlow Lite format

Before we can use the model in a mobile application, it needs to be converted to the TensorFlow Lite format.

A TensorFlow function is created to simplify using the model with a fixed output of 200 tokens.

In [26]:
@tf.function
def generate(prompt, max_length):
    return gpt2_lm.generate(prompt, max_length)

concrete_func = generate.get_concrete_function(tf.TensorSpec([], tf.string), 200)

We new define a helper function to test the TensorFlow Lite models once they are created.

In [27]:
def run_inference(input, generate_tflite):
  interp = interpreter.InterpreterWithCustomOps(
      model_content=generate_tflite,
      custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
  interp.get_signature_list()

  generator = interp.get_signature_runner('serving_default')
  output = generator(prompt=np.array([input]))
  print("\nGenerated with TFLite:\n", output["output_0"])

Conversion of the GPT-2 model to TensorFlow Lite requires the use of TensorFlow operands not normally included; they need to be specified during the conversion. 

In [28]:
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            gpt2_lm)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.allow_custom_ops = True
converter.target_spec.experimental_select_user_tf_ops = ["UnsortedSegmentJoin", "UpperBound"]
converter._experimental_guarantee_all_funcs_one_use = True
generate_tflite = converter.convert()




Generated with TFLite:
 b"I'm enjoying a great weekend in London with friends and family. I've been looking forward to getting back to my hometown for the first time since the end of the World Cup. The weather is nice, there are no issues, and I'm feeling pretty safe and comfortable. I'll be staying at a lovely hotel with my wife and two young boys, who are both from Manchester City and Chelsea. The weather is good and the sun is setting, so I'm feeling really good! \xc2\xa0It"


We now save the converted model.

In [29]:
with open('unquantized_gpt2.tflite', 'wb') as f:
  f.write(generate_tflite)

Let's see how big it is.

In [30]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 478M Oct 11 13:18 unquantized_gpt2.tflite


The converted TensorFlow Lite model is still very large. We can quantise the TensorFlow Lite model during conversion to reduce its size.

In [31]:
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            gpt2_lm)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.allow_custom_ops = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.experimental_select_user_tf_ops = ["UnsortedSegmentJoin", "UpperBound"]
converter._experimental_guarantee_all_funcs_one_use = True
quant_generate_tflite = converter.convert()




Generated with TFLite:
 b"I'm enjoying a lot of things: reading my weekly weekly weekly newsletter, and then following up with some of your favorite\xc2\xa0quotes. See my written daily Journalist column, which features weekly written written written written for the Mail Mail.  \xc2\xa0You can\xc2\xa0quiz my weekly weekly written columns:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0Quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\n\xc2\xa0quiz:\xc2\xa0"


We now save this new version and check its size. You'll see that it is considerably smaller.

In [None]:
with open('quantized_gpt2.tflite', 'wb') as f:
  f.write(quant_generate_tflite)

In [None]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 124M Oct 11 13:24 quantized_gpt2.tflite
-rw-r--r-- 1 root root 478M Oct 11 13:18 unquantized_gpt2.tflite


Let's try both models out. First, the quantised version and then the non-quantised version. Both should generate acceptable summaries.

In [32]:
run_inference("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR ", quant_generate_tflite)


Generated with TFLite:
 b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR  Underground fire sparks at Terminal 2 parking lot. The blaze started around 6 p.m., the airport said. The blaze occurred at Terminal 2 parking lot'


In [33]:
run_inference("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR ", generate_tflite)


Generated with TFLite:
 b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR  Airport shut down in London after a fire in terminal 2 parking lot. A "significant" fire broke out in Terminal 2 parking lot at Luton Airport'
