%%

In [1]:
!pip install -q git+https://github.com/keras-team/keras-nlp.git@google-io-2023 tensorflow-text==2.12

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import numpy as np
import keras_nlp
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from tensorflow import keras
from tensorflow.lite.python import interpreter
import time

%%

In [3]:
gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset("gpt2_base_en")
gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=256,
    add_end_token=True,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en", preprocessor=gpt2_preprocessor)



%%

In [None]:
output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output.numpy().decode("utf-8"))

%%

In [None]:
output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output.numpy().decode("utf-8"))

%%

In [None]:
cnn_ds = tfds.load('cnn_dailymail', as_supervised=True)

%

In [4]:
from nltk import tokenize

import nltk
nltk.download('punkt')

for article, highlights in cnn_ds['train']:
  combination = article + tf.constant(' TL;DR ') + tf.strings.regex_replace(highlights, "\n", " ")
  word_count=len(tokenize.word_tokenize(str(combination.numpy())))
  if word_count < 256:
    print(combination.numpy())
    print(article.numpy())
    print(highlights.numpy())
    break

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


KeyboardInterrupt: ignored

%%

In [None]:
import progressbar

short_texts = []
total = len(cnn_ds['train'])
progressbar_update_freq = 1000
count = 0

widgets = [' [',
         progressbar.Timer(format= 'elapsed time: %s'),
         '] ',
           progressbar.Bar('*'),' (',
           progressbar.ETA(), ') ',
          ]
bar = progressbar.ProgressBar(
    maxval=total // progressbar_update_freq + 2,
    widgets=widgets).start()

In [None]:
for article, highlights in cnn_ds['train']:
  combination = article + tf.constant(' TL;DR ') + tf.strings.regex_replace(highlights, "\n", " ")
  word_count = len(tokenize.word_tokenize(str(combination.numpy())))
  if word_count < 256:
    short_texts.append(combination)
  count += 1
  if count % progressbar_update_freq == 0:
    bar.update(count / progressbar_update_freq)

%%

In [4]:
def save_texts(texts):
    np.savez('data/selected_texts.npz', texts)

def load_texts():
    restored_texts = list()
    with np.load('data/selected_texts.npz', allow_pickle=True) as data:
      for file in data.files:
        restored_texts.extend(data[file].tolist())
    return restored_texts

Save the list of short combinations of articles and summaries (sort of a checkpoint).

In [5]:
save_texts(short_texts)

NameError: ignored

%%

In [6]:
!mkdir data
!wget https://github.com/vveloso/ai-in-practice-talk/raw/main/gpt-2/data/selected_texts.npz -O data/selected_texts.npz

short_texts = load_texts()

mkdir: cannot create directory ‘data’: File exists
--2023-10-11 12:16:30--  https://github.com/vveloso/ai-in-practice-talk/raw/main/gpt-2/data/selected_texts.npz
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vveloso/ai-in-practice-talk/main/gpt-2/data/selected_texts.npz [following]
--2023-10-11 12:16:30--  https://raw.githubusercontent.com/vveloso/ai-in-practice-talk/main/gpt-2/data/selected_texts.npz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5477897 (5.2M) [application/octet-stream]
Saving to: ‘data/selected_texts.npz’


2023-10-11 12:16:30 (106 MB/s) - ‘data/selected_texts.npz’ saved [

%%

In [8]:
tf_train_ds = tf.data.Dataset.from_tensor_slices(short_texts)
processed_ds = tf_train_ds.map(gpt2_preprocessor, tf.data.AUTOTUNE).batch(20).cache().prefetch(tf.data.AUTOTUNE)
part_of_ds = processed_ds.take(200)

%%

In [9]:
gpt2_lm.include_preprocessing = False

num_epochs = 2

lr = tf.keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=part_of_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

gpt2_lm.compile(
    optimizer=keras.optimizers.experimental.Adam(lr),
    loss=loss,
    weighted_metrics=["accuracy"])

gpt2_lm.fit(part_of_ds, epochs=num_epochs)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7c0ae04bfb20>

%%

In [9]:
gpt2_lm.generate("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR ", max_length=200)

<tf.Tensor: shape=(), dtype=string, numpy=b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR The fire at Terminal 2 is "significant" airport said. The airport said it will be shut until at least 3 p.m. local time. The blaze occurred in a parking lot in a Terminal 2 parking lot.'>

In [10]:
gpt2_lm.generate("The House GOP's two candidates for speaker detailed their plans during a closed-door meeting on Tuesday for avoiding a government shutdown - a key issue for members, and one that sank Kevin McCarthy's speakership. House Majority Leader Steve Scalise and Judiciary Chairman Jim Jordan made their pitches during the Tuesday meeting ahead of a conference vote for speaker on Wednesday, but GOP lawmakers made clear that the conference remains divided, and there's a heavy dose of skepticism among Republicans that they will quickly coalesce around either candidate to be the next speaker. TL;DR ", max_length=200)

<tf.Tensor: shape=(), dtype=string, numpy=b"The House GOP's two candidates for speaker detailed their plans during a closed-door meeting on Tuesday for avoiding a government shutdown - a key issue for members, and one that sank Kevin McCarthy's speakership. House Majority Leader Steve Scalise and Judiciary Chairman Jim Jordan made their pitches during the Tuesday meeting ahead of a conference vote for speaker on Wednesday, but GOP lawmakers made clear that the conference remains divided, and there's a heavy dose of skepticism among Republicans that they will quickly coalesce around either candidate to be the next speaker. TL;DR  Speaker Paul Ryan's two contenders for speaker outline their plan. The two candidates are expected to address a closed-door conference vote for Speaker Paul Ryan. The House GOP's two candidates for speaker detailed their plans during closed-door meetings Tuesday.">

In [12]:
gpt2_lm.backbone.save_weights("finetuned_model.h5")

In [5]:
gpt2_lm.backbone.load_weights("finetuned_model.h5")

In [6]:
del gpt2_tokenizer, gpt2_preprocessor

In [26]:
@tf.function
def generate(prompt, max_length):
    return gpt2_lm.generate(prompt, max_length)

concrete_func = generate.get_concrete_function(tf.TensorSpec([], tf.string), 100)

In [27]:
def run_inference(input, generate_tflite):
  interp = interpreter.InterpreterWithCustomOps(
      model_content=generate_tflite,
      custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
  interp.get_signature_list()

  generator = interp.get_signature_runner('serving_default')
  output = generator(prompt=np.array([input]))
  print("\nGenerated with TFLite:\n", output["output_0"])

In [28]:
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            gpt2_lm)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.allow_custom_ops = True
converter.target_spec.experimental_select_user_tf_ops = ["UnsortedSegmentJoin", "UpperBound"]
converter._experimental_guarantee_all_funcs_one_use = True
generate_tflite = converter.convert()
run_inference("I'm enjoying a", generate_tflite)




Generated with TFLite:
 b"I'm enjoying a great weekend in London with friends and family. I've been looking forward to getting back to my hometown for the first time since the end of the World Cup. The weather is nice, there are no issues, and I'm feeling pretty safe and comfortable. I'll be staying at a lovely hotel with my wife and two young boys, who are both from Manchester City and Chelsea. The weather is good and the sun is setting, so I'm feeling really good! \xc2\xa0It"


In [29]:
with open('unquantized_gpt2.tflite', 'wb') as f:
  f.write(generate_tflite)

In [30]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 478M Oct 11 13:18 unquantized_gpt2.tflite


In [31]:
gpt2_lm.jit_compile = False
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func],
                                                            gpt2_lm)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.allow_custom_ops = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.experimental_select_user_tf_ops = ["UnsortedSegmentJoin", "UpperBound"]
converter._experimental_guarantee_all_funcs_one_use = True
quant_generate_tflite = converter.convert()
run_inference("I'm enjoying a", quant_generate_tflite)




Generated with TFLite:
 b"I'm enjoying a lot of things: reading my weekly weekly weekly newsletter, and then following up with some of your favorite\xc2\xa0quotes. See my written daily Journalist column, which features weekly written written written written for the Mail Mail.  \xc2\xa0You can\xc2\xa0quiz my weekly weekly written columns:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0Quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\xc2\xa0quiz:\n\xc2\xa0quiz:\xc2\xa0"


In [32]:
run_inference("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR ", quant_generate_tflite)


Generated with TFLite:
 b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR  Underground fire sparks at Terminal 2 parking lot. The blaze started around 6 p.m., the airport said. The blaze occurred at Terminal 2 parking lot'


In [33]:
run_inference("All flights have been suspended in London's Luton Airport following the breakout of a \"significant\" fire in the airport's Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR ", generate_tflite)


Generated with TFLite:
 b'All flights have been suspended in London\'s Luton Airport following the breakout of a "significant" fire in the airport\'s Terminal 2 parking lot, the airport said in a statement on Wednesday. The airport said it would be closed until at least 3 p.m. local time, with passengers advised not to travel to the airport. TL;DR  Airport shut down in London after a fire in terminal 2 parking lot. A "significant" fire broke out in Terminal 2 parking lot at Luton Airport'


In [34]:
with open('quantized_gpt2.tflite', 'wb') as f:
  f.write(quant_generate_tflite)

In [35]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 124M Oct 11 13:24 quantized_gpt2.tflite
-rw-r--r-- 1 root root 478M Oct 11 13:18 unquantized_gpt2.tflite


In [36]:
!mv quantized_gpt2.tflite summarise.tflite

In [25]:
del quant_generate_tflite, generate_tflite