Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] fix shadow variable lookup race condition #280

Merged

Conversation

alionkun
Copy link
Contributor

@alionkun alionkun commented Sep 28, 2022

Description

Brief Description of the PR:

Fixes #278

Type of change

  • Bug fix
  • New Tutorial
  • Updated or additional documentation
  • Additional Testing
  • New Feature

Testing with TFServing

Codes

  • model.py
import tensorflow as tf
from tensorflow_recommenders_addons import dynamic_embedding as de
from absl import app, flags, logging


def build_model():
    embedding_size = 2
    user_id = tf.keras.Input(shape=(1,), name='user_id', dtype=tf.int64)
    user_emb = de.keras.layers.SquashedEmbedding(embedding_size, name='user_de')(user_id)
    user_emb = tf.keras.layers.Lambda(lambda x: x, name='user_emb')(user_emb)
    # Outputing both user_id and user_emb for the convenience of checking race condition.
    # We are expecting user_id.shape.num_elements() * embedding_size == user_emb.shape.num_elements(),
    # or there is a race condition
    model = tf.keras.Model({
                                'user_id': user_id
                           },
                           {
                                'user_id': user_id,
                                'user_emb': user_emb
                           })

    return model


def main(_):
    model = build_model()
    save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
    savedmodel_dir = '/tmp/savedmodel/tfra/train/1'
    model.save(savedmodel_dir, options=save_options)

    del model
    de.enable_inference_mode()
    model = build_model()
    save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
    savedmodel_dir = '/tmp/savedmodel/tfra/inference/1'
    model.save(savedmodel_dir, options=save_options)


if __name__ == '__main__':
    app.run(main)
  • client.py
import time
import grpc
import requests
import tensorflow as tf
import numpy as np
from absl import app, flags, logging
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc


total_count = 0
bad_count = 0
good_count = 0


def make_tfs_request(batch_size):
    user_id = np.random.randint(100, size=(batch_size, 1))
    user_id = tf.make_tensor_proto(user_id, dtype=np.int64)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'default'
    request.model_spec.signature_name = 'serving_default'
    request.inputs['user_id'].CopyFrom(user_id)
    return request


def check_emb(result):
    def element_count_from_shape(shape):
        ret = None
        for dim in shape.dim:
            size = dim.size
            if ret == None:
                ret = size
            else:
                ret = ret * size
        return ret

    user_id = result.outputs['user_id']
    user_id_count = element_count_from_shape(user_id.tensor_shape)
    user_emb = result.outputs['user_emb']
    user_emb_count = element_count_from_shape(user_emb.tensor_shape)
    if user_id_count * 2 != user_emb_count:
        global bad_count
        bad_count += 1
    else:
        global good_count
        good_count += 1


def process_resp(call_future):
    global total_count
    total_count += 1
    try:
        result = call_future.result()
    except :
        global bad_count
        bad_count += 1
    else:
        check_emb(result)


def main(_):
    channel = grpc.insecure_channel("0.0.0.0:8500")
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    tfs_requests = [ make_tfs_request(batch_size) for batch_size in range(1, 50) ]
    # warm up tfs
    stub.Predict(tfs_requests[0])
    for i in range(100):
        # prevent tfs from overloading
        time.sleep(0.1)
        # try simulating inference requests with different batch size coming in simultaneously to trigger a potential race condition
        futures = [ stub.Predict.future(request).add_done_callback(process_resp) for request in tfs_requests ]
        futures = [ stub.Predict.future(request).add_done_callback(process_resp) for request in reversed(tfs_requests) ]

    # wait for pending requests
    time.sleep(5)
    global total_count, bad_count, good_count
    assert total_count == (bad_count + good_count)
    # after fixing, bad_count should be 0
    logging.info(f'total_count={total_count}, bad_count={bad_count}, good_count={good_count}')


if __name__ == '__main__':
    app.run(main)

Testing steps

  1. generate SavedModels in both train model and inference mode
python model.py
# a train mode model is saved in /tmp/savedmodel/tfra/train/1
# a inference mode model is saved in /tmp/savedmodel/tfra/inference/1
  1. serve the train mode model with tfs in a separate terminal
tensorflow_model_server --model_base_path=/tmp/savedmodel/tfra/train
  1. request tfs and check
python client.py
# bad_count is likely not 0
  1. serve the inference mode model with tfs in a separate terminal (after killing the former tfs process)
tensorflow_model_server --model_base_path=/tmp/savedmodel/tfra/inference
  1. request tfs and check
python client.py
# bad_count should be 0

@google-cla
Copy link

google-cla bot commented Sep 28, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@rhdong rhdong requested a review from Lifann September 29, 2022 04:51
@rhdong
Copy link
Member

rhdong commented Sep 29, 2022

Hi @Lifann, would you please help review this PR? Thank you!

Copy link
Member

@rhdong rhdong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rhdong rhdong merged commit 61eb5e4 into tensorflow:master Sep 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

de.shadow_ops.embedding_lookup() is non thread-safe
3 participants