# Use Keras-NLP for various nlp tasks


Inference with a pretrained classifier

Fine tuning a pretrained backbone

Fine tuning with user-controlled preprocessing

Fine tuning a custom modes  


## Introduction

A natural language processing library called KerasNLP helps users at every stage of the development process. Our workflows are constructed from modular parts that, when utilized, have cutting-edge predefined weights and architectures.

unconventional and are readily adaptable to require additional control.

This library is an extension of the core Keras API; all high-level modules are
[`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras,
congratulations! You already understand most of KerasNLP.

KerasNLP uses Keras 3 to work with any of TensorFlow, Pytorch and Jax. In the
guide below, we will use the `jax` backend for training our models, and
[tf.data](https://www.tensorflow.org/guide/data) for efficiently running our
input preprocessing. But feel free to mix things up! This guide runs in
TensorFlow or PyTorch backends with zero changes, simply update the
`KERAS_BACKEND` below.

This demonstrates modular approach using a sentiment analysis example at 4
levels of complexity:

* Inference with a pretrained classifier
* Fine tuning a pretrained backbone
* Fine tuning with user-controlled preprocessing
* Fine tuning a custom model



In [1]:
pip install --upgrade tf-keras



In [44]:
!pip install -q --upgrade tensorflow==2.16.1

In [3]:
!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras  # Upgrade to Keras 3.

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.3/465.3 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"
import keras_nlp
import keras

In [5]:
import tensorflow as tf
tf.keras.backend.set_image_data_format('channels_last')

# Import necessary modules from Keras and TensorFlow
from tensorflow import keras
import tensorflow.keras.mixed_precision as mixed_precision

# Set the global policy to use mixed precision for faster training
# This is specifically tailored to TensorFlow's handling of mixed precision
mixed_precision.set_global_policy('mixed_float16')


The highest level API in keras_nlp.models provides a comprehensive suite of modules for handling various natural language processing (NLP) tasks. Each module within this API focuses on a specific aspect of the NLP pipeline, from converting strings to tokens to generating task-specific outputs. Here's a summary of the key modules:

Tokenizer: keras_nlp.models.XXTokenizer

Converts strings to sequences of token IDs.
Essential for mapping raw strings to a manageable number of tokens.
Inherits from keras.layers.Layer.


Preprocessor: keras_nlp.models.XXPreprocessor

Converts strings to preprocessed tensors used by the backbone.
Incorporates special tokens and tensors for understanding input sequences.
Utilizes a tokenizer.
Inherits from keras.layers.Layer.


Backbone: keras_nlp.models.XXBackbone

Converts preprocessed tensors to dense features.
Distills input tokens into dense features for downstream tasks.
Inherits from keras.Model.


Task Model: e.g., keras_nlp.models.XXClassifier

Converts strings to task-specific output, such as classification probabilities.
Combines preprocessing, backbone, and task-specific layers.
Requires fine-tuning on labeled data.
Inherits from keras.Model.
The modular hierarchy, exemplified by BertClassifier, emphasizes compositional relationships between the modules, allowing for flexible and customizable NLP workflows. All modules offer a from_preset() method for instantiating the class with preset architecture and weights, simplifying usage.






## Data
The AG News dataset is a collection of news articles categorized into four classes: World, Sports, Business, and Science/Technology. Here's a breakdown of the code and what it does:

**Loading the Dataset:**

The dataset is loaded using tfds.load() with the name 'ag_news_subset'. This dataset contains a subset of the AG News dataset.
The dataset is split into training and testing sets using the split parameter.

**Batching the Dataset:**

The training and testing datasets are batched using the batch() method. The BATCH_SIZE is set to 16, meaning each batch will contain 16 examples.
Decoding Text:

The decode_text() function is defined to decode text from a tensor. It takes a text tensor and a label tensor as input and returns the decoded text and label.

**Inspecting the Dataset:**

The code then inspects the first review in the training set. It iterates over the training set, decodes the text and label tensors using the decode_text() function, and prints them out.
Additionally, it inspects the first review by using the unbatch().take(1).get_single_element() method. This returns the first example in the dataset without decoding.
Overall, the code loads the AG News dataset, batches it, and provides a glimpse of the data by inspecting the first review. The dataset is commonly used for text classification tasks such as sentiment analysis or topic classification.

We load the data using `keras.utils.text_dataset_from_directory`, which utilizes the
powerful `tf.data.Dataset` format for examples.

In [6]:
pip install datasets

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed dataset

In [7]:
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
import numpy as np

In [8]:
import tensorflow_datasets as tfds
ag_news, info = tfds.load('ag_news_subset', with_info=True)

In [9]:
pip install tensorflow-datasets



In [10]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the AG News dataset
ag_news_train, ag_news_test = tfds.load('ag_news_subset', split=['train', 'test'], as_supervised=True)
BATCH_SIZE = 16
# Batch the datasets
ag_news_train = ag_news_train.batch(BATCH_SIZE)
ag_news_test = ag_news_test.batch(BATCH_SIZE)

# Function to decode the text from a tensor
def decode_text(text_tensor, label_tensor):
    text = text_tensor.numpy().decode('utf-8')
    label = label_tensor.numpy()
    return text, label

# Inspect the first review of the training set
for text_tensor, label_tensor in ag_news_train.unbatch().take(1):
    text, label = decode_text(text_tensor, label_tensor)
    print(f"Text: {text}\nLabel: {label}")

# Inspect first review
# Format is (review text tensor, label tensor)
print(ag_news_train.unbatch().take(1).get_single_element())

Text: AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.
Label: 3
(<tf.Tensor: shape=(), dtype=string, numpy=b'AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.'>, <tf.Tensor: shape=(), dtype=int64, numpy=3>)


In [14]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the AG News dataset
ag_news_train, ag_news_test = tfds.load('ag_news_subset', split=['train', 'test'], as_supervised=True)

BATCH_SIZE = 16

# Take a subset of the dataset for both training and testing
# Here, `.take(n)` means we take only the first 'n' batches
# For example, if BATCH_SIZE = 16 and n = 100, we take 1600 examples
SUBSET_SIZE = 100  # Number of batches to take
ag_news_train_subset = ag_news_train.batch(BATCH_SIZE).take(SUBSET_SIZE)
ag_news_test_subset = ag_news_test.batch(BATCH_SIZE).take(SUBSET_SIZE)

# Function to decode the text from a tensor
def decode_text(text_tensor, label_tensor):
    text = text_tensor.numpy().decode('utf-8')
    label = label_tensor.numpy()
    return text, label

# Inspect the first review of the training set subset
for text_tensor, label_tensor in ag_news_train_subset.unbatch().take(1):
    text, label = decode_text(text_tensor, label_tensor)
    print(f"Text: {text}\nLabel: {label}")

Text: AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.
Label: 3


## **1) Inference with a pretrained classifier**


A task is the highest level module in KerasNLP. A task is a {keras.Model} made up of task-specific layers and a **backbone** model, which is typically pretrained.
Using `keras_nlp.models.BertClassifier}, here's an example.

**Note**: The logits for each class are the outputs (e.g., {[0, 0]} indicates a 50% chance of positive). For binary classification, the result is [positive, negative].

In [16]:
classifier = keras_nlp.models.BertClassifier.from_preset("bert_tiny_en_uncased_sst2")
# Note: batched inputs expected so must wrap string in iterable
classifier.predict(["I love modular workflows in keras-nlp!"])

  trackable.load_own_variables(weights_store.get(inner_path))
  trackable.load_own_variables(weights_store.get(inner_path))


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step


array([[-1.539,  1.542]], dtype=float16)

All **tasks** have a `from_preset` method that constructs a `keras.Model` instance with
preset preprocessing, architecture and weights. This means that we can pass raw strings
in any format accepted by a `keras.Model` and get output specific to our task.

This particular **preset** is a `"bert_tiny_uncased_en"` **backbone** fine-tuned on
`sst2`, another movie review sentiment analysis (this time from Rotten Tomatoes). We use
the `tiny` architecture for demo purposes, but larger models are recommended for SoTA
performance. For all the task-specific presets available for `BertClassifier`, see
our keras.io [models page](https://keras.io/api/keras_nlp/models/).

Let's evaluate our classifier on the IMDB dataset. You will note we don't need to
call `keras.Model.compile` here. All **task** models like `BertClassifier` ship with
compilation defaults, meaning we can just call `keras.Model.evaluate` directly. You
can always call compile as normal to override these defaults (e.g. to add new metrics).

The output below is [loss, accuracy],

In [17]:
# Assuming `classifier` is your TensorFlow/Keras model
eval_results = classifier.evaluate(ag_news_test_subset)
print(eval_results)


[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 1s/step - loss: 0.5685 - sparse_categorical_accuracy: 0.2733
[0.548656702041626, 0.2862499952316284]


Our result is around 29% accuracy without training anything.

## 2)Fine tuning a pretrained BERT backbone

Performance can be enhanced by fine-tuning a custom classifier when labeled text relevant to our task is available. Rotten Tomatoes data should not outperform IMDB data in terms of predicting the sentiment of IMDB reviews! Furthermore, there won't be any relevant pretrained models available for many tasks (like classifying customer reviews).


With the exception of requesting a **preset** for the **backbone**-only model rather than the complete classifier, the fine-tuning procedure is nearly the same as it was previously described. A **task** {Model} will randomly initialize all task-specific layers in order to get ready for training when given a **backbone** **preset**.

Visit our keras.io [models page](https://keras.io/api/keras_nlp/models/) to view all of the **backbone** presets that are available for {BertClassifier}.


To train your classifier, use `keras.Model.fit` as with any other
`keras.Model`. As with our inference example, we can rely on the compilation
defaults for the **task** and skip `keras.Model.compile`. As preprocessing is
included, we again pass the raw data.

In [19]:
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_tiny_en_uncased",
    num_classes=2,
)
classifier.fit(
    ag_news_train_subset,
    validation_data=ag_news_test_subset,
    epochs=1,
)

[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m477s[0m 5s/step - loss: 0.3134 - sparse_categorical_accuracy: 0.4166 - val_loss: 0.1645 - val_sparse_categorical_accuracy: 0.5081


<keras.src.callbacks.history.History at 0x79fb5e8ebd60>

Here we see a significant lift in validation accuracy (0.29 to 0.41) with a 1 epoch of
training.

## 3)Fine tuning with user-controlled preprocessing


In some advanced training scenarios, users might prefer to be directly in charge of preprocessing. For larger datasets, examples can be preprocessed beforehand and saved to disk or by a different worker pool using `tf.data.experimental.service}.

In other cases, special preprocessing is needed to handle the inputs.

Pass {preprocessor=None} or a custom `BertPreprocessor} to the constructor of a **task** {Model} to prevent automatic preprocessing.

### Separate preprocessing from the same preset

Each model architecture has a parallel **preprocessor** `Layer` with its own
`from_preset` constructor. Using the same **preset** for this `Layer` will return the
matching **preprocessor** as the **task**.

In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`,
which computes the preprocessing once and caches the result before fitting begins.

**Note:** we can use `tf.data` for preprocessing while running on the
Jax or PyTorch backend. The input dataset will automatically be converted to
backend native tensor types during fit. In fact, given the efficiency of `tf.data`
for running preprocessing, this is good practice on all backends.

In [20]:
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
    "bert_tiny_en_uncased",
    sequence_length=512,
)

# Apply the preprocessor to every sample of train and test data using `map()`.
# `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see
# https://www.tensorflow.org/guide/data_performance for details.

# Note: only call `cache()` if you training data fits in CPU memory!
ag_train_cached = (
    ag_news_train_subset.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)
)
ag_test_cached = (
    ag_news_test_subset.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)
)

classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_tiny_en_uncased", preprocessor=None, num_classes=2
)
classifier.fit(
    ag_train_cached,
    validation_data=ag_test_cached,
    epochs=3,
)

Epoch 1/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m468s[0m 5s/step - loss: 0.3123 - sparse_categorical_accuracy: 0.4051 - val_loss: 0.1555 - val_sparse_categorical_accuracy: 0.5081
Epoch 2/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m468s[0m 5s/step - loss: 0.1356 - sparse_categorical_accuracy: 0.4963 - val_loss: 0.1000 - val_sparse_categorical_accuracy: 0.4975
Epoch 3/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m425s[0m 4s/step - loss: 0.0694 - sparse_categorical_accuracy: 0.5052 - val_loss: 0.0554 - val_sparse_categorical_accuracy: 0.5113


<keras.src.callbacks.history.History at 0x79fb5f69b280>

After three epochs, our validation accuracy has only increased to 0.51. This is both a
function of the small size of our dataset and our model. To exceed 90% accuracy, try
larger **presets** such as  `"bert_base_en_uncased"`. For all the **backbone** presets
available for `BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_nlp/models/).

## 4)Fine tuning with a custom model

There might not be a suitable **task** {Model} for more complex applications. Here, we offer direct access to the **backbone** {Model}, which can be assembled using unique {Layer}s and has its own `from_preset} constructor. Visit our [transfer learning guide](https://keras.io/guides/transfer_learning/) for more information and detailed examples.

While automatic preprocessing is not included in a **backbone** {Model}, it can be paired with a compatible **preprocessor** by using the same **preset** as demonstrated in the workflow before.

In order to adjust to the new input, we test the idea of freezing our backbone model and adding two trainable transformer layers in this workflow.

**Note**: Since we are using the sequence output from BERT, we can disregard the warning regarding gradients for the {pooled_dense} layer.

In [43]:
preprocessor = keras_nlp.models.BertPreprocessor.from_preset("bert_tiny_en_uncased")
backbone = keras_nlp.models.BertBackbone.from_preset("bert_tiny_en_uncased")

ag_train_preprocessed = (
    ag_news_train_subset.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)
)
ag_test_preprocessed = (
    ag_news_test_subset.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)
)

backbone.trainable = False
inputs = backbone.input
sequence = backbone(inputs)["sequence_output"]
for _ in range(2):
    sequence = keras_nlp.layers.TransformerEncoder(
        num_heads=2,
        intermediate_dim=512,
        dropout=0.1,
    )(sequence)
# Use [CLS] token output to classify
outputs = keras.layers.Dense(2)(sequence[:, backbone.cls_token_index, :])

model = keras.Model(inputs, outputs)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.AdamW(5e-5),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    jit_compile=True,
)
model.summary()
model.fit(
    ag_train_preprocessed,
    validation_data=ag_test_preprocessed,
    epochs=3,
)

Epoch 1/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m594s[0m 6s/step - loss: 0.3071 - sparse_categorical_accuracy: 0.4037 - val_loss: 0.0886 - val_sparse_categorical_accuracy: 0.4950
Epoch 2/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m575s[0m 6s/step - loss: 0.1064 - sparse_categorical_accuracy: 0.4793 - val_loss: 0.0682 - val_sparse_categorical_accuracy: 0.5063
Epoch 3/3
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m623s[0m 6s/step - loss: 0.0823 - sparse_categorical_accuracy: 0.4937 - val_loss: 0.0727 - val_sparse_categorical_accuracy: 0.5000


<keras.src.callbacks.history.History at 0x79fb09904460>

This model achieves reasonable accuracy despite having only 10% of the trainable parameters
of our `BertClassifier` model. Each training step takes about 1/3 of the time---even
accounting for cached preprocessing.

and the loss also reduced after 3 epochs and accuracy increased to 0.5% from 0.3(initial accuracy)