## Run imports and set variables

In [1]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
import datetime

In [4]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

model_path = '../../models/bert'

batch_size = 32
num_epochs = 3
percent_train_data = 100

## Import data
Import the dataset from Tensorflow Hub and split it into train and test sets.

In [6]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=(f'train[:{percent_train_data}%]', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=batch_size
);


## Import BERT model and preprocessing handler

In [8]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing');
bert = hub.KerasLayer(bert_model_path, trainable=True, name = 'BERT');

## Build the model
We create a function to define and compile the NN with the pretrained BERT model.

In [9]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    net = outputs['pooled_output']

    # Additional layer for classification
    net = tf.keras.layers.Dense(4, activation='softmax')(net)

    # Build and compile the model
    model = tf.keras.Model(text_input, net)
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_mask': (None 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'pooled_output': (N 4385921     preprocessing[0][0]              
                                                                 preprocessing[0][1]              
                                                                 preprocessing[0][2]              
______________________________________________________________________________________________

## Train the model
Now that the model is compiled, we can train on our data. We use early stopping to prevent overfitting

In [12]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir logs/fit

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 3849), started 0:00:10 ago. (Use '!kill 3849' to kill it.)

In [13]:
earlystopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', 
    patience=3, 
    verbose=1
)

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs/fit/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 
    histogram_freq=1, 
    update_freq='batch'
)

history = model.fit(
    x=train_data,
    validation_data=test_data,
    epochs=num_epochs,
    callbacks=[earlystopping_callback, tensorboard_callback],
    batch_size=batch_size,
    verbose=1
)


2021-12-17 14:54:22.518583: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-17 14:54:22.518605: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-12-17 14:54:22.522198: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


Epoch 1/3
   1/3750 [..............................] - ETA: 4:16:38 - loss: 1.7175 - sparse_categorical_accuracy: 0.0625

2021-12-17 14:54:26.982632: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-17 14:54:26.982649: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


   2/3750 [..............................] - ETA: 29:09 - loss: 1.4796 - sparse_categorical_accuracy: 0.2500  

2021-12-17 14:54:27.214600: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-12-17 14:54:27.233974: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2021-12-17 14:54:27.254356: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211217-145422/train/plugins/profile/2021_12_17_14_54_27

2021-12-17 14:54:27.261341: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to logs/fit/20211217-145422/train/plugins/profile/2021_12_17_14_54_27/Viktors-MacBook-Pro.local.trace.json.gz
2021-12-17 14:54:27.294461: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211217-145422/train/plugins/profile/2021_12_17_14_54_27

2021-12-17 14:54:27.294732: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to logs/fit/20211217-145422/train/plugins/

Epoch 2/3
Epoch 3/3


In [14]:
import matplotlib.pyplot as plt
model.predict(['i play a lot of fotball, sports is nice, who scored most goals'])

array([[2.5322572e-03, 9.9693835e-01, 3.6762521e-04, 1.6190905e-04]],
      dtype=float32)

In [15]:
model.save(model_path)

2021-12-17 16:13:40.545074: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ../models/bert/assets


INFO:tensorflow:Assets written to: ../models/bert/assets


In [None]:
model = tf.keras.models.load_model(model_path)