## Run imports and set variables

In [1]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
import datetime

In [45]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

model_path = '../../models/bert'

batch_size = 32
num_epochs = 3
percent_train_data = 100

## Import data
Import the dataset from Tensorflow Hub and split it into train and test sets.

In [46]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=(f'train[:{percent_train_data}%]', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=batch_size
);

## Limiting the amount of data by number of batches

In [47]:
num_batches = 32
train_data = train_data.take(num_batches)

print(f'Using {num_batches} batches and {num_batches * batch_size} samples')

Using 16 batches and 512 samples


## Import BERT model and preprocessing handler

In [4]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing');
bert = hub.KerasLayer(bert_model_path, trainable=True, name = 'BERT');

2021-12-19 12:53:42.340587: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


## Build the model
We create a function to define and compile the NN with the pretrained BERT model.

In [48]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    net = outputs['pooled_output']

    # Additional layer for classification
    net = tf.keras.layers.Dense(4, activation='softmax')(net)

    # Build and compile the model
    model = tf.keras.Model(text_input, net)
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()


Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_type_ids': ( 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'encoder_outputs':  4385921     preprocessing[9][0]              
                                                                 preprocessing[9][1]              
                                                                 preprocessing[9][2]              
____________________________________________________________________________________________

## Train the model
Now that the model is compiled, we can train on our data. We use early stopping to prevent overfitting

In [10]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

%tensorboard --logdir logs/fit

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 4635), started 0:30:55 ago. (Use '!kill 4635' to kill it.)

In [49]:
earlystopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy', 
    patience=3, 
    verbose=1
)

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs/fit/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 
    histogram_freq=1, 
    update_freq='batch'
)

history = model.fit(
    x=train_data,
    validation_data=test_data,
    epochs=num_epochs,
    callbacks=[earlystopping_callback, tensorboard_callback],
    batch_size=batch_size,
    verbose=1
)


2021-12-19 13:22:29.297383: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-19 13:22:29.297400: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2021-12-19 13:22:29.298196: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


Epoch 1/3
 1/16 [>.............................] - ETA: 36s - loss: 1.5996 - sparse_categorical_accuracy: 0.3125

2021-12-19 13:22:32.084803: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2021-12-19 13:22:32.084820: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


 2/16 [==>...........................] - ETA: 6s - loss: 1.3117 - sparse_categorical_accuracy: 0.3906 

2021-12-19 13:22:32.587948: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-12-19 13:22:32.595945: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2021-12-19 13:22:32.602559: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211219-132229/train/plugins/profile/2021_12_19_13_22_32

2021-12-19 13:22:32.608010: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to logs/fit/20211219-132229/train/plugins/profile/2021_12_19_13_22_32/Viktors-MacBook-Pro.local.trace.json.gz
2021-12-19 13:22:32.622640: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: logs/fit/20211219-132229/train/plugins/profile/2021_12_19_13_22_32

2021-12-19 13:22:32.622916: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to logs/fit/20211219-132229/train/plugins/



2021-12-19 13:22:54.650798: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Epoch 2/3


2021-12-19 13:23:16.521193: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Epoch 3/3


2021-12-19 13:23:37.802511: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [14]:
import matplotlib.pyplot as plt
model.predict(['i play a lot of fotball, sports is nice, who scored most goals'])

array([[2.5322572e-03, 9.9693835e-01, 3.6762521e-04, 1.6190905e-04]],
      dtype=float32)

In [15]:
model.evaluate(test_data)



[1.1623013019561768, 0.5490789413452148]

In [15]:
model.save(model_path)

2021-12-17 16:13:40.545074: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ../models/bert/assets


INFO:tensorflow:Assets written to: ../models/bert/assets


In [None]:
model = tf.keras.models.load_model(model_path)

## Plot comparison between fine-tuning with different amounts of data and dataless

In [None]:
import matplotlib.pyplot as plt

x_points = [128, 256, 384, 512, 640, 768, 896, 1024]
fine_tuning = [63.6, 73.8, 77.0, 75.5, 78.6, 82.2, 82.3, 82.5]
dataless = [77.6 for _ in range(8)]

plt.plot(x_points, fine_tuning, color='black', marker='o' , label='Fine-tuning')
plt.plot(x_points, dataless, color='black', linestyle='dashed', label='Dataless')

plt.xlabel('Number of training samples')
plt.ylabel('Test accuracy (%)')
plt.xticks(x_points, x_points)
plt.legend(loc='lower right')
plt.show()