In [5]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text

## import data
We start by importing the desiered dataset from the tensorflow resources.

In [52]:
batch_size = 32
train = tfds.load('ag_news_subset', split='train', shuffle_files=True, as_supervised = True, batch_size=batch_size)
test = tfds.load('ag_news_subset', split='test', shuffle_files=True, as_supervised = True, batch_size=batch_size)

In [64]:
next(iter(train.batch(1)))

(<tf.Tensor: shape=(1, 32), dtype=string, numpy=
 array([[b'AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.',
         b'Reuters - Major League Baseball\\Monday announced a decision on the appeal filed by Chicago Cubs\\pitcher Kerry Wood regarding a suspension stemming from an\\incident earlier this season.',
         b'President Bush #39;s  quot;revenue-neutral quot; tax reform needs losers to balance its winners, and people claiming the federal deduction for state and local taxes may be in administration planners #39; sights, news reports say.',
         b'Britain will run out of leading scientists unless science education is improved, says Professor Colin Pillinger.',
         b'London, England (Sports Network) - England midfielder Steven Gerrard injured his groin late in Thursday #39;s training session, but is hopeful he will be ready for Saturday #39;s World Cup qualifier a

## import bert

We import a light bert from tesnorflow hub to later be finetuned. We can also import the acompanied preprocessing layer.

In [53]:
bert_preprocessing = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3', name='preprocessing')
bert = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2', trainable=True, name = 'BERT')

## Build the model
We create a function to define and compile the NN with the pretrained bert model

In [83]:
def build_model():
  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
  encoder_inputs = bert_preprocessing(text_input)
  outputs = bert(encoder_inputs)
  
  net = outputs['pooled_output']
  net = tf.keras.layers.Dropout(0.1)(net)
  net = tf.keras.layers.Dense(32, activation=None)(net)
  net = tf.keras.layers.Dense(16, activation=None)(net)
  net = tf.keras.layers.Dense(3, activation=None)(net)
  model =tf.keras.Model(text_input, net)
  model.compile(optimizer = 'Adam', loss = 'SparseCategoricalCrossentropy', metrics = ['accuracy'])

  return model

Build and summarize the model

In [84]:
model = build_model()
model.summary()

Model: "model_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_mask': (None 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'default': (None, 1 4385921     preprocessing[11][0]             
                                                                 preprocessing[11][1]             
                                                                 preprocessing[11][2]             
___________________________________________________________________________________________

Now that the model in compiled, we can train on our data. We use early stopping to prevent overfitting

In [85]:
earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=3, verbose=1)

history = model.fit(
        x = train,
        validation_data=test,
        epochs=1,
        callbacks=[earlystopping],
        batch_size=32,
        verbose=1)