## Run imports and set variables

In [8]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text

In [10]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

batch_size = 32
num_epochs = 1

## Import data
Import the dataset from Tensorflow Hub and split it into train and test sets.

In [11]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=('train', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=batch_size
)


## Import BERT model and preprocessing handler

In [12]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing')
bert = hub.KerasLayer(bert_model_path, trainable=True, name = 'BERT')

## Build the model
We create a function to define and compile the NN with the pretrained BERT model.

In [13]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    net = outputs['pooled_output']

    # Additional layers for classification
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(32, activation='relu')(net)
    net = tf.keras.layers.Dense(16, activation='relu')(net)
    net = tf.keras.layers.Dense(4, activation='softmax')(net)

    # Build and compile the model
    model = tf.keras.Model(text_input, net)
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_mask': (None 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'sequence_output':  4385921     preprocessing[0][0]              
                                                                 preprocessing[0][1]              
                                                                 preprocessing[0][2]              
____________________________________________________________________________________________

## Train the model
Now that the model is compiled, we can train on our data. We use early stopping to prevent overfitting

In [14]:
earlystopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', patience=3, verbose=1)

history = model.fit(
    x=train_data,
    validation_data=test_data,
    epochs=num_epochs,
    callbacks=[earlystopping],
    batch_size=batch_size,
    verbose=1
)






In [15]:
import matplotlib.pyplot as plt
model.predict(['i play a lot of fotball, sports is nice, who scored most goals'])

array([[0.01998609, 0.97603333, 0.0021453 , 0.00183535]], dtype=float32)