In [1]:
import os
import logging
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from math import ceil
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.linear_model import LinearRegression

from bert_serving.client import BertClient
from bert_serving.server.graph import optimize_graph
from bert_serving.server.helper import get_args_parser
from bert_serving.server.bert.tokenization import FullTokenizer
from bert_serving.server.bert.extract_features import convert_lst_to_features

import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.keras.utils import Progbar

log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
log.handlers = []

# BERT Embeddings on GPU

This notebook uses `Bert-as-a-Service` and `TensorFlow` to extract embeddings from tokens and documents. This is part of a larger project to use these embeddings to calculate Semantic Similarity to clean search results in user-generated data. Please visit `Embedding Similarity Across Models` to see this done (along with other models).

In [2]:
def predict_inference_speed(seq_len):
    lr = LinearRegression()
    lr.fit(X = np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]).reshape(-1, 1),
           y = np.array([0.0214, 0.0236, 0.0257, 0.0277, 0.0289, 0.0316, 0.0344, 0.0355, 0.0422, 0.0429, 0.0454]).reshape(-1, 1)
          );
    return lr.predict(np.array([[seq_len]]))[0][0]

In [3]:
f'Hours: {round(((predict_inference_speed(seq_len = 22)) * 550000) / 60 / 60, 3)}'

'Hours: 9.826'

## Build / Export Graph

 - **Models**:
     - Token Model Seq Len: 4
     - Doc Model Seq Len: 22 (99.9% coverage)

In [None]:
MODEL_DIR = 'models/wwm_uncased_L-24_H-1024_A-16'
GRAPH_DIR = 'models/graph'
NUM_WORKER = "8"                                # concurrency
POOL_LAYER = "-2"                               # second to last layer (suggested as last layer is too biased on target, see Bert-as-a-Service docs for more info)
POOL_STRAT = 'REDUCE_MEAN'                      # averages vectors for all tokens in a sequence
SEQ_LEN = 22                                    # in tokens, changing this value linearly affects inference speed
GRAPH_OUT = f'extractor_seq-len-{SEQ_LEN}.pbtxt'

In [4]:
tf.gfile.MkDir(GRAPH_DIR)

parser = get_args_parser()
carg = parser.parse_args(args=['-model_dir', MODEL_DIR,
                               "-graph_tmp_dir", GRAPH_DIR,
                               '-max_seq_len', str(SEQ_LEN),
                               '-pooling_layer', str(POOL_LAYER),
                               '-pooling_strategy', str(POOL_STRAT),
                               '-num_worker', str(NUM_WORKER),
                              ])

tmpfi_name, config = optimize_graph(carg)
graph_fout = os.path.join(GRAPH_DIR, GRAPH_OUT)

tf.gfile.Rename(
    tmpfi_name,
    graph_fout,
    overwrite=True
)
print("Serialized graph to {}".format(graph_fout))

## Create Feature Extractor

In [5]:
SEQ_LEN = int(SEQ_LEN)
GRAPH_PATH = f'models/graph/extractor_seq-len-{SEQ_LEN}.pbtxt'
VOCAB_PATH = 'models/wwm_uncased_L-24_H-1024_A-16/vocab.txt'

In [9]:
INPUT_NAMES = ['input_ids', 'input_mask', 'input_type_ids']
bert_tokenizer = FullTokenizer(VOCAB_PATH)

def build_feed_dict(texts):
    
    text_features = list(
        convert_lst_to_features(
            lst_str = texts,
            max_seq_length = SEQ_LEN,
            max_position_embeddings = SEQ_LEN,
            tokenizer = bert_tokenizer,
            logger = log,
            is_tokenized = False,
            mask_cls_sep = False
    ))

    target_shape = (len(texts), -1)

    feed_dict = {}
    for iname in INPUT_NAMES:
        features_i = np.array([getattr(f, iname) for f in text_features])
        features_i = features_i.reshape(target_shape)
        features_i = features_i.astype("int32")
        feed_dict[iname] = features_i

    return feed_dict

def build_input_fn(container):
    
    def gen():
        while True:
            try:
                yield build_feed_dict(container.get())
            except:
                yield build_feed_dict(container.get())
    
    def input_fn():
        return tf.data.Dataset.from_generator(
            gen,
            output_types = {iname: tf.int32 for iname in INPUT_NAMES},
            output_shapes = {iname: (None, None) for iname in INPUT_NAMES})
    
    return input_fn

class DataContainer:
    
    def __init__(self):
        self._texts = None
    
    def set(self, texts):
        if type(texts) is str:
            texts = [texts]
        self._texts = texts
        
    def get(self):
        return self._texts
    
def model_fn(features, mode):
    
    with tf.gfile.GFile(GRAPH_PATH, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    output = tf.import_graph_def(
        graph_def,
        input_map={k + ':0': features[k] for k in INPUT_NAMES},
        return_elements=['final_encodes:0']
    )
    
    return EstimatorSpec(mode=mode, predictions={'output': output[0]})

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def build_vectorizer(_estimator, _input_fn_builder, batch_size=128):
    container = DataContainer()
    predict_fn = _estimator.predict(_input_fn_builder(container), yield_single_examples=False)
    
    def vectorize(text, verbose=False):
        x = []
        bar = Progbar(len(text))
        
        for text_batch in batch(text, batch_size):
            container.set(text_batch)
            x.append(next(predict_fn)['output'])
            if verbose:
                bar.add(len(text_batch))
        
        r = np.vstack(x)
        return r
    
    return vectorize

In [10]:
# init and build embedding generator
estimator = Estimator(model_fn = model_fn)
bert_vectorizer = build_vectorizer(estimator, build_input_fn)
_ = bert_vectorizer(['music']); del _

# Get Embeddings

In [11]:
with open('all_queries_and_plns.pkl', 'rb') as f:
    all_queries_and_plns = pickle.load(f)

all_tokens = list(set([t for tt in [token.split() for token in all_queries_and_plns] for t in tt]))

## Tokens

In [12]:
token_embeds = bert_vectorizer(all_tokens)

token_embed_dict = {token: embed.reshape(1, -1) for token, embed in zip(all_tokens, token_embeds)}

with open('bert_token_embeds.pkl', 'wb') as f:
    pickle.dump(token_embed_dict, f, protocol = pickle.HIGHEST_PROTOCOL)

## Docs

In [14]:
def get_emb_dict(word_list):
    return {w: e.reshape(1, -1) for w, e in zip(word_list, bert_vectorizer(word_list))}

def get_doc_embeds(batches_ = 10):
    write_fname = 'bert_doc_embeds.pkl'
    batch_size_ = ceil(len(all_queries_and_plns) / batches_)

    saved_emb = {}

    for i in range(batches_):

        print(f'Batch: {i+1}')
        print(f'Saved Emb Size: {len(saved_emb)}')
        print()

        batch_ = list(all_queries_and_plns[batch_size_*i:batch_size_*(i+1)])
        saved_emb = {**saved_emb, **get_emb_dict(batch_)}

        with open(write_fname, 'wb') as f:
            pickle.dump(saved_emb, f, protocol = pickle.HIGHEST_PROTOCOL)
    
    return saved_emb

In [None]:
doc_embeds = get_doc_embeds()