Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ scratch
.DS_Store
*.csv
wiki_schema.yaml
docs/_build/
docs/_build/
.venv
34 changes: 33 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,38 @@ Here's how to get started with your code contribution:
### Dev Environment
There is a provided `requirements.txt` and `requirements-dev.txt` file you can use to install required libraries with `pip` into your virtual environment.

Or use the local package editable install method:
```bash
python -m venv .venv
source .venv/bin/activate
pip install -e .[all,dev]
```

Then to deactivate the env:
```
source deactivate
```

### Linting and Tests

Check formatting, linting, and typing:
```bash
make check
```

Tests (with vectorizers):
```bash
make test-cov
```

Tests w/out vectorizers:
```bash
SKIP_VECTORIZERS=true make test-cov
```

> Dev requirements are needed here to be able to run tests and linting.
> See other commands in the [Makefile](Makefile)

### Docker Tips

Make sure to have [Redis](https://redis.io) accessible with Search & Query features enabled on [Redis Cloud](https://redis.com/try-free) or locally in docker with [Redis Stack](https://redis.io/docs/getting-started/install-stack/docker/):
Expand All @@ -38,7 +70,7 @@ Make sure to have [Redis](https://redis.io) accessible with Search & Query featu
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
```

This will also spin up the [Redis Insight GUI](https://redis.com/redis-enterprise/redis-insight/) at `http://localhost:8001`.
This will also spin up the [FREE RedisInsight GUI](https://redis.com/redis-enterprise/redis-insight/) at `http://localhost:8001`.

## How to Report a Bug

Expand Down
2 changes: 1 addition & 1 deletion redisvl/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(
self.schema = schema

self._storage = self._STORAGE_MAP[self.schema.storage_type](
self.schema.prefix, self.schema.key_separator
prefix=self.schema.prefix, key_separator=self.schema.key_separator
)

@property
Expand Down
2 changes: 1 addition & 1 deletion redisvl/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
distance_threshold: float = 0.1,
ttl: Optional[int] = None,
vectorizer: BaseVectorizer = HFTextVectorizer(
"sentence-transformers/all-mpnet-base-v2"
model="sentence-transformers/all-mpnet-base-v2"
),
redis_url: str = "redis://localhost:6379",
connection_args: Dict[str, Any] = {},
Expand Down
3 changes: 2 additions & 1 deletion redisvl/schema/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic.v1 import BaseModel, Field, validator
from redis.commands.search.field import Field as RedisField
from redis.commands.search.field import GeoField as RedisGeoField
from redis.commands.search.field import NumericField as RedisNumericField
Expand Down Expand Up @@ -69,6 +69,7 @@ class BaseVectorField(BaseModel):
as_name: Optional[str] = None

@validator("algorithm", "datatype", "distance_metric", pre=True)
@classmethod
def uppercase_strings(cls, v):
return v.upper()

Expand Down
3 changes: 2 additions & 1 deletion redisvl/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Union

import yaml
from pydantic import BaseModel, validator
from pydantic.v1 import BaseModel, validator
from redis.commands.search.field import Field as RedisField

from redisvl.schema.fields import BaseField, BaseVectorField, FieldFactory
Expand Down Expand Up @@ -66,6 +66,7 @@ class IndexSchema(BaseModel):
fields: Dict[str, List[Union[BaseField, BaseVectorField]]] = {}

@validator("fields", pre=True)
@classmethod
def check_unique_field_names(cls, fields):
"""Validate that field names are all unique."""
all_names = cls._get_field_names(fields)
Expand Down
50 changes: 31 additions & 19 deletions redisvl/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,27 @@
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional

from pydantic.v1 import BaseModel
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from redis.commands.search.indexDefinition import IndexType

from redisvl.utils.utils import convert_bytes


class BaseStorage:
type: IndexType
DEFAULT_BATCH_SIZE: int = 200
DEFAULT_WRITE_CONCURRENCY: int = 20
class BaseStorage(BaseModel):
"""
Base class for internal storage handling in Redis.

def __init__(self, prefix: str, key_separator: str):
"""Initialize the BaseStorage with a specific prefix and key separator
for Redis keys.
Provides foundational methods for key management, data preprocessing,
validation, and basic read/write operations (both sync and async).
"""

Args:
prefix (str): The prefix to prepend to each Redis key.
key_separator (str): The separator to use between the prefix and
the key value.
"""
self._prefix = prefix
self._key_separator = key_separator
type: IndexType # Type of index used in storage
prefix: str # Prefix for Redis keys
key_separator: str # Separator between prefix and key value
default_batch_size: int = 200 # Default size for batch operations
default_write_concurrency: int = 20 # Default concurrency for async ops

@staticmethod
def _key(key_value: str, prefix: str, key_separator: str) -> str:
Expand Down Expand Up @@ -69,7 +67,7 @@ def _create_key(self, obj: Dict[str, Any], key_field: Optional[str] = None) -> s
raise ValueError(f"Key field {key_field} not found in record {obj}")

return self._key(
key_value, prefix=self._prefix, key_separator=self._key_separator
key_value, prefix=self.prefix, key_separator=self.key_separator
)

@staticmethod
Expand Down Expand Up @@ -202,7 +200,7 @@ def write(

if batch_size is None:
# Use default or calculate based on the input data
batch_size = self.DEFAULT_BATCH_SIZE
batch_size = self.default_batch_size

keys_iterator = iter(keys) if keys else None
added_keys: List[str] = []
Expand Down Expand Up @@ -272,7 +270,7 @@ async def awrite(
raise ValueError("Length of keys does not match the length of objects")

if not concurrency:
concurrency = self.DEFAULT_WRITE_CONCURRENCY
concurrency = self.default_write_concurrency

semaphore = asyncio.Semaphore(concurrency)
keys_iterator = iter(keys) if keys else None
Expand Down Expand Up @@ -322,7 +320,7 @@ def get(

if batch_size is None:
batch_size = (
self.DEFAULT_BATCH_SIZE
self.default_batch_size
) # Use default or calculate based on the input data

# Use a pipeline to batch the retrieval
Expand Down Expand Up @@ -363,7 +361,7 @@ async def aget(
return []

if not concurrency:
concurrency = self.DEFAULT_WRITE_CONCURRENCY
concurrency = self.default_write_concurrency

semaphore = asyncio.Semaphore(concurrency)

Expand All @@ -378,6 +376,13 @@ async def _get(key: str) -> Dict[str, Any]:


class HashStorage(BaseStorage):
"""
Internal subclass of BaseStorage for the Redis hash data type.

Implements hash-specific logic for validation and read/write operations
(both sync and async) in Redis.
"""

type: IndexType = IndexType.HASH

def _validate(self, obj: Dict[str, Any]):
Expand Down Expand Up @@ -443,6 +448,13 @@ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:


class JsonStorage(BaseStorage):
"""
Internal subclass of BaseStorage for the Redis JSON data type.

Implements json-specific logic for validation and read/write operations
(both sync and async) in Redis.
"""

type: IndexType = IndexType.JSON

def _validate(self, obj: Dict[str, Any]):
Expand Down
31 changes: 13 additions & 18 deletions redisvl/vectorize/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
from typing import Callable, List, Optional

from redisvl.utils.utils import array_to_buffer
from typing import Any, Callable, List, Optional

from pydantic.v1 import BaseModel, validator

class BaseVectorizer:
_dims = None

def __init__(self, model: str):
self._model = model
from redisvl.utils.utils import array_to_buffer

@property
def model(self) -> str:
return self._model

@property
def dims(self) -> Optional[int]:
return self._dims
class BaseVectorizer(BaseModel):
model: str
dims: int
client: Any

def set_model(self, model: str, dims: Optional[int] = None) -> None:
self._model = model
if dims is not None:
self._dims = dims
@validator("dims", pre=True)
@classmethod
def check_dims(cls, v):
if v <= 0:
raise ValueError("Dimension must be a positive integer")
return v

def embed_many(
self,
Expand Down
22 changes: 11 additions & 11 deletions redisvl/vectorize/text/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
ValueError: If the API key is not provided.

"""
super().__init__(model)
# Dynamic import of the cohere module
try:
import cohere
Expand All @@ -80,15 +79,16 @@ def __init__(
"Provide it in api_config or set the COHERE_API_KEY environment variable."
)

self._model = model
self._model_client = cohere.Client(api_key)
self._dims = self._set_model_dims()
client = cohere.Client(api_key)
dims = self._set_model_dims(client, model)
super().__init__(model=model, dims=dims, client=client)

def _set_model_dims(self) -> int:
@staticmethod
def _set_model_dims(client, model) -> int:
try:
embedding = self._model_client.embed(
embedding = client.embed(
texts=["dimension test"],
model=self._model,
model=model,
input_type="search_document",
).embeddings[0]
except (KeyError, IndexError) as ke:
Expand Down Expand Up @@ -150,8 +150,8 @@ def embed(
)
if preprocess:
text = preprocess(text)
embedding = self._model_client.embed(
texts=[text], model=self._model, input_type=input_type
embedding = self.client.embed(
texts=[text], model=self.model, input_type=input_type
).embeddings[0]
return self._process_embedding(embedding, as_buffer)

Expand Down Expand Up @@ -219,8 +219,8 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._model_client.embed(
texts=batch, model=self._model, input_type=input_type
response = self.client.embed(
texts=batch, model=self.model, input_type=input_type
)
embeddings += [
self._process_embedding(embedding, as_buffer)
Expand Down
25 changes: 13 additions & 12 deletions redisvl/vectorize/text/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def __init__(
ImportError: If the sentence-transformers library is not installed.
ValueError: If there is an error setting the embedding model dimensions.
"""
super().__init__(model)

# Load the SentenceTransformer model
try:
from sentence_transformers import SentenceTransformer
Expand All @@ -53,16 +51,19 @@ def __init__(
"Please install with `pip install sentence-transformers`"
)

self._model_client = SentenceTransformer(model)
client = SentenceTransformer(model)
dims = self._set_model_dims(client)
super().__init__(model=model, dims=dims, client=client)

# Initialize model dimensions
@staticmethod
def _set_model_dims(client):
try:
self._dims = self._set_model_dims()
except Exception as e:
raise ValueError(f"Error setting embedding model dimensions: {e}")

def _set_model_dims(self):
embedding = self._model_client.encode(["dimension check"])[0]
embedding = client.encode(["dimension check"])[0]
except (KeyError, IndexError) as ke:
raise ValueError(f"Empty response from the embedding model: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)

def embed(
Expand Down Expand Up @@ -92,7 +93,7 @@ def embed(

if preprocess:
text = preprocess(text)
embedding = self._model_client.encode([text])[0]
embedding = self.client.encode([text])[0]
return self._process_embedding(embedding.tolist(), as_buffer)

def embed_many(
Expand Down Expand Up @@ -128,7 +129,7 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
batch_embeddings = self._model_client.encode(batch)
batch_embeddings = self.client.encode(batch)
embeddings.extend(
[
self._process_embedding(embedding.tolist(), as_buffer)
Expand Down
Loading