In [1]:
import os
import shutil

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization  # to create AdamW optimizer

import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


In [2]:
tfhub_handle_preprocess = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
tfhub_handle_encoder = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1"

In [3]:
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
text_test = ['this is such an amazing movie!']
text_preprocessed = bert_preprocess_model(text_test)

print(f'Keys       : {list(text_preprocessed.keys())}')
print(f'Shape      : {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids   : {text_preprocessed["input_word_ids"][0, :12]}')
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids   : {text_preprocessed["input_type_ids"][0, :12]}')

Keys       : ['input_mask', 'input_word_ids', 'input_type_ids']
Shape      : (1, 128)
Word Ids   : [ 101 2023 2003 2107 2019 6429 3185  999  102    0    0    0]
Input Mask : [1 1 1 1 1 1 1 1 1 0 0 0]
Type Ids   : [0 0 0 0 0 0 0 0 0 0 0 0]


In [4]:
bert_model = hub.KerasLayer(tfhub_handle_encoder)

In [5]:
bert_results = bert_model(text_preprocessed)

print(f'Loaded BERT: {tfhub_handle_encoder}')
print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')

Loaded BERT: https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1
Pooled Outputs Shape:(1, 512)
Pooled Outputs Values:[ 0.76262903  0.99280983 -0.18611847  0.3667386   0.15233745  0.6550445
  0.9681154  -0.94862705  0.00216164 -0.9877732   0.0684273  -0.97630596]
Sequence Outputs Shape:(1, 128, 512)
Sequence Outputs Values:[[-0.28946295  0.34321263  0.33231527 ...  0.2130087   0.71020836
  -0.05771071]
 [-0.2874208   0.31981027 -0.23018518 ...  0.5845508  -0.21329744
   0.7269212 ]
 [-0.66157013  0.6887687  -0.8743292  ...  0.10877226 -0.26173285
   0.47855547]
 ...
 [-0.22561097 -0.2892568  -0.07064426 ...  0.47566074  0.83277184
   0.40025318]
 [-0.29824227 -0.27473107 -0.05450526 ...  0.488498    1.0955358
   0.18163362]
 [-0.4437815   0.00930744  0.07223801 ...  0.17290124  1.1833242
   0.07898009]]


In [6]:
import pandas as pd
df_mt = pd.read_csv("/ssd003/projects/pets/datasets/mtsamples.csv")

In [17]:
X_all = df_mt["transcription"][~df_mt["transcription"].isna()]
Y_all = df_mt["medical_specialty"][~df_mt["transcription"].isna()]

In [13]:
text_preprocessed_mt = bert_preprocess_model(X_all)

In [14]:
import pickle
with open("text_preprocessed_mt.pkl", "wb") as fout:
    pickle.dump(text_preprocessed_mt, fout)

In [15]:
with open("text_preprocessed_mt.pkl", "rb") as fin:
    text_preprocessed_mt2 = pickle.load(fin)

In [28]:
import numpy as np

pooled_output_all = np.zeros([len(X_all), 512])

for i in range(len(X_all)//128+1):
    start = 128*i
    end = start + 128
    text_preprocessed_batch = {
        "input_mask": text_preprocessed_mt["input_mask"][start:end],
        "input_word_ids": text_preprocessed_mt["input_word_ids"][start:end],
        "input_type_ids": text_preprocessed_mt["input_type_ids"][start:end]
    }
    bert_mt_results = bert_model(text_preprocessed_batch)
    pooled_output_all[start:end] = bert_mt_results["pooled_output"]

In [36]:
np.save("bert_mt_pooled_output", pooled_output_all, allow_pickle=True)

In [35]:
pooled_output_all

array([[ 0.90054685,  0.9077723 , -0.30310431, ...,  0.12300892,
        -0.5520494 , -0.67643225],
       [ 0.11917152,  0.96000177, -0.09910253, ..., -0.07028765,
        -0.26213399,  0.77441353],
       [-0.31382662,  0.993572  , -0.22049351, ...,  0.22879362,
        -0.7720679 ,  0.64762485],
       ...,
       [-0.60929734,  0.78391075, -0.27688962, ..., -0.05187986,
        -0.65109682,  0.90731198],
       [-0.17592755,  0.86092067, -0.10273649, ...,  0.33312511,
        -0.29383278, -0.80268782],
       [-0.97767252,  0.95949841,  0.08247349, ..., -0.05513402,
        -0.6456095 , -0.89641678]])