This project is no longer actively maintained. While existing releases remain available, there are no planned updates, bug fixes, new features, or security patches. Users should be aware that vulnerabilities may not be addressed.
A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.
Within this context, TorchServe offers a mechanism known as sequence continuous batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This sequence ID
serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID
when a connection resumes as long as the sequence is not expired on the TorchServe side. Additionally, continuous batching enables a new inference request of a sequence to be served while the previous one is in a response steaming mode.
The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by maxSequenceJobQueueSize
. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.
This example serves as a practical showcase of employing stateful inference via sequence batching and continuous batching. Underneath the surface, the backend leverages an LRU dictionary, functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.
stateful_handler.py is an example of stateful handler. It creates a cache self.cache
by calling [LRU](https://github.com/amitdev/lru-dict)
.
def initialize(self, ctx: Context):
"""
Loads the model and Initializes the necessary artifacts
"""
ctx.cache = {}
if ctx.model_yaml_config["handler"] is not None:
self.cache = LRU(
int(
ctx.model_yaml_config["handler"]
.get("cache", {})
.get("capacity", StatefulHandler.DEFAULT_CAPACITY)
)
)
self.initialized = True
Handler uses sequenceId (ie., sequence_id = self.context.get_sequence_id(idx)
) as key to store and fetch values from self.cache
.
def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing
Args :
data (list): List of the data from the request input.
Returns:
tensor: Returns the tensor data of the input
"""
results = []
for idx, row in enumerate(data):
sequence_id = self.context.get_sequence_id(idx)
# SageMaker sticky router relies on response header to identify the sessions
# The sequence_id from request headers must be set in response headers
self.context.set_response_header(
idx, self.context.header_key_sequence_id, sequence_id
)
# check if sequence_id exists
if self.context.get_request_header(
idx, self.context.header_key_sequence_start
):
prev = int(0)
self.context.cache[sequence_id] = {
"start": True,
"cancel": False,
"end": False,
"num_requests": 0,
}
elif self.cache.has_key(sequence_id):
prev = int(self.cache[sequence_id])
else:
prev = None
logger.error(
f"Not received sequence_start request for sequence_id:{sequence_id} before"
)
req_id = self.context.get_request_id(idx)
# process a new request
if req_id not in self.context.cache:
logger.info(
f"received a new request sequence_id={sequence_id}, request_id={req_id}"
)
request = row.get("data") or row.get("body")
if isinstance(request, (bytes, bytearray)):
request = request.decode("utf-8")
self.context.cache[req_id] = {
"stopping_criteria": self._create_stopping_criteria(
req_id=req_id, seq_id=sequence_id
),
"stream": True,
}
self.context.cache[sequence_id]["num_requests"] += 1
if type(request) is dict and "input" in request:
request = request.get("input")
# -1: cancel
if int(request) == -1:
self.context.cache[sequence_id]["cancel"] = True
self.context.cache[req_id]["stream"] = False
results.append(int(request))
elif prev is None:
logger.info(
f"Close the sequence:{sequence_id} without open session request"
)
self.context.cache[sequence_id]["end"] = True
self.context.cache[req_id]["stream"] = False
self.context.set_response_header(
idx, self.context.header_key_sequence_end, sequence_id
)
results.append(int(request))
else:
val = prev + int(request)
self.cache[sequence_id] = val
# 0: end
if int(request) == 0:
self.context.cache[sequence_id]["end"] = True
self.context.cache[req_id]["stream"] = False
self.context.set_response_header(
idx, self.context.header_key_sequence_end, sequence_id
)
# non stream input:
elif int(request) % 2 == 0:
self.context.cache[req_id]["stream"] = False
results.append(val)
else:
# continue processing stream
logger.info(
f"received continuous request sequence_id={sequence_id}, request_id={req_id}"
)
time.sleep(1)
results.append(prev)
return results
Stateful inference has three parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel.
- sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout.
- sequenceTimeoutMSec: the max duration in milliseconds of a sequence inference request of this stateful model. The default value is 0 (i.e. there is effectively no sequence timeout and the sequence does not expire). TorchServe does not process a new inference request if the sequence timeout is exceeded.
- maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1.
#cat model-config.yaml
minWorkers: 2
maxWorkers: 2
batchSize: 4
sequenceMaxIdleMSec: 60000
maxSequenceJobQueueSize: 10
sequenceBatching: true
continuousBatching: true
handler:
cache:
capacity: 4
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml
The details can be found at here.
- Install gRPC python dependencies
git submodule init
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
- Generate python gRPC client stub using the proto files
cd ../../..
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
- Start TorchServe
torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
- Run sequence inference via GRPC client
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
- Run sequence inference via HTTP
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt