Skip to content

Latest commit

 

History

History

sequence_continuous_batching

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

⚠️ Notice: Limited Maintenance

This project is no longer actively maintained. While existing releases remain available, there are no planned updates, bug fixes, new features, or security patches. Users should be aware that vulnerabilities may not be addressed.

Stateful Inference

A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.

Within this context, TorchServe offers a mechanism known as sequence continuous batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This sequence ID serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID when a connection resumes as long as the sequence is not expired on the TorchServe side. Additionally, continuous batching enables a new inference request of a sequence to be served while the previous one is in a response steaming mode.

The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by maxSequenceJobQueueSize. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.

sequence batch

This example serves as a practical showcase of employing stateful inference via sequence batching and continuous batching. Underneath the surface, the backend leverages an LRU dictionary, functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.

Step 1: Implement handler

stateful_handler.py is an example of stateful handler. It creates a cache self.cache by calling [LRU](https://github.com/amitdev/lru-dict).

    def initialize(self, ctx: Context):
        """
        Loads the model and Initializes the necessary artifacts
        """

        ctx.cache = {}
        if ctx.model_yaml_config["handler"] is not None:
            self.cache = LRU(
                int(
                    ctx.model_yaml_config["handler"]
                    .get("cache", {})
                    .get("capacity", StatefulHandler.DEFAULT_CAPACITY)
                )
            )
        self.initialized = True

Handler uses sequenceId (ie., sequence_id = self.context.get_sequence_id(idx)) as key to store and fetch values from self.cache.

    def preprocess(self, data):
        """
        Preprocess function to convert the request input to a tensor(Torchserve supported format).
        The user needs to override to customize the pre-processing

        Args :
            data (list): List of the data from the request input.

        Returns:
            tensor: Returns the tensor data of the input
        """

        results = []
        for idx, row in enumerate(data):
            sequence_id = self.context.get_sequence_id(idx)
            # SageMaker sticky router relies on response header to identify the sessions
            # The sequence_id from request headers must be set in response headers
            self.context.set_response_header(
                idx, self.context.header_key_sequence_id, sequence_id
            )

            # check if sequence_id exists
            if self.context.get_request_header(
                idx, self.context.header_key_sequence_start
            ):
                prev = int(0)
                self.context.cache[sequence_id] = {
                    "start": True,
                    "cancel": False,
                    "end": False,
                    "num_requests": 0,
                }
            elif self.cache.has_key(sequence_id):
                prev = int(self.cache[sequence_id])
            else:
                prev = None
                logger.error(
                    f"Not received sequence_start request for sequence_id:{sequence_id} before"
                )

            req_id = self.context.get_request_id(idx)
            # process a new request
            if req_id not in self.context.cache:
                logger.info(
                    f"received a new request sequence_id={sequence_id}, request_id={req_id}"
                )
                request = row.get("data") or row.get("body")
                if isinstance(request, (bytes, bytearray)):
                    request = request.decode("utf-8")

                self.context.cache[req_id] = {
                    "stopping_criteria": self._create_stopping_criteria(
                        req_id=req_id, seq_id=sequence_id
                    ),
                    "stream": True,
                }
                self.context.cache[sequence_id]["num_requests"] += 1

                if type(request) is dict and "input" in request:
                    request = request.get("input")

                # -1: cancel
                if int(request) == -1:
                    self.context.cache[sequence_id]["cancel"] = True
                    self.context.cache[req_id]["stream"] = False
                    results.append(int(request))
                elif prev is None:
                    logger.info(
                        f"Close the sequence:{sequence_id} without open session request"
                    )
                    self.context.cache[sequence_id]["end"] = True
                    self.context.cache[req_id]["stream"] = False
                    self.context.set_response_header(
                        idx, self.context.header_key_sequence_end, sequence_id
                    )
                    results.append(int(request))
                else:
                    val = prev + int(request)
                    self.cache[sequence_id] = val
                    # 0: end
                    if int(request) == 0:
                        self.context.cache[sequence_id]["end"] = True
                        self.context.cache[req_id]["stream"] = False
                        self.context.set_response_header(
                            idx, self.context.header_key_sequence_end, sequence_id
                        )
                    # non stream input:
                    elif int(request) % 2 == 0:
                        self.context.cache[req_id]["stream"] = False

                    results.append(val)
            else:
                # continue processing stream
                logger.info(
                    f"received continuous request sequence_id={sequence_id}, request_id={req_id}"
                )
                time.sleep(1)
                results.append(prev)

        return results

Step 2: Model configuration

Stateful inference has three parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel.

  • sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout.
  • sequenceTimeoutMSec: the max duration in milliseconds of a sequence inference request of this stateful model. The default value is 0 (i.e. there is effectively no sequence timeout and the sequence does not expire). TorchServe does not process a new inference request if the sequence timeout is exceeded.
  • maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1.
#cat model-config.yaml

minWorkers: 2
maxWorkers: 2
batchSize: 4
sequenceMaxIdleMSec: 60000
maxSequenceJobQueueSize: 10
sequenceBatching: true
continuousBatching: true

handler:
  cache:
    capacity: 4

Step 3: Generate mar or tgz file

torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml

Step 4: Build GRPC Client

The details can be found at here.

  • Install gRPC python dependencies
git submodule init
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
  • Generate python gRPC client stub using the proto files
cd ../../..
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto

Step 5: Run inference

  • Start TorchServe
torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
  • Run sequence inference via GRPC client
python ts_scripts/torchserve_grpc_client.py  infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
  • Run sequence inference via HTTP
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt