# More on the GPT-2 model from KerasNLP

Next up, we will actually fine-tune the model to update its parameters, but before we do, let's take a look at the full set of tools we have to for working with for [GPT2](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/).

The code of GPT2 can be found [here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/). Conceptually the GPT2CausalLM can be hierarchically broken down into several modules in KerasNLP, all of which have a from_preset() function that loads a pretrained model:

[keras_nlp.models.GPT2Tokenizer](https://keras.io/api/keras_nlp/models/gpt2/gpt2_tokenizer#gpt2tokenizer-class): The tokenizer used by GPT2 model, which is a [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
[keras_nlp.models.GPT2CausalLMPreprocessor](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor#gpt2causallmpreprocessor-class): the preprocessor used by GPT2 causal LM training. It does the tokenization along with other preprocessing works such as creating the label and appending the end token.
[keras_nlp.models.GPT2Backbone](https://keras.io/api/keras_nlp/models/gpt2/gpt2_backbone#gpt2backbone-class): the GPT2 model, which is a stack of [keras_nlp.layers.TransformerDecoder](https://keras.io/api/keras_nlp/modeling_layers/transformer_decoder#transformerdecoder-class). This is usually just referred as GPT2.
[keras_nlp.models.GPT2CausalLM](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm#gpt2causallm-class): wraps GPT2Backbone, it multiplies the output of GPT2Backbone by embedding matrix to generate logits over vocab tokens.

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow" # "jax"  # or "tensorflow" or "torch"

import keras_nlp
import tensorflow as tf
import keras_core as keras
import time

In [None]:
# cuda_malloc_async has fewer fragmentation issues than the default BFC memory allocator - https://docs.nvidia.com/deeplearning/frameworks/tensorflow-user-guide/index.html#tf_gpu_allocator

os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'
print(os.getenv('TF_GPU_ALLOCATOR'))

# Load the model previously trained

In [None]:
#gpt2_lm  = tf.keras.models.load_model("../models/gpt2_lm_v1.keras")
gpt2_lm = keras.models.load_model('../models/gpt2_lm_v1.keras')

# Finetune on Reddit dataset

Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one step further to finetune the model so that it generates text in a specific style, short or long, strict or casual. In this tutorial, we will use reddit dataset for example.

In [None]:
%pip install tensorflow_datasets==4.9.* -q

In [None]:
import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

Let's take a look inside sample data from the reddit TensorFlow Dataset. There are two features:

document: text of the post.
title: the title.

In [None]:
for document, title in reddit_ds:
    print(document.numpy())
    print(title.numpy())
    break

In our case, we are performing next word prediction in a language model, so we only need the 'document' feature.

In [None]:
train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

Now you can finetune the model using the familiar fit() function. Note that preprocessor will be automatically called inside fit method since GPT2CausalLM is a keras_nlp.models.Task instance.

This step takes quite a bit of GPU memory and a long time if we were to train it all the way to a fully trained state. Here we just use part of the dataset for demo purposes.

In [None]:
train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

After fine-tuning is finished, you can again generate text using the same generate() function. This time, the text will be closer to Reddit writing style, and the generated length will be close to our preset length in the training set.

In [None]:
start = time.time()

output = gpt2_lm.generate("Red hat is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

# Save the fine-tuned GPT-2 model to object storage

You can save the model in different formats depending on how you intend to serve the model. In short, this save will enable us to do early online experimentation with the pre-trained model.

In [None]:
gpt2_lm.save('../models/gpt2_lm_v2.keras')
#gpt2_lm.save('../models/gpt2_lm_c1.tf')
#gpt2_lm.save('../models/gpt2_lm_v1.h5')

# Please Clear All Outputs and close the notebook before running 03_ notebook.