In [72]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
import datetime
from scipy import spatial
import numpy as np

In [5]:
bert_model_path = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2'
bert_preprocessing_path = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

In [73]:
train_data, test_data = tfds.load(
    name='ag_news_subset',
    split=(f'train', 'test'),
    shuffle_files=True,
    as_supervised=True,
    batch_size=120000
)

INFO:absl:Load dataset info from /Users/viktorenzell/tensorflow_datasets/ag_news_subset/1.0.0
INFO:absl:Reusing dataset ag_news_subset (/Users/viktorenzell/tensorflow_datasets/ag_news_subset/1.0.0)
INFO:absl:Constructing tf.data.Dataset ag_news_subset for split ('train', 'test'), from /Users/viktorenzell/tensorflow_datasets/ag_news_subset/1.0.0


In [74]:
bert_preprocessing = hub.KerasLayer(bert_preprocessing_path, name='preprocessing')
bert = hub.KerasLayer(bert_model_path, trainable=False, name = 'BERT')

In [75]:
def build_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='description')
    encoder_inputs = bert_preprocessing(text_input)
    outputs = bert(encoder_inputs)

    # Only retrieve the outputs from the corresponding [CLS] token
    # Build and compile the model
    model = tf.keras.Model(text_input, outputs['pooled_output'])
    model.compile(
        optimizer='Adam',
        loss='SparseCategoricalCrossentropy',
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    return model


model = build_model()
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
description (InputLayer)        [(None,)]            0                                            
__________________________________________________________________________________________________
preprocessing (KerasLayer)      {'input_type_ids': ( 0           description[0][0]                
__________________________________________________________________________________________________
BERT (KerasLayer)               {'pooled_output': (N 4385921     preprocessing[0][0]              
                                                                 preprocessing[0][1]              
                                                                 preprocessing[0][2]              
Total params: 4,385,921
Trainable params: 0
Non-trainable params: 4,385,921
________________

In [76]:
all_data = train_data.__iter__().get_next()
samples = all_data[0].numpy()
labels = all_data[1].numpy()

def get_sample(i):
    label_names = 'World', 'Sports', 'Business', 'Schience and Technology'
    return samples[i], labels[i], label_names[labels[i]]


In [81]:
# Comparing two samples

a, al, an = get_sample(0)
b, bl, bn = get_sample(8)

va = model.predict([a])
vb = model.predict([b])

sim = 1 - spatial.distance.cosine(va, vb)

print(an, a)
print()
print(bn, b)
print()
print(round(sim, 3))

Schience and Technology b'AMD #39;s new dual-core Opteron chip is designed mainly for corporate computing applications, including databases, Web services, and financial transactions.'

World b'Witnesses in the trial of a US soldier charged with abusing prisoners at Abu Ghraib have told the court that the CIA sometimes directed abuse and orders were received from military command to toughen interrogations.'

0.712
