# Install KerasNLP, Choose Backend and Import Dependencies

This examples uses [Keras Core](https://keras.io/keras_core/) to work in any of "tensorflow", "jax" or "torch". Support for Keras Core is baked into KerasNLP, simply change the "KERAS_BACKEND" environment variable to select the backend of your choice. We select the JAX backend below.

Source tutorial: https://keras.io/examples/generative/gpt2_text_generation_with_kerasnlp/

In [None]:
%pip install pip -U -q
%pip install tensorflow~=2.13.1 keras-nlp==0.6.2 -q

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"  # "jax"  # or "tensorflow" or "torch"

import keras_nlp
import tensorflow as tf
import keras_core as keras
import time

In [None]:
# cuda_malloc_async has fewer fragmentation issues than the default BFC memory allocator - https://docs.nvidia.com/deeplearning/frameworks/tensorflow-user-guide/index.html#tf_gpu_allocator

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
print(os.getenv("TF_GPU_ALLOCATOR"))

# Load the model previously trained

In [None]:
gpt2_lm = keras.models.load_model("../models/gpt2_lm_v2.keras")

# Into the Sampling Method
In KerasNLP, we offer a few sampling methods, e.g., contrastive search, Top-K and beam sampling. By default, our GPT2CausalLM uses Top-k search, but you can choose your own sampling method.

Much like optimizer and activations, there are two ways to specify your custom sampler:

Use a string identifier, such as "greedy", you are using the default configuration via this way.
Pass a [keras_nlp.samplers.Sampler](https://keras.io/api/keras_nlp/samplers/samplers#sampler-class) instance, you can use custom configuration via this way.

For more details on KerasNLP Sampler class, you can check the code [here](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/samplers).

# Finetune on Chinese Poem Dataset

We can also finetune GPT2 on non-English datasets. For readers knowing Chinese, this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our model to become a poet!

Because GPT2 uses byte-pair encoder, and the original pretraining dataset contains some Chinese characters, we can use the original vocab to finetune on Chinese dataset.

In [None]:
# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git

Load text from the json file. We only use《全唐诗》for demo purposes.

In [None]:
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

Let's take a look at sample data.

In [None]:
print(paragraphs[0])

Similar as Reddit example, we convert to TF dataset, and only use partial data to train.

In [None]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

Let's check the result! Copy the results into [Google Translate](https://translate.google.com/)

In [None]:
# "Red Hat is" translated to Chinese is "红帽是" using https://translate.google.com/
output = gpt2_lm.generate("红帽是", max_length=200)
print(output)

# Save the fine-tuned GPT-2 model to object storage

You can save the model in different formats depending on how you intend to serve the model. In short, this save will enable us to do early online experimentation with the pre-trained model.

In [None]:
gpt2_lm.save("../models/gpt2_lm_v3.keras")
# gpt2_lm.save('../models/gpt2_lm_v1.h5')