## Run imports and set variables

In [41]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
import datetime

In [52]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

model_path = '../models/bert1'

batch_size = 32
num_epochs = 3
percent_train_data = 10

## Import data
Import the dataset from Tensorflow Hub and split it into train and test sets.

In [43]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=(f'train[:{percent_train_data}%]', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=batch_size
)


## Import BERT model and preprocessing handler

In [44]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing')
bert = hub.KerasLayer(bert_model_path, trainable=True, name = 'BERT')

## Build the model
We create a function to define and compile the NN with the pretrained BERT model.

In [45]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    net = outputs['pooled_output']

    # Additional layers for classification
    net = tf.keras.layers.Dense(4, activation='softmax')(net)

    # Build and compile the model
    model = tf.keras.Model(text_input, net)
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()


Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_type_ids': ( 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'pooled_output': (N 4385921     preprocessing[0][0]              
                                                                 preprocessing[0][1]              
                                                                 preprocessing[0][2]              
____________________________________________________________________________________________

## Train the model
Now that the model is compiled, we can train on our data. We use early stopping to prevent overfitting

In [55]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir logs/fit

In [56]:
earlystopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', 
    patience=3, 
    verbose=1
)

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs/fit/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 
    histogram_freq=1, 
    update_freq='batch'
)

history = model.fit(
    x=train_data,
    validation_data=test_data,
    epochs=num_epochs,
    callbacks=[earlystopping_callback, tensorboard_callback],
    batch_size=batch_size,
    verbose=1
)


2021-11-18 15:22:37.004170: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-11-18 15:22:37.004222: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-11-18 15:22:37.006286: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


Epoch 1/3
  1/375 [..............................] - ETA: 28:30 - loss: 0.0732 - sparse_categorical_accuracy: 0.9688

2021-11-18 15:22:42.112628: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-11-18 15:22:42.112646: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


  2/375 [..............................] - ETA: 3:34 - loss: 0.1769 - sparse_categorical_accuracy: 0.9375 

2021-11-18 15:22:42.858386: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-11-18 15:22:42.880274: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2021-11-18 15:22:42.902728: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211118-152236/train/plugins/profile/2021_11_18_15_22_42

2021-11-18 15:22:42.912826: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to logs/fit/20211118-152236/train/plugins/profile/2021_11_18_15_22_42/10-192-242-225client.eduroam.upc.edu.trace.json.gz
2021-11-18 15:22:42.950937: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211118-152236/train/plugins/profile/2021_11_18_15_22_42

2021-11-18 15:22:42.951186: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to logs/fit/20211118-152236/tra

Epoch 2/3
Epoch 3/3


In [None]:
import matplotlib.pyplot as plt
model.predict(['i play a lot of fotball, sports is nice, who scored most goals'])

In [None]:
model.save(model_path)

In [None]:
model = tf.keras.models.load_model(model_path)