# More on the GPT-2 model from KerasNLP

Next up, we will actually fine-tune the model to update its parameters, but before we do, let's take a look at the full set of tools we have to for working with for [GPT2](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/).

The code of GPT2 can be found [here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/). Conceptually the GPT2CausalLM can be hierarchically broken down into several modules in KerasNLP, all of which have a from_preset() function that loads a pretrained model:

[keras_nlp.models.GPT2Tokenizer](https://keras.io/api/keras_nlp/models/gpt2/gpt2_tokenizer#gpt2tokenizer-class): The tokenizer used by GPT2 model, which is a [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
[keras_nlp.models.GPT2CausalLMPreprocessor](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor#gpt2causallmpreprocessor-class): the preprocessor used by GPT2 causal LM training. It does the tokenization along with other preprocessing works such as creating the label and appending the end token.
[keras_nlp.models.GPT2Backbone](https://keras.io/api/keras_nlp/models/gpt2/gpt2_backbone#gpt2backbone-class): the GPT2 model, which is a stack of [keras_nlp.layers.TransformerDecoder](https://keras.io/api/keras_nlp/modeling_layers/transformer_decoder#transformerdecoder-class). This is usually just referred as GPT2.
[keras_nlp.models.GPT2CausalLM](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm#gpt2causallm-class): wraps GPT2Backbone, it multiplies the output of GPT2Backbone by embedding matrix to generate logits over vocab tokens.

In [1]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow" # "jax"  # or "tensorflow" or "torch"

import keras_nlp
import tensorflow as tf
import keras_core as keras
import time

Using TensorFlow backend


2023-10-17 17:15:01.981564: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-17 17:15:03.683586: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
gpt2_lm  = tf.keras.saving.load_model('../models/gpt2_lm_v1.keras')

  instance.compile_from_config(compile_config)


# Finetune on Reddit dataset

Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one step further to finetune the model so that it generates text in a specific style, short or long, strict or casual. In this tutorial, we will use reddit dataset for example.

In [4]:
%pip install tensorflow_datasets==4.9.* -q

[0mNote: you may need to restart the kernel to use updated packages.


In [5]:
import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

  from .autonotebook import tqdm as notebook_tqdm


Let's take a look inside sample data from the reddit TensorFlow Dataset. There are two features:

document: text of the post.
title: the title.

In [6]:
for document, title in reddit_ds:
    print(document.numpy())
    print(title.numpy())
    break

b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we a

2023-10-17 17:15:32.848630: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In our case, we are performing next word prediction in a language model, so we only need the 'document' feature.

In [7]:
train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

Now you can finetune the model using the familiar fit() function. Note that preprocessor will be automatically called inside fit method since GPT2CausalLM is a keras_nlp.models.Task instance.

This step takes quite a bit of GPU memory and a long time if we were to train it all the way to a fully trained state. Here we just use part of the dataset for demo purposes.

In [None]:
train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

After fine-tuning is finished, you can again generate text using the same generate() function. This time, the text will be closer to Reddit writing style, and the generated length will be close to our preset length in the training set.

In [None]:
start = time.time()

output = gpt2_lm.generate("Red hat is ", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")