## Run imports and set variables

In [36]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
import datetime

In [37]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

batch_size = 32
num_epochs = 3
percent_train_data = 10

## Import data
Import the dataset from Tensorflow Hub and split it into train and test sets.

In [38]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=(f'train[:{percent_train_data}%]', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=batch_size
)


## Import BERT model and preprocessing handler

In [28]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing')
bert = hub.KerasLayer(bert_model_path, trainable=True, name = 'BERT')

2021-11-18 14:13:48.840700: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2021-11-18 14:13:49.608966: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2021-11-18 14:13:49.629000: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. I

## Build the model
We create a function to define and compile the NN with the pretrained BERT model.

In [29]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    net = outputs['pooled_output']

    # Additional layers for classification
    net = tf.keras.layers.Dense(4, activation='softmax')(net)

    # Build and compile the model
    model = tf.keras.Model(text_input, net)
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()


Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_word_ids': ( 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'default': (None, 1 4385921     preprocessing[0][0]              
                                                                 preprocessing[0][1]              
                                                                 preprocessing[0][2]              
____________________________________________________________________________________________

## Train the model
Now that the model is compiled, we can train on our data. We use early stopping to prevent overfitting

In [34]:
%tensorboard --logdir logs

UsageError: Line magic function `%tensorboard` not found.


In [40]:
earlystopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', patience=3, verbose=1)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(update_freq='batch')

history = model.fit(
    x=train_data,
    validation_data=test_data,
    epochs=num_epochs,
    callbacks=[earlystopping_callback, tensorboard_callback],
    batch_size=batch_size,
    verbose=1
)


2021-11-18 14:34:27.538950: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-11-18 14:34:27.538972: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-11-18 14:34:27.539812: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


Epoch 1/3


2021-11-18 14:34:27.921643: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at summary_kernels.cc:130 : Not found: Resource localhost/_AnonymousVar491/N10tensorflow22SummaryWriterInterfaceE does not exist.


NotFoundError:  Resource localhost/_AnonymousVar491/N10tensorflow22SummaryWriterInterfaceE does not exist.
	 [[{{node batch_loss/write_summary/summary_cond/then/_496/batch_loss/write_summary}}]] [Op:__inference_train_function_92491]

Function call stack:
train_function


In [None]:
import matplotlib.pyplot as plt
model.predict(['i play a lot of fotball, sports is nice, who scored most goals'])

array([[0.00346005, 0.99125564, 0.00406324, 0.00122105]], dtype=float32)

In [None]:
import numpy as np

# Returns list of all the labels in given tensorflow dataset
def get_labels(data):
    labels = np.empty(0)
    iterator = data.__iter__()
    for i in iterator:
        labels = np.append(labels, i[1].numpy())
    return labels

# Returns list of all the texts in given tensorflow dataset
def get_texts(data):
    texts = np.empty(0)
    iterator = data.__iter__()
    for i in iterator:
        texts = np.append(texts, i[0].numpy())
    return texts


In [3]:
# Get predicted labels
predictions = model.predict(test_data)
predicted_labels = [np.argmax(prediction) for prediction in predictions]

# Get true labels
true_labels = get_labels(test_data)

# Check for equality between true and predicted labels
correct_predictions = true_labels == predicted_labels

# Get indexes of incorrect predictions
correct_index = np.argwhere(not correct_predictions)

[2, 2, 2, 1, 0]