# More on the GPT-2 model from KerasNLP

Next up, we will actually fine-tune the model to update its parameters, but before we do, let's take a look at the full set of tools we have to for working with for [GPT2](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/).

The code of GPT2 can be found [here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/). Conceptually the GPT2CausalLM can be hierarchically broken down into several modules in KerasNLP, all of which have a from_preset() function that loads a pretrained model:

[keras_nlp.models.GPT2Tokenizer](https://keras.io/api/keras_nlp/models/gpt2/gpt2_tokenizer#gpt2tokenizer-class): The tokenizer used by GPT2 model, which is a [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
[keras_nlp.models.GPT2CausalLMPreprocessor](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor#gpt2causallmpreprocessor-class): the preprocessor used by GPT2 causal LM training. It does the tokenization along with other preprocessing works such as creating the label and appending the end token.
[keras_nlp.models.GPT2Backbone](https://keras.io/api/keras_nlp/models/gpt2/gpt2_backbone#gpt2backbone-class): the GPT2 model, which is a stack of [keras_nlp.layers.TransformerDecoder](https://keras.io/api/keras_nlp/modeling_layers/transformer_decoder#transformerdecoder-class). This is usually just referred as GPT2.
[keras_nlp.models.GPT2CausalLM](https://keras.io/api/keras_nlp/models/gpt2/gpt2_causal_lm#gpt2causallm-class): wraps GPT2Backbone, it multiplies the output of GPT2Backbone by embedding matrix to generate logits over vocab tokens.

In [None]:
%pip install pip -U -q
%pip install -r ../requirements.txt -q

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"  # "jax"  # or "tensorflow" or "torch"

import keras_nlp
import tensorflow as tf
import keras_core as keras
import time

In [None]:
# cuda_malloc_async has fewer fragmentation issues than the default BFC memory allocator - https://docs.nvidia.com/deeplearning/frameworks/tensorflow-user-guide/index.html#tf_gpu_allocator

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
print(os.getenv("TF_GPU_ALLOCATOR"))

# Load the model previously trained

|Preset name|Compile Time|
|-----------|-----------|
|gpt2_base_en	| 	|
|gpt2_medium_en	| 6m 36.8s	|
|gpt2_large_en	|	|
|gpt2_extra_large_en	|	|
|gpt2_base_en_cnn_dailymail	|	|

In [None]:
# this load takes about 
gpt2_lm = keras.models.load_model("../models/gpt2_lm.keras")

# Finetune on custom dataset

Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one step further to finetune the model so that it generates text in a specific style, short or long, strict or casual. In this tutorial, we will use reddit dataset for example.

Here is a list of of some [abstract text summarizations from tf datasets](https://www.tensorflow.org/datasets/catalog/overview#abstractive_text_summarization)

|Dataset|Description|Download Size|Download Time|
|-------|----------|----------|----------|
|aeslc|A collection of email messages of employees in the Enron Corporation.|11.10M|    |
|billsum| summarization of US Congressional and California state bills.|64.14M| |
|reddit| preprocessed posts from the Reddit dataset consists of 3,848,330 posts with an average length of 270 words for content, and 28 words for the summary.|2.9G| |
|reddit_tifu| Reddit dataset, where TIFU denotes the name of subbreddit /r/tifu.|639.54M|   |
|scientific_papers|scientific papers datasets are obtained from ArXiv and PubMed OpenAccess repositories. |4.2G|    |

*download speed from hotel WiFi at 25Mbps*

In [None]:
custom_ds = "scientific_papers"

IMPORTANT: you will want to update line 27 in your `devfile.yaml`           
`memory: 48Gi # you will want at least 48Gi`

In [None]:
import tensorflow_datasets as tfds

source_ds = tfds.load(custom_ds, split="train", as_supervised=True)

Let's take a look inside sample data from the reddit TensorFlow Dataset. There are two features:

document: text of the post.
title: the title.

In [None]:
for document, title in source_ds:
    print(document.numpy())
    print(title.numpy())
    break

In our case, we are performing next word prediction in a language model, so we only need the 'document' feature.

In [None]:
train_ds = (
    source_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

Now you can finetune the model using the familiar fit() function. Note that preprocessor will be automatically called inside fit method since GPT2CausalLM is a keras_nlp.models.Task instance.

This step takes quite a bit of GPU memory and a long time if we were to train it all the way to a fully trained state. Here we just use part of the dataset for demo purposes.

In [None]:
train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

After fine-tuning is finished, you can again generate text using the same generate() function. This time, the text will be closer to scientific papers writing style, and the generated length will be close to our preset length in the training set.

In [None]:
start = time.time()

output = gpt2_lm.generate("Red hat is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

# Save the fine-tuned GPT-2 model to object storage

You can save the model in different formats depending on how you intend to serve the model. In short, this save will enable us to do early online experimentation with the pre-trained model.

In [None]:
# Local storage
filename = "../models/gpt2_lm.keras"

gpt2_lm.save(filename)
# gpt2_lm.save('../models/gpt2_lm.h5')

In [None]:
def convert_model_to_onnx(model):
    import onnx
    import tf2onnx

    export_filename = filename.replace('.keras', '.onnx')

    proto, _ = tf2onnx.convert.from_keras(model)
    onnx.save(proto, export_filename)

convert_model_to_onnx(filename)

## Save to S3 Object Storage (Minio)

Lets use the NVIDIA Triton model folder structure to store the saved models

Triton model folder structure:

```
models (provide this dir as source / MODEL_REPOSITORY )
└─ [ model name ]
    └─ 1 (version)
        └── model.savedmodel (we will use .keras)
            ├── saved_model.pb
```

In [None]:
# install requirements

%pip install -U boto3 python-dotenv -q

In [None]:
# assuming Minio is deployed, populate the environment variables

!  echo "AWS_S3_BUCKET=${AWS_S3_BUCKET:-models}" > .env
!  echo "AWS_S3_ENDPOINT=${AWS_S3_ENDPOINT:-http://minio.minio.svc:9000}" >> .env
!  echo "AWS_ACCESS_KEY_ID=$(oc -n minio extract secret/minio-root-user --keys=MINIO_ROOT_USER --to=-)" >> .env
!  echo "AWS_SECRET_ACCESS_KEY=$(oc -n minio extract secret/minio-root-user --keys=MINIO_ROOT_PASSWORD --to=-)" >> .env

In [None]:
# import the packages

import os, boto3
from dotenv import load_dotenv

load_dotenv()

In [None]:
# upload the model from local storage to S3

local_path = "../models"
remote_path = "gpt2/2"

bucket = os.getenv("AWS_S3_BUCKET", "models")

s3 = boto3.client(
    "s3",
    endpoint_url=os.getenv("AWS_S3_ENDPOINT", "http://minio.minio.svc:9000"),
    aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID", "minioadmin"),
    aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", "minioadmin"),
)


if bucket not in [bu["Name"] for bu in s3.list_buckets()["Buckets"]]:
    s3.create_bucket(Bucket=bucket)


def uploadDirectory(path, bucketname):
    for root, dirs, files in os.walk(path):
        for file in files:
            print(f"uploading: {file} to {bucket}/{remote_path}")
            s3.upload_file(
                os.path.join(root, file), bucketname, f"{remote_path}/{file}"
            )
            print("[ok]")


uploadDirectory(path=local_path, bucketname=bucket)

# Please Clear All Outputs and close the notebook before running 03_ notebook.