Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions examples/redis_cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Caching with Redis database

We will build a minimal working example that uses a Redis server to cache the input/output of a custom handler.

The example will be based on the [MNIST classifier example](https://github.com/pytorch/serve/tree/master/examples/image_classifier/mnist).

### Pre-requisites

- Redis is installed on your system. Follow the [Redis getting started guide](https://redis.io/docs/getting-started/) to install Redis.

Start a Redis server using (the server will be started on `localhost:6379` by default):
```bash
redis-server
# optionally specify the port:
# redis-server --port 6379
```
- The [Python Redis interface](https://github.com/redis/redis-py) is installed:
```bash
pip install redis
```

Note that if the pre-requisites are not met, a no op decorator will be used and no exceptions will be raised.

### Using the `ts.utils.redis_cache.handler_cache` decorator

The decorator's usage is similar to that of the built-in `functools.lru_cache`.

A typical usage would be:
```python
from ts.utils.redis_cache import handler_cache

class SomeHandler(BaseHandler):
def __init__(self):
...
self.handle = handler_cache(host='localhost', port=6379, db=0, maxsize=128)(self.handle)
```
See [mnist_handler_cached.py](https://github.com/pytorch/serve/tree/master/examples/redis_cache/mnist_handler_cached.py) for a minimal concrete example.

### Package and serve the model as usual

Execute commands from the project root:
```bash
torch-model-archiver --model-name mnist --version 1.0 --model-file examples/image_classifier/mnist/mnist.py --serialized-file examples/image_classifier/mnist/mnist_cnn.pt --handler examples/redis_cache/mnist_handler_cached.py
mkdir -p model_store
mv mnist.mar model_store/
torchserve --start --model-store model_store --models mnist=mnist.mar --ts-config examples/image_classifier/mnist/config.properties
```

Run inference using:
```bash
curl http://127.0.0.1:8080/predictions/mnist -T examples/image_classifier/mnist/test_data/0.png
# The second call will return the cached result
curl http://127.0.0.1:8080/predictions/mnist -T examples/image_classifier/mnist/test_data/0.png
```

### Breif note on performance
The input and output are both serialized (by pickle) before being put into the cache.
The output also needs to be retrieved and deserialized at a cache hit.

If the input and/or output are very large objects, these serialization process might take a while and longer keys take longer to compare.
10 changes: 10 additions & 0 deletions examples/redis_cache/mnist_handler_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from examples.image_classifier.mnist.mnist_handler import MNISTDigitClassifier
from examples.redis_cache.redis_cache import handler_cache


class MNISTDigitClassifierCached(MNISTDigitClassifier):
def __init__(self):
super(MNISTDigitClassifierCached, self).__init__()
self.handle = handler_cache(host="localhost", port=6379, db=0, maxsize=2)(
self.handle
)
76 changes: 76 additions & 0 deletions examples/redis_cache/redis_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
import pickle
from functools import wraps

try:
import redis

_has_redis = True
except ImportError:
_has_redis = False

from ts.context import Context


def _make_key(args, kwds):
key = args
if kwds:
key += (object(),)
for item in kwds.items():
key += item
return pickle.dumps(key)


def _no_op_decorator(func):
@wraps(func)
def wrapper(*args, **kwds):
return func(*args, **kwds)

return wrapper


def handler_cache(host, port, db, maxsize=128):
"""Decorator for handler's handle() method that cache input/output to a Redis database.

A typical usage would be:

class SomeHandler(BaseHandler):
def __init__(self):
...
self.handle = handler_cache(host='localhost', port=6379, db=0, maxsize=128)(self.handle)

The user should ensure that both the input and the output can be pickled.
"""
if not _has_redis:
logging.error(f"Cannot import redis, try pip install redis.")
return _no_op_decorator
r = redis.Redis(host=host, port=port, db=db)
try:
r.ping()
except redis.exceptions.ConnectionError:
logging.error(
f"Cannot connect to a Redis server, ensure a server is running on {host}:{port}."
)
return _no_op_decorator

def decorating_function(func):
@wraps(func)
def wrapper(*args, **kwds):
# Removing Context objects from key hashing
key = _make_key(
args=[arg for arg in args if not isinstance(arg, Context)],
kwds={k: v for (k, v) in kwds.items() if not isinstance(v, Context)},
)
value_str = r.get(key)
if value_str is not None:
return pickle.loads(value_str)
value = func(*args, **kwds)
# Randomly remove one entry if maxsize is reached
if r.dbsize() >= maxsize:
r.delete(r.randomkey())
r.set(key, pickle.dumps(value))
return value

return wrapper

return decorating_function
3 changes: 2 additions & 1 deletion ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ deepspeed
mii
Diffusers
diffusers
Redis
AzureML
Largemodels
bigscience
Expand All @@ -997,4 +998,4 @@ sharded
NVfuser
fuser
ort
sess
sess