Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/#486: provide redis_connection for creating all Object Models #487

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions gptcache/manager/scalar_data/redis_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,32 @@
from redis_om import JsonModel, EmbeddedJsonModel, NotFoundError, Field, Migrator


def get_models(global_key):
def get_models(global_key: str, redis_connection: Redis):
"""
Get all the models for the given global key and redis connection.
:param global_key: Global key will be used as a prefix for all the keys
:type global_key: str

:param redis_connection: Redis connection to use for all the models.
Note: This needs to be explicitly mentioned in `Meta` class for each Object Model,
otherwise it will use the default connection from the pool.
:type redis_connection: Redis
"""

class Counter:
"""
counter collection
"""
key_name = global_key + ":counter"
database = redis_connection

@classmethod
def incr(cls, con: Redis):
con.incr(cls.key_name)
def incr(cls):
cls.database.incr(cls.key_name)

@classmethod
def get(cls, con: Redis):
return con.get(cls.key_name)
def get(cls):
return cls.database.get(cls.key_name)

class Embedding:
"""
Expand Down Expand Up @@ -75,6 +90,9 @@ class Answers(EmbeddedJsonModel):
answer: str
answer_type: int

class Meta:
database = redis_connection

class Questions(JsonModel):
"""
questions collection
Expand All @@ -89,6 +107,7 @@ class Questions(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "questions"
database = redis_connection

class Sessions(JsonModel):
"""
Expand All @@ -98,6 +117,7 @@ class Sessions(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "sessions"
database = redis_connection

session_id: str = Field(index=True)
session_question: str
Expand All @@ -111,6 +131,7 @@ class QuestionDeps(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "ques_deps"
database = redis_connection

question_id: str = Field(index=True)
dep_name: str
Expand All @@ -125,6 +146,7 @@ class Report(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "report"
database = redis_connection

user_question: str
cache_question_id: int = Field(index=True)
Expand Down Expand Up @@ -194,16 +216,16 @@ def __init__(
self._session,
self._counter,
self._report,
) = get_models(global_key_prefix)
) = get_models(global_key_prefix, redis_connection=self.con)

Migrator().run()

def create(self):
pass

def _insert(self, data: CacheData, pipeline: Pipeline = None):
self._counter.incr(self.con)
pk = str(self._counter.get(self.con))
self._counter.incr()
pk = str(self._counter.get())
answers = data.answers if isinstance(data.answers, list) else [data.answers]
all_data = []
for answer in answers:
Expand Down Expand Up @@ -360,7 +382,8 @@ def delete_session(self, keys: List[str]):
self._session.delete_many(sessions_to_delete, pipeline)
pipeline.execute()

def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time):
def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value,
cache_delta_time):
self._report(
user_question=user_question,
cache_question=cache_question,
Expand Down
19 changes: 12 additions & 7 deletions tests/unit_tests/manager/test_redis_cache_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@
import numpy as np

from gptcache.manager.scalar_data.base import CacheData, Question
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage, get_models
from gptcache.utils import import_redis

import_redis()
from redis_om import get_redis_connection
from redis_om import get_redis_connection, RedisModel


class TestRedisStorage(unittest.TestCase):
test_dbname = "gptcache_test"
url = "redis://default:default@localhost:6379"

def setUp(cls) -> None:
cls._clear_test_db()

@staticmethod
def _clear_test_db():
r = get_redis_connection()
r = get_redis_connection(url=TestRedisStorage.url)
r.flushall()
r.flushdb()
time.sleep(1)

def test_normal(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data = []
for i in range(1, 10):
data.append(
Expand Down Expand Up @@ -61,7 +63,8 @@ def test_normal(self):
assert redis_storage.count(is_all=True) == 7

def test_with_deps(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data_id = redis_storage.batch_insert(
[
CacheData(
Expand Down Expand Up @@ -98,7 +101,8 @@ def test_with_deps(self):
assert ret.question.deps[1].dep_type == 1

def test_create_on(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
redis_storage.create()
data = []
for i in range(1, 10):
Expand All @@ -124,7 +128,8 @@ def test_create_on(self):
assert last_access1 < last_access2

def test_session(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data = []
for i in range(1, 11):
data.append(
Expand Down
Loading