In [1]:
! pip install git+https://github.com/keras-team/keras-hub.git -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m71.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m100.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for keras-hub (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.18.0 which is incompatible.[0m[31m
[0m

Large Language Models are complex to build and expensive to train from scratch. Luckily there are pretrained LLMs available for use right away. KerasHub provides a large number of pre-trained checkpoints that allow you to experiment with SOTA models without needing to train them yourself.

KerasHub is a natural language processing library that supports users through their entire development cycle. KerasHub offers both pretrained models and modularized building blocks, so developers could easily reuse pretrained models or stack their own LLM.

In a nutshell, for generative LLM, KerasHub offers:

Pretrained models with generate() method, e.g., keras_hub.models.GPT2CausalLM and keras_hub.models.OPTCausalLM.
Sampler class that implements generation algorithms such as Top-K, Beam and contrastive search. These samplers can be used to generate text with custom models.

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_hub
import keras
import tensorflow as tf
import time

keras.mixed_precision.set_global_policy("mixed_float16")

KerasHub provides a number of pre-trained models, such as Google Bert and GPT-2. You can see the list of models available in the KerasHub repository.

It's very easy to load the GPT-2 model as you can see below:

In [3]:
# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/config.json...


100%|██████████| 431/431 [00:00<00:00, 494kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/tokenizer.json...


100%|██████████| 618/618 [00:00<00:00, 721kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/assets/tokenizer/vocabulary.json...


100%|██████████| 0.99M/0.99M [00:01<00:00, 658kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/assets/tokenizer/merges.txt...


100%|██████████| 446k/446k [00:01<00:00, 438kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/model.weights.h5...


100%|██████████| 475M/475M [00:31<00:00, 16.0MB/s]


In [4]:
start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")


GPT-2 output:
My trip to Yosemite was the most interesting part of my trip. The first time I went, it was the first time I ever spent in Yosemite. It was the first time I ever went out on my own. The only time I've ever been to Yosemite was during my first day of hiking and my first day of camping. The first time I ever went to Yosemite was during my first day of hiking and my first day of camping.

I was a little surprised to find that the only time I went to Yosemite was in the winter months. I was not in the winter months. I was not hiking in the fall or winter months. I was in the winter months, and I didn't even think of the winter months as winter months. I was not in Yosemite. My only time there was in the winter months was when I was at my first campground.

I was surprised to find that the only time I ever went to Yosemite was during my first day of hiking
TOTAL TIME ELAPSED: 9.18s


Try another one:

In [5]:
start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")


GPT-2 output:
That Italian restaurant is called "Bella Bella" in Italy and is the place where the "Bella Bella" restaurant was founded in 1885.

"It was a very important Italian restaurant," said Italian restaurant historian and author Giuseppe Giorgi. "The name was given because it was an Italian restaurant and not because it was a good restaurant."

The restaurant was opened in 1885, and its name changed in the 1920s.

"It's a very popular place in Florence, but we don't see any restaurants here," said Giuseppe. "The place is still in a very good condition, and it's a good place. The restaurant is a good restaurant and is a good restaurant in a good state. But we have a problem with the fact the restaurant is a good restaurant. It's not good in Italy, because the Italians were always very good at Italian food, so the Italians are the ones who made it."

TOTAL TIME ELAPSED: 1.70s


Now you have the knowledge of the GPT-2 model from KerasHub, you can take one step further to finetune the model so that it generates text in a specific style, short or long, strict or casual. In this tutorial, we will use reddit dataset for example.

In [8]:
!# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git

Cloning into 'chinese-poetry'...
remote: Enumerating objects: 7326, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 7326 (delta 4), reused 2 (delta 2), pack-reused 7319 (from 2)[K
Receiving objects: 100% (7326/7326), 236.98 MiB | 14.48 MiB/s, done.
Resolving deltas: 100% (5005/5005), done.
Updating files: 100% (2285/2285), done.


In [9]:
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

In [10]:
type(paragraphs)

list

In [11]:
len(paragraphs)

311855

In [12]:
paragraphs[0]

'半依籬脚半依城，多傍梅邊水際亭。最是晚晴斜照裏，黄金日射萬銀星。'

Convert to TF dataset, and only use partial data to train

In [13]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

In [14]:
type(train_ds)

In [15]:
%%time

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m120s[0m 180ms/step - accuracy: 0.2564 - loss: 2.5333
CPU times: user 2min 16s, sys: 2.95 s, total: 2min 19s
Wall time: 2min 3s


<keras.src.callbacks.history.History at 0x7d55bc2bb610>

In [16]:
output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)

昨夜雨疏风骤，曾風臺知時秦。頭聞頭求書書，秋風雲樹自香。
